refactor: cleanup provider base class; cleanup shared responsibilities; cleanup public contract

docs: document provider base class
refactor: cleanup custom provider, it should only deal with `is_custom` model configurations
fix: make sure openrouter provider does not load `is_custom` models
fix: listmodels tool cleanup
This commit is contained in:
Fahad
2025-10-02 12:59:45 +04:00
parent 6ec2033f34
commit 693b84db2b
15 changed files with 509 additions and 751 deletions

View File

@@ -7,7 +7,7 @@ This guide explains how to add support for a new AI model provider to the Zen MC
Each provider:
- Inherits from `ModelProvider` (base class) or `OpenAICompatibleProvider` (for OpenAI-compatible APIs)
- Defines supported models using `ModelCapabilities` objects
- Implements a few core abstract methods
- Implements the minimal abstract hooks (`get_provider_type()` and `generate_content()`)
- Gets registered automatically via environment variables
## Choose Your Implementation Path
@@ -15,11 +15,11 @@ Each provider:
**Option A: Full Provider (`ModelProvider`)**
- For APIs with unique features or custom authentication
- Complete control over API calls and response handling
- Required methods: `generate_content()`, `get_capabilities()`, `validate_model_name()`, `get_provider_type()` (override `count_tokens()` only when you have a provider-accurate tokenizer)
- Implement `generate_content()` and `get_provider_type()`; override `get_all_model_capabilities()` to expose your catalogue and extend `_lookup_capabilities()` / `_ensure_model_allowed()` only when you need registry lookups or custom restriction rules (override `count_tokens()` only when you have a provider-accurate tokenizer)
**Option B: OpenAI-Compatible (`OpenAICompatibleProvider`)**
- For APIs that follow OpenAI's chat completion format
- Only need to define: model configurations, capabilities, and validation
- Supply `MODEL_CAPABILITIES`, override `get_provider_type()`, and optionally adjust configuration (the base class handles alias resolution, validation, and request wiring)
- Inherits all API handling automatically
⚠️ **Important**: If using aliases (like `"gpt"``"gpt-4"`), override `generate_content()` to resolve them before API calls.
@@ -62,8 +62,7 @@ logger = logging.getLogger(__name__)
class ExampleModelProvider(ModelProvider):
"""Example model provider implementation."""
# Define models using ModelCapabilities objects (like Gemini provider)
MODEL_CAPABILITIES = {
"example-large": ModelCapabilities(
provider=ProviderType.EXAMPLE,
@@ -87,51 +86,47 @@ class ExampleModelProvider(ModelProvider):
aliases=["small", "fast"],
),
}
def __init__(self, api_key: str, **kwargs):
super().__init__(api_key, **kwargs)
# Initialize your API client here
def get_capabilities(self, model_name: str) -> ModelCapabilities:
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
return dict(self.MODEL_CAPABILITIES)
def get_provider_type(self) -> ProviderType:
return ProviderType.EXAMPLE
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.MODEL_CAPABILITIES:
raise ValueError(f"Unsupported model: {model_name}")
# Apply restrictions if needed
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
raise ValueError(f"Model '{model_name}' is not allowed.")
return self.MODEL_CAPABILITIES[resolved_name]
def generate_content(self, prompt: str, model_name: str, system_prompt: Optional[str] = None,
temperature: float = 0.7, max_output_tokens: Optional[int] = None, **kwargs) -> ModelResponse:
resolved_name = self._resolve_model_name(model_name)
# Your API call logic here
# response = your_api_client.generate(...)
return ModelResponse(
content="Generated response", # From your API
content="Generated response",
usage={"input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
model_name=resolved_name,
friendly_name="Example",
provider=ProviderType.EXAMPLE,
)
def get_provider_type(self) -> ProviderType:
return ProviderType.EXAMPLE
def validate_model_name(self, model_name: str) -> bool:
resolved_name = self._resolve_model_name(model_name)
return resolved_name in self.MODEL_CAPABILITIES
```
`ModelProvider.count_tokens()` uses a simple 4-characters-per-token estimate so
providers work out of the box. Override the method only when you can call into
the provider's real tokenizer (for example, the OpenAI-compatible base class
already integrates `tiktoken`).
`ModelProvider.get_capabilities()` automatically resolves aliases, enforces the
shared restriction service, and returns the correct `ModelCapabilities`
instance. Override `_lookup_capabilities()` only when you source capabilities
from a registry or remote API. `ModelProvider.count_tokens()` uses a simple
4-characters-per-token estimate so providers work out of the box—override it
only when you can call the provider's real tokenizer (for example, the
OpenAI-compatible base class integrates `tiktoken`).
#### Option B: OpenAI-Compatible Provider (Simplified)
@@ -172,26 +167,16 @@ class ExampleProvider(OpenAICompatibleProvider):
def __init__(self, api_key: str, **kwargs):
kwargs.setdefault("base_url", "https://api.example.com/v1")
super().__init__(api_key, **kwargs)
def get_capabilities(self, model_name: str) -> ModelCapabilities:
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.MODEL_CAPABILITIES:
raise ValueError(f"Unsupported model: {model_name}")
return self.MODEL_CAPABILITIES[resolved_name]
def get_provider_type(self) -> ProviderType:
return ProviderType.EXAMPLE
def validate_model_name(self, model_name: str) -> bool:
resolved_name = self._resolve_model_name(model_name)
return resolved_name in self.MODEL_CAPABILITIES
def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
# IMPORTANT: Resolve aliases before API call
resolved_model_name = self._resolve_model_name(model_name)
return super().generate_content(prompt=prompt, model_name=resolved_model_name, **kwargs)
```
`OpenAICompatibleProvider` already exposes the declared models via
`MODEL_CAPABILITIES`, resolves aliases through the shared base pipeline, and
enforces restrictions. Most subclasses only need to provide the class metadata
shown above.
### 3. Register Your Provider
Add environment variable mapping in `providers/registry.py`:
@@ -237,15 +222,11 @@ DISABLED_TOOLS=debug,tracer
Create basic tests to verify your implementation:
```python
# Test model validation
provider = ExampleModelProvider("test-key")
assert provider.validate_model_name("large") == True
assert provider.validate_model_name("unknown") == False
# Test capabilities
caps = provider.get_capabilities("large")
assert caps.context_window > 0
assert caps.provider == ProviderType.EXAMPLE
provider = ExampleModelProvider("test-key")
capabilities = provider.get_capabilities("large")
assert capabilities.context_window > 0
assert capabilities.provider == ProviderType.EXAMPLE
```
@@ -259,31 +240,19 @@ When a user requests a model, providers are checked in priority order:
3. **OpenRouter** - catch-all for everything else
### Model Validation
Your `validate_model_name()` should **only** return `True` for models you explicitly support:
```python
def validate_model_name(self, model_name: str) -> bool:
resolved_name = self._resolve_model_name(model_name)
return resolved_name in self.MODEL_CAPABILITIES # Be specific!
```
`ModelProvider.validate_model_name()` delegates to `get_capabilities()` so most
providers can rely on the shared implementation. Override it only when you need
to opt out of that pipeline—for example, `CustomProvider` declines OpenRouter
models so they fall through to the dedicated OpenRouter provider.
### Model Aliases
The base class handles alias resolution automatically via the `aliases` field in `ModelCapabilities`.
Aliases declared on `ModelCapabilities` are applied automatically via
`_resolve_model_name()`, and both the validation and request flows call it
before touching your SDK. Override `generate_content()` only when your provider
needs additional alias handling beyond the shared behaviour.
## Important Notes
### Alias Resolution in OpenAI-Compatible Providers
If using `OpenAICompatibleProvider` with aliases, **you must override `generate_content()`** to resolve aliases before API calls:
```python
def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
# Resolve alias before API call
resolved_model_name = self._resolve_model_name(model_name)
return super().generate_content(prompt=prompt, model_name=resolved_model_name, **kwargs)
```
Without this, API calls with aliases like `"large"` will fail because your API doesn't recognize the alias.
## Best Practices
- **Be specific in model validation** - only accept models you actually support

View File

@@ -43,128 +43,37 @@ class ModelProvider(ABC):
self.api_key = api_key
self.config = kwargs
@abstractmethod
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific model."""
pass
@abstractmethod
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using the model.
Args:
prompt: User prompt to send to the model
model_name: Name of the model to use
system_prompt: Optional system prompt for model behavior
temperature: Sampling temperature (0-2)
max_output_tokens: Maximum tokens to generate
**kwargs: Provider-specific parameters
Returns:
ModelResponse with generated content and metadata
"""
pass
def count_tokens(self, text: str, model_name: str) -> int:
"""Estimate token usage for a piece of text.
Providers can rely on this shared implementation or override it when
they expose a more accurate tokenizer. This default uses a simple
character-based heuristic so it works even without provider-specific
tooling.
"""
resolved_model = self._resolve_model_name(model_name)
if not text:
return 0
# Rough estimation: ~4 characters per token for English text
estimated = max(1, len(text) // 4)
logger.debug("Estimating %s tokens for model %s via character heuristic", estimated, resolved_model)
return estimated
# ------------------------------------------------------------------
# Provider identity & capability surface
# ------------------------------------------------------------------
@abstractmethod
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
pass
"""Return the concrete provider identity."""
@abstractmethod
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported by this provider."""
pass
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Resolve capability metadata for a model name.
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
"""Validate model parameters against capabilities.
Raises:
ValueError: If parameters are invalid
This centralises the alias resolution → lookup → restriction check
pipeline so providers only override the pieces they genuinely need to
customise. Subclasses usually only override ``_lookup_capabilities`` to
integrate a registry or dynamic source, or ``_finalise_capabilities`` to
tweak the returned object.
"""
capabilities = self.get_capabilities(model_name)
# Validate temperature using constraint
if not capabilities.temperature_constraint.validate(temperature):
constraint_desc = capabilities.temperature_constraint.get_description()
raise ValueError(f"Temperature {temperature} is invalid for model {model_name}. {constraint_desc}")
resolved_name = self._resolve_model_name(model_name)
capabilities = self._lookup_capabilities(resolved_name, model_name)
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
"""Get model configurations for this provider.
if capabilities is None:
self._raise_unsupported_model(model_name)
This is a hook method that subclasses can override to provide
their model configurations from different sources.
self._ensure_model_allowed(capabilities, resolved_name, model_name)
return self._finalise_capabilities(capabilities, resolved_name, model_name)
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
"""Return the provider's statically declared model capabilities."""
Returns:
Dictionary mapping model names to their ModelCapabilities objects
"""
model_map = getattr(self, "MODEL_CAPABILITIES", None)
if isinstance(model_map, dict) and model_map:
return {k: v for k, v in model_map.items() if isinstance(v, ModelCapabilities)}
return {}
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model shorthand to full name.
This implementation uses the hook methods to support different
model configuration sources.
Args:
model_name: Model name that may be an alias
Returns:
Resolved model name
"""
# Get model configurations from the hook method
model_configs = self.get_model_configurations()
# First check if it's already a base model name (case-sensitive exact match)
if model_name in model_configs:
return model_name
# Check case-insensitively for both base models and aliases
model_name_lower = model_name.lower()
# Check base model names case-insensitively
for base_model in model_configs:
if base_model.lower() == model_name_lower:
return base_model
# Check aliases from the model configurations
alias_map = ModelCapabilities.collect_aliases(model_configs)
for base_model, aliases in alias_map.items():
if any(alias.lower() == model_name_lower for alias in aliases):
return base_model
# If not found, return as-is
return model_name
def list_models(
self,
*,
@@ -175,7 +84,7 @@ class ModelProvider(ABC):
) -> list[str]:
"""Return formatted model names supported by this provider."""
model_configs = self.get_model_configurations()
model_configs = self.get_all_model_capabilities()
if not model_configs:
return []
@@ -202,36 +111,155 @@ class ModelProvider(ABC):
unique=unique,
)
def close(self):
"""Clean up any resources held by the provider.
# ------------------------------------------------------------------
# Request execution
# ------------------------------------------------------------------
@abstractmethod
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using the model."""
def count_tokens(self, text: str, model_name: str) -> int:
"""Estimate token usage for a piece of text."""
resolved_model = self._resolve_model_name(model_name)
if not text:
return 0
estimated = max(1, len(text) // 4)
logger.debug("Estimating %s tokens for model %s via character heuristic", estimated, resolved_model)
return estimated
def close(self) -> None:
"""Clean up any resources held by the provider."""
Default implementation does nothing.
Subclasses should override if they hold resources that need cleanup.
"""
# Base implementation: no resources to clean up
return
# ------------------------------------------------------------------
# Validation hooks
# ------------------------------------------------------------------
def validate_model_name(self, model_name: str) -> bool:
"""Return ``True`` when the model resolves to an allowed capability."""
try:
self.get_capabilities(model_name)
except ValueError:
return False
return True
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
"""Validate model parameters against capabilities."""
capabilities = self.get_capabilities(model_name)
if not capabilities.temperature_constraint.validate(temperature):
constraint_desc = capabilities.temperature_constraint.get_description()
raise ValueError(f"Temperature {temperature} is invalid for model {model_name}. {constraint_desc}")
# ------------------------------------------------------------------
# Preference / registry hooks
# ------------------------------------------------------------------
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
"""Get the preferred model from this provider for a given category.
"""Get the preferred model from this provider for a given category."""
Args:
category: The tool category requiring a model
allowed_models: Pre-filtered list of model names that are allowed by restrictions
Returns:
Model name if this provider has a preference, None otherwise
"""
# Default implementation - providers can override with specific logic
return None
def get_model_registry(self) -> Optional[dict[str, Any]]:
"""Get the model registry for providers that maintain one.
"""Return the model registry backing this provider, if any."""
This is a hook method for providers like CustomProvider that maintain
a dynamic model registry.
return None
# ------------------------------------------------------------------
# Capability lookup pipeline
# ------------------------------------------------------------------
def _lookup_capabilities(
self,
canonical_name: str,
requested_name: Optional[str] = None,
) -> Optional[ModelCapabilities]:
"""Return ``ModelCapabilities`` for the canonical model name."""
return self.get_all_model_capabilities().get(canonical_name)
def _ensure_model_allowed(
self,
capabilities: ModelCapabilities,
canonical_name: str,
requested_name: str,
) -> None:
"""Raise ``ValueError`` if the model violates restriction policy."""
try:
from utils.model_restrictions import get_restriction_service
except Exception: # pragma: no cover - only triggered if service import breaks
return
restriction_service = get_restriction_service()
if not restriction_service:
return
if restriction_service.is_allowed(self.get_provider_type(), canonical_name, requested_name):
return
raise ValueError(
f"{self.get_provider_type().value} model '{canonical_name}' is not allowed by restriction policy."
)
def _finalise_capabilities(
self,
capabilities: ModelCapabilities,
canonical_name: str,
requested_name: str,
) -> ModelCapabilities:
"""Allow subclasses to adjust capability metadata before returning."""
return capabilities
def _raise_unsupported_model(self, model_name: str) -> None:
"""Raise the canonical unsupported-model error."""
raise ValueError(f"Unsupported model '{model_name}' for provider {self.get_provider_type().value}.")
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model shorthand to full name.
This implementation uses the hook methods to support different
model configuration sources.
Args:
model_name: Model name that may be an alias
Returns:
Model registry dict or None if not applicable
Resolved model name
"""
# Default implementation - most providers don't have a registry
return None
# Get model configurations from the hook method
model_configs = self.get_all_model_capabilities()
# First check if it's already a base model name (case-sensitive exact match)
if model_name in model_configs:
return model_name
# Check case-insensitively for both base models and aliases
model_name_lower = model_name.lower()
# Check base model names case-insensitively
for base_model in model_configs:
if base_model.lower() == model_name_lower:
return base_model
# Check aliases from the model configurations
alias_map = ModelCapabilities.collect_aliases(model_configs)
for base_model, aliases in alias_map.items():
if any(alias.lower() == model_name_lower for alias in aliases):
return base_model
# If not found, return as-is
return model_name

View File

@@ -83,117 +83,69 @@ class CustomProvider(OpenAICompatibleProvider):
aliases = self._registry.list_aliases()
logging.info(f"Custom provider loaded {len(models)} models with {len(aliases)} aliases")
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model aliases to actual model names.
# ------------------------------------------------------------------
# Capability surface
# ------------------------------------------------------------------
def _lookup_capabilities(
self,
canonical_name: str,
requested_name: Optional[str] = None,
) -> Optional[ModelCapabilities]:
"""Return custom capabilities from the registry or generic defaults."""
For Ollama-style models, strips version tags (e.g., 'llama3.2:latest' -> 'llama3.2')
since the base model name is what's typically used in API calls.
Args:
model_name: Input model name or alias
Returns:
Resolved model name with version tags stripped if applicable
"""
# First, try to resolve through registry as-is
config = self._registry.resolve(model_name)
if config:
if config.model_name != model_name:
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
return config.model_name
else:
# If not found in registry, handle version tags for local models
# Strip version tags (anything after ':') for Ollama-style models
if ":" in model_name:
base_model = model_name.split(":")[0]
logging.debug(f"Stripped version tag from '{model_name}' -> '{base_model}'")
# Try to resolve the base model through registry
base_config = self._registry.resolve(base_model)
if base_config:
logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'")
return base_config.model_name
else:
return base_model
else:
# If not found in registry and no version tag, return as-is
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
return model_name
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a custom model.
Args:
model_name: Name of the model (or alias)
Returns:
ModelCapabilities from registry or generic defaults
"""
# Try to get from registry first
capabilities = self._registry.get_capabilities(model_name)
builtin = super()._lookup_capabilities(canonical_name, requested_name)
if builtin is not None:
return builtin
capabilities = self._registry.get_capabilities(canonical_name)
if capabilities:
# Check if this is an OpenRouter model and apply restrictions
config = self._registry.resolve(model_name)
if config and not config.is_custom:
# This is an OpenRouter model, check restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.OPENROUTER, config.model_name, model_name):
raise ValueError(f"OpenRouter model '{model_name}' is not allowed by restriction policy.")
# Update provider type to OPENROUTER for OpenRouter models
capabilities.provider = ProviderType.OPENROUTER
else:
# Update provider type to CUSTOM for local custom models
config = self._registry.resolve(canonical_name)
if config and getattr(config, "is_custom", False):
capabilities.provider = ProviderType.CUSTOM
return capabilities
else:
# Resolve any potential aliases and create generic capabilities
resolved_name = self._resolve_model_name(model_name)
return capabilities
# Non-custom models should fall through so OpenRouter handles them
return None
logging.debug(
f"Using generic capabilities for '{resolved_name}' via Custom API. "
"Consider adding to custom_models.json for specific capabilities."
)
logging.debug(
f"Using generic capabilities for '{canonical_name}' via Custom API. "
"Consider adding to custom_models.json for specific capabilities."
)
# Infer temperature behaviour for generic capability fallback
supports_temperature, temperature_constraint, temperature_reason = TemperatureConstraint.resolve_settings(
resolved_name
)
supports_temperature, temperature_constraint, temperature_reason = TemperatureConstraint.resolve_settings(
canonical_name
)
logging.warning(
f"Model '{resolved_name}' not found in custom_models.json. Using generic capabilities with inferred settings. "
f"Temperature support: {supports_temperature} ({temperature_reason}). "
"For better accuracy, add this model to your custom_models.json configuration."
)
logging.warning(
f"Model '{canonical_name}' not found in custom_models.json. Using generic capabilities with inferred settings. "
f"Temperature support: {supports_temperature} ({temperature_reason}). "
"For better accuracy, add this model to your custom_models.json configuration."
)
# Create generic capabilities with inferred defaults
capabilities = ModelCapabilities(
provider=ProviderType.CUSTOM,
model_name=resolved_name,
friendly_name=f"{self.FRIENDLY_NAME} ({resolved_name})",
context_window=32_768, # Conservative default
max_output_tokens=32_768, # Conservative default max output
supports_extended_thinking=False, # Most custom models don't support this
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # Conservative default
supports_temperature=supports_temperature,
temperature_constraint=temperature_constraint,
)
# Mark as generic for validation purposes
capabilities._is_generic = True
return capabilities
generic = ModelCapabilities(
provider=ProviderType.CUSTOM,
model_name=canonical_name,
friendly_name=f"{self.FRIENDLY_NAME} ({canonical_name})",
context_window=32_768,
max_output_tokens=32_768,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False,
supports_temperature=supports_temperature,
temperature_constraint=temperature_constraint,
)
generic._is_generic = True
return generic
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
"""Identify this provider for restriction and logging logic."""
return ProviderType.CUSTOM
# ------------------------------------------------------------------
# Validation
# ------------------------------------------------------------------
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is allowed.
@@ -206,49 +158,41 @@ class CustomProvider(OpenAICompatibleProvider):
Returns:
True if model is intended for custom/local endpoint
"""
# logging.debug(f"Custom provider validating model: '{model_name}'")
if super().validate_model_name(model_name):
return True
# Try to resolve through registry first
config = self._registry.resolve(model_name)
if config:
model_id = config.model_name
# Use explicit is_custom flag for clean validation
if config.is_custom:
logging.debug(f"... [Custom] Model '{model_name}' -> '{model_id}' validated via registry")
return True
else:
# This is a cloud/OpenRouter model - CustomProvider should NOT handle these
# Let OpenRouter provider handle them instead
# logging.debug(f"... [Custom] Model '{model_name}' -> '{model_id}' not custom (defer to OpenRouter)")
return False
if config and not getattr(config, "is_custom", False):
return False
# Handle version tags for unknown models (e.g., "my-model:latest")
clean_model_name = model_name
if ":" in model_name:
clean_model_name = model_name.split(":")[0]
clean_model_name = model_name.split(":", 1)[0]
logging.debug(f"Stripped version tag from '{model_name}' -> '{clean_model_name}'")
# Try to resolve the clean name
if super().validate_model_name(clean_model_name):
return True
config = self._registry.resolve(clean_model_name)
if config:
return self.validate_model_name(clean_model_name) # Recursively validate clean name
if config and not getattr(config, "is_custom", False):
return False
# For unknown models (not in registry), only accept if they look like local models
# This maintains backward compatibility for custom models not yet in the registry
# Accept models with explicit local indicators in the name
if any(indicator in clean_model_name.lower() for indicator in ["local", "ollama", "vllm", "lmstudio"]):
lowered = clean_model_name.lower()
if any(indicator in lowered for indicator in ["local", "ollama", "vllm", "lmstudio"]):
logging.debug(f"Model '{clean_model_name}' validated via local indicators")
return True
# Accept simple model names without vendor prefix (likely local/custom models)
if "/" not in clean_model_name:
logging.debug(f"Model '{clean_model_name}' validated as potential local model (no vendor prefix)")
return True
# Reject everything else (likely cloud models not in registry)
logging.debug(f"Model '{model_name}' rejected by custom provider (appears to be cloud model)")
return False
# ------------------------------------------------------------------
# Request execution
# ------------------------------------------------------------------
def generate_content(
self,
prompt: str,
@@ -284,25 +228,41 @@ class CustomProvider(OpenAICompatibleProvider):
**kwargs,
)
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
"""Get model configurations from the registry.
# ------------------------------------------------------------------
# Registry helpers
# ------------------------------------------------------------------
For CustomProvider, we convert registry configurations to ModelCapabilities objects.
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve registry aliases and strip version tags for local models."""
Returns:
Dictionary mapping model names to their ModelCapabilities objects
"""
config = self._registry.resolve(model_name)
if config:
if config.model_name != model_name:
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
return config.model_name
configs = {}
if ":" in model_name:
base_model = model_name.split(":")[0]
logging.debug(f"Stripped version tag from '{model_name}' -> '{base_model}'")
if self._registry:
# Get all models from registry
for model_name in self._registry.list_models():
# Only include custom models that this provider validates
if self.validate_model_name(model_name):
config = self._registry.resolve(model_name)
if config and config.is_custom:
# Use ModelCapabilities directly from registry
configs[model_name] = config
base_config = self._registry.resolve(base_model)
if base_config:
logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'")
return base_config.model_name
return base_model
return configs
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
return model_name
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
"""Expose registry capabilities for models marked as custom."""
if not self._registry:
return {}
capabilities: dict[str, ModelCapabilities] = {}
for model_name in self._registry.list_models():
config = self._registry.resolve(model_name)
if config and getattr(config, "is_custom", False):
capabilities[model_name] = config
return capabilities

View File

@@ -261,68 +261,10 @@ class DIALModelProvider(OpenAICompatibleProvider):
logger.info(f"Initialized DIAL provider with host: {dial_host} and api-version: {self.api_version}")
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific model.
Args:
model_name: Name of the model (can be shorthand)
Returns:
ModelCapabilities object
Raises:
ValueError: If model is not supported or not allowed
"""
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.MODEL_CAPABILITIES:
raise ValueError(f"Unsupported DIAL model: {model_name}")
# Check restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
# Return the ModelCapabilities object directly from MODEL_CAPABILITIES
return self.MODEL_CAPABILITIES[resolved_name]
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.DIAL
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported.
Args:
model_name: Model name to validate
Returns:
True if model is supported and allowed, False otherwise
"""
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.MODEL_CAPABILITIES:
return False
# Check against base class allowed_models if configured
if self.allowed_models is not None:
# Check both original and resolved names (case-insensitive)
if model_name.lower() not in self.allowed_models and resolved_name.lower() not in self.allowed_models:
logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' not in allowed_models list")
return False
# Also check restrictions via ModelRestrictionService
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' blocked by restrictions")
return False
return True
def _get_deployment_client(self, deployment: str):
"""Get or create a cached client for a specific deployment.
@@ -504,7 +446,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
)
def close(self):
def close(self) -> None:
"""Clean up HTTP clients when provider is closed."""
logger.info("Closing DIAL provider HTTP clients...")

View File

@@ -131,6 +131,19 @@ class GeminiModelProvider(ModelProvider):
self._token_counters = {} # Cache for token counting
self._base_url = kwargs.get("base_url", None) # Optional custom endpoint
# ------------------------------------------------------------------
# Capability surface
# ------------------------------------------------------------------
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
"""Return statically defined Gemini capabilities."""
return dict(self.MODEL_CAPABILITIES)
# ------------------------------------------------------------------
# Client access
# ------------------------------------------------------------------
@property
def client(self):
"""Lazy initialization of Gemini client."""
@@ -146,25 +159,9 @@ class GeminiModelProvider(ModelProvider):
self._client = genai.Client(api_key=self.api_key)
return self._client
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific Gemini model."""
# Resolve shorthand
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.MODEL_CAPABILITIES:
raise ValueError(f"Unsupported Gemini model: {model_name}")
# Check if model is allowed by restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
# IMPORTANT: Parameter order is (provider_type, model_name, original_name)
# resolved_name is the canonical model name, model_name is the user input
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
raise ValueError(f"Gemini model '{resolved_name}' is not allowed by restriction policy.")
# Return the ModelCapabilities object directly from MODEL_CAPABILITIES
return self.MODEL_CAPABILITIES[resolved_name]
# ------------------------------------------------------------------
# Request execution
# ------------------------------------------------------------------
def generate_content(
self,
@@ -365,26 +362,6 @@ class GeminiModelProvider(ModelProvider):
"""Get the provider type."""
return ProviderType.GOOGLE
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported and allowed."""
resolved_name = self._resolve_model_name(model_name)
# First check if model is supported
if resolved_name not in self.MODEL_CAPABILITIES:
return False
# Then check if model is allowed by restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
# IMPORTANT: Parameter order is (provider_type, model_name, original_name)
# resolved_name is the canonical model name, model_name is the user input
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
logger.debug(f"Gemini model '{model_name}' -> '{resolved_name}' blocked by restrictions")
return False
return True
def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int:
"""Get actual thinking token budget for a model and thinking mode."""
resolved_name = self._resolve_model_name(model_name)

View File

@@ -5,7 +5,6 @@ import ipaddress
import logging
import os
import time
from abc import abstractmethod
from typing import Optional
from urllib.parse import urlparse
@@ -61,6 +60,33 @@ class OpenAICompatibleProvider(ModelProvider):
"This may be insecure. Consider setting an API key for authentication."
)
def _ensure_model_allowed(
self,
capabilities: ModelCapabilities,
canonical_name: str,
requested_name: str,
) -> None:
"""Respect provider-specific allowlists before default restriction checks."""
super()._ensure_model_allowed(capabilities, canonical_name, requested_name)
if self.allowed_models is not None:
requested = requested_name.lower()
canonical = canonical_name.lower()
if requested not in self.allowed_models and canonical not in self.allowed_models:
raise ValueError(
f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}"
)
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
"""Return statically declared capabilities for OpenAI-compatible providers."""
model_map = getattr(self, "MODEL_CAPABILITIES", None)
if isinstance(model_map, dict):
return {k: v for k, v in model_map.items() if isinstance(v, ModelCapabilities)}
return {}
def _parse_allowed_models(self) -> Optional[set[str]]:
"""Parse allowed models from environment variable.
@@ -686,30 +712,6 @@ class OpenAICompatibleProvider(ModelProvider):
return super().count_tokens(text, model_name)
@abstractmethod
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific model.
Must be implemented by subclasses.
"""
pass
@abstractmethod
def get_provider_type(self) -> ProviderType:
"""Get the provider type.
Must be implemented by subclasses.
"""
pass
@abstractmethod
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported.
Must be implemented by subclasses.
"""
pass
def _is_error_retryable(self, error: Exception) -> bool:
"""Determine if an error should be retried based on structured error codes.

View File

@@ -174,106 +174,61 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
kwargs.setdefault("base_url", "https://api.openai.com/v1")
super().__init__(api_key, **kwargs)
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific OpenAI model."""
# First check if it's a key in MODEL_CAPABILITIES
if model_name in self.MODEL_CAPABILITIES:
self._check_model_restrictions(model_name, model_name)
return self.MODEL_CAPABILITIES[model_name]
# ------------------------------------------------------------------
# Capability surface
# ------------------------------------------------------------------
# Try resolving as alias
resolved_name = self._resolve_model_name(model_name)
def _lookup_capabilities(
self,
canonical_name: str,
requested_name: Optional[str] = None,
) -> Optional[ModelCapabilities]:
"""Look up OpenAI capabilities from built-ins or the custom registry."""
# Check if resolved name is a key
if resolved_name in self.MODEL_CAPABILITIES:
self._check_model_restrictions(resolved_name, model_name)
return self.MODEL_CAPABILITIES[resolved_name]
builtin = super()._lookup_capabilities(canonical_name, requested_name)
if builtin is not None:
return builtin
# Finally check if resolved name matches any API model name
for key, capabilities in self.MODEL_CAPABILITIES.items():
if resolved_name == capabilities.model_name:
self._check_model_restrictions(key, model_name)
return capabilities
# Check custom models registry for user-configured OpenAI models
try:
from .openrouter_registry import OpenRouterModelRegistry
registry = OpenRouterModelRegistry()
config = registry.get_model_config(resolved_name)
config = registry.get_model_config(canonical_name)
if config and config.provider == ProviderType.OPENAI:
self._check_model_restrictions(config.model_name, model_name)
# Update provider type to ensure consistency
config.provider = ProviderType.OPENAI
return config
except Exception as e:
# Log but don't fail - registry might not be available
logger.debug(f"Could not check custom models registry for '{resolved_name}': {e}")
except Exception as exc: # pragma: no cover - registry failures are non-critical
logger.debug(f"Could not resolve custom OpenAI model '{canonical_name}': {exc}")
return None
def _finalise_capabilities(
self,
capabilities: ModelCapabilities,
canonical_name: str,
requested_name: str,
) -> ModelCapabilities:
"""Ensure registry-sourced models report the correct provider type."""
if capabilities.provider != ProviderType.OPENAI:
capabilities.provider = ProviderType.OPENAI
return capabilities
def _raise_unsupported_model(self, model_name: str) -> None:
raise ValueError(f"Unsupported OpenAI model: {model_name}")
def _check_model_restrictions(self, provider_model_name: str, user_model_name: str) -> None:
"""Check if a model is allowed by restriction policy.
Args:
provider_model_name: The model name used by the provider
user_model_name: The model name requested by the user
Raises:
ValueError: If the model is not allowed by restriction policy
"""
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.OPENAI, provider_model_name, user_model_name):
raise ValueError(f"OpenAI model '{user_model_name}' is not allowed by restriction policy.")
# ------------------------------------------------------------------
# Provider identity
# ------------------------------------------------------------------
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.OPENAI
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported and allowed."""
resolved_name = self._resolve_model_name(model_name)
# First, determine which model name to check against restrictions.
model_to_check = None
is_custom_model = False
if resolved_name in self.MODEL_CAPABILITIES:
model_to_check = resolved_name
else:
# If not a built-in model, check the custom models registry.
try:
from .openrouter_registry import OpenRouterModelRegistry
registry = OpenRouterModelRegistry()
config = registry.get_model_config(resolved_name)
if config and config.provider == ProviderType.OPENAI:
model_to_check = config.model_name
is_custom_model = True
except Exception as e:
# Log but don't fail - registry might not be available.
logger.debug(f"Could not check custom models registry for '{resolved_name}': {e}")
# If no model was found (neither built-in nor custom), it's invalid.
if not model_to_check:
return False
# Now, perform the restriction check once.
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.OPENAI, model_to_check, model_name):
model_type = "custom " if is_custom_model else ""
logger.debug(f"OpenAI {model_type}model '{model_name}' -> '{resolved_name}' blocked by restrictions")
return False
return True
# ------------------------------------------------------------------
# Request execution
# ------------------------------------------------------------------
def generate_content(
self,
@@ -298,6 +253,10 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
**kwargs,
)
# ------------------------------------------------------------------
# Provider preferences
# ------------------------------------------------------------------
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
"""Get OpenAI's preferred model for a given category from allowed models.

View File

@@ -61,108 +61,52 @@ class OpenRouterProvider(OpenAICompatibleProvider):
aliases = self._registry.list_aliases()
logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases")
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model aliases to OpenRouter model names.
# ------------------------------------------------------------------
# Capability surface
# ------------------------------------------------------------------
Args:
model_name: Input model name or alias
Returns:
Resolved OpenRouter model name
"""
# Try to resolve through registry
config = self._registry.resolve(model_name)
if config:
if config.model_name != model_name:
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
return config.model_name
else:
# If not found in registry, return as-is
# This allows using models not in our config file
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
return model_name
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a model.
Args:
model_name: Name of the model (or alias)
Returns:
ModelCapabilities from registry or generic defaults
"""
# Try to get from registry first
capabilities = self._registry.get_capabilities(model_name)
def _lookup_capabilities(
self,
canonical_name: str,
requested_name: Optional[str] = None,
) -> Optional[ModelCapabilities]:
"""Fetch OpenRouter capabilities from the registry or build a generic fallback."""
capabilities = self._registry.get_capabilities(canonical_name)
if capabilities:
return capabilities
else:
# Resolve any potential aliases and create generic capabilities
resolved_name = self._resolve_model_name(model_name)
logging.debug(
f"Using generic capabilities for '{resolved_name}' via OpenRouter. "
"Consider adding to custom_models.json for specific capabilities."
)
logging.debug(
f"Using generic capabilities for '{canonical_name}' via OpenRouter. "
"Consider adding to custom_models.json for specific capabilities."
)
# Create generic capabilities with conservative defaults
capabilities = ModelCapabilities(
provider=ProviderType.OPENROUTER,
model_name=resolved_name,
friendly_name=self.FRIENDLY_NAME,
context_window=32_768, # Conservative default context window
max_output_tokens=32_768,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False,
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
)
generic = ModelCapabilities(
provider=ProviderType.OPENROUTER,
model_name=canonical_name,
friendly_name=self.FRIENDLY_NAME,
context_window=32_768,
max_output_tokens=32_768,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False,
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
)
generic._is_generic = True
return generic
# Mark as generic for validation purposes
capabilities._is_generic = True
return capabilities
# ------------------------------------------------------------------
# Provider identity
# ------------------------------------------------------------------
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
"""Identify this provider for restrictions and logging."""
return ProviderType.OPENROUTER
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is allowed.
As the catch-all provider, OpenRouter accepts any model name that wasn't
handled by higher-priority providers. OpenRouter will validate based on
the API key's permissions and local restrictions.
Args:
model_name: Model name to validate
Returns:
True if model is allowed, False if restricted
"""
# Check model restrictions if configured
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if restriction_service:
# Check if model name itself is allowed
if restriction_service.is_allowed(self.get_provider_type(), model_name):
return True
# Also check aliases - model_name might be an alias
model_config = self._registry.resolve(model_name)
if model_config and model_config.aliases:
for alias in model_config.aliases:
if restriction_service.is_allowed(self.get_provider_type(), alias):
return True
# If restrictions are configured and model/alias not in allowed list, reject
return False
# No restrictions configured - accept any model name as the fallback provider
return True
# ------------------------------------------------------------------
# Request execution
# ------------------------------------------------------------------
def generate_content(
self,
@@ -204,6 +148,10 @@ class OpenRouterProvider(OpenAICompatibleProvider):
**kwargs,
)
# ------------------------------------------------------------------
# Registry helpers
# ------------------------------------------------------------------
def list_models(
self,
*,
@@ -227,6 +175,12 @@ class OpenRouterProvider(OpenAICompatibleProvider):
if not config:
continue
# Custom models belong to CustomProvider; skip them here so the two
# providers don't race over the same registrations (important for tests
# that stub the registry with minimal objects lacking attrs).
if hasattr(config, "is_custom") and config.is_custom is True:
continue
if restriction_service:
allowed = restriction_service.is_allowed(self.get_provider_type(), model_name)
@@ -255,24 +209,37 @@ class OpenRouterProvider(OpenAICompatibleProvider):
unique=unique,
)
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
"""Get model configurations from the registry.
# ------------------------------------------------------------------
# Registry helpers
# ------------------------------------------------------------------
For OpenRouter, we convert registry configurations to ModelCapabilities objects.
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve aliases defined in the OpenRouter registry."""
Returns:
Dictionary mapping model names to their ModelCapabilities objects
"""
configs = {}
config = self._registry.resolve(model_name)
if config:
if config.model_name != model_name:
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
return config.model_name
if self._registry:
# Get all models from registry
for model_name in self._registry.list_models():
# Only include models that this provider validates
if self.validate_model_name(model_name):
config = self._registry.resolve(model_name)
if config and not config.is_custom: # Only OpenRouter models, not custom ones
# Use ModelCapabilities directly from registry
configs[model_name] = config
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
return model_name
return configs
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
"""Expose registry-backed OpenRouter capabilities."""
if not self._registry:
return {}
capabilities: dict[str, ModelCapabilities] = {}
for model_name in self._registry.list_models():
config = self._registry.resolve(model_name)
if not config:
continue
# See note in list_models: respect the CustomProvider boundary.
if hasattr(config, "is_custom") and config.is_custom is True:
continue
capabilities[model_name] = config
return capabilities

View File

@@ -64,6 +64,8 @@ class ModelProviderRegistry:
"""
instance = cls()
instance._providers[provider_type] = provider_class
# Invalidate any cached instance so subsequent lookups use the new registration
instance._initialized_providers.pop(provider_type, None)
@classmethod
def get_provider(cls, provider_type: ProviderType, force_new: bool = False) -> Optional[ModelProvider]:

View File

@@ -85,46 +85,10 @@ class XAIModelProvider(OpenAICompatibleProvider):
kwargs.setdefault("base_url", "https://api.x.ai/v1")
super().__init__(api_key, **kwargs)
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific X.AI model."""
# Resolve shorthand
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.MODEL_CAPABILITIES:
raise ValueError(f"Unsupported X.AI model: {model_name}")
# Check if model is allowed by restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name):
raise ValueError(f"X.AI model '{model_name}' is not allowed by restriction policy.")
# Return the ModelCapabilities object directly from MODEL_CAPABILITIES
return self.MODEL_CAPABILITIES[resolved_name]
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.XAI
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported and allowed."""
resolved_name = self._resolve_model_name(model_name)
# First check if model is supported
if resolved_name not in self.MODEL_CAPABILITIES:
return False
# Then check if model is allowed by restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name):
logger.debug(f"X.AI model '{model_name}' -> '{resolved_name}' blocked by restrictions")
return False
return True
def generate_content(
self,
prompt: str,

View File

@@ -54,11 +54,9 @@ class TestCustomProvider:
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
# Test with a model that should be in the registry (OpenRouter model)
capabilities = provider.get_capabilities("o3") # o3 is an OpenRouter model
assert capabilities.provider == ProviderType.OPENROUTER # o3 is an OpenRouter model (is_custom=false)
assert capabilities.context_window > 0
# OpenRouter-backed models should be handled by the OpenRouter provider
with pytest.raises(ValueError):
provider.get_capabilities("o3")
# Test with a custom model (is_custom=true)
capabilities = provider.get_capabilities("local-llama")
@@ -168,7 +166,13 @@ class TestCustomProviderRegistration:
return CustomProvider(api_key="", base_url="http://localhost:11434/v1")
with patch.dict(
os.environ, {"OPENROUTER_API_KEY": "test-openrouter-key", "CUSTOM_API_PLACEHOLDER": "configured"}
os.environ,
{
"OPENROUTER_API_KEY": "test-openrouter-key",
"CUSTOM_API_PLACEHOLDER": "configured",
"OPENROUTER_ALLOWED_MODELS": "llama,anthropic/claude-opus-4.1",
},
clear=True,
):
# Register both providers
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
@@ -195,18 +199,22 @@ class TestCustomProviderRegistration:
return CustomProvider(api_key="", base_url="http://localhost:11434/v1")
with patch.dict(
os.environ, {"OPENROUTER_API_KEY": "test-openrouter-key", "CUSTOM_API_PLACEHOLDER": "configured"}
os.environ,
{
"OPENROUTER_API_KEY": "test-openrouter-key",
"CUSTOM_API_PLACEHOLDER": "configured",
"OPENROUTER_ALLOWED_MODELS": "",
},
clear=True,
):
# Register OpenRouter first (higher priority)
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, custom_provider_factory)
import utils.model_restrictions
# Test model resolution - OpenRouter should win for shared aliases
provider_for_model = ModelProviderRegistry.get_provider_for_model("llama")
utils.model_restrictions._restriction_service = None
custom_provider = custom_provider_factory()
openrouter_provider = OpenRouterProvider(api_key="test-openrouter-key")
# OpenRouter should be selected first due to registration order
assert provider_for_model is not None
# The exact provider type depends on which validates the model first
assert not custom_provider.validate_model_name("llama")
assert openrouter_provider.validate_model_name("llama")
class TestConfigureProvidersFunction:

View File

@@ -121,7 +121,7 @@ class TestDIALProvider:
"""Test that get_capabilities raises for invalid models."""
provider = DIALModelProvider("test-key")
with pytest.raises(ValueError, match="Unsupported DIAL model"):
with pytest.raises(ValueError, match="Unsupported model 'invalid-model' for provider dial"):
provider.get_capabilities("invalid-model")
@patch("utils.model_restrictions.get_restriction_service")

View File

@@ -356,15 +356,13 @@ class TestCustomProviderOpenRouterRestrictions:
provider = CustomProvider(base_url="http://test.com/v1")
# For OpenRouter models, get_capabilities should still work but mark them as OPENROUTER
# This tests the capabilities lookup, not validation
capabilities = provider.get_capabilities("opus")
assert capabilities.provider == ProviderType.OPENROUTER
# For OpenRouter models, CustomProvider should defer by raising
with pytest.raises(ValueError):
provider.get_capabilities("opus")
# Should raise for disallowed OpenRouter model
with pytest.raises(ValueError) as exc_info:
# Should raise for disallowed OpenRouter model (still defers)
with pytest.raises(ValueError):
provider.get_capabilities("haiku")
assert "not allowed by restriction policy" in str(exc_info.value)
# Should still work for custom models (is_custom=true)
capabilities = provider.get_capabilities("local-llama")

View File

@@ -141,7 +141,7 @@ class TestXAIProvider:
"""Test error handling for unsupported models."""
provider = XAIModelProvider("test-key")
with pytest.raises(ValueError, match="Unsupported X.AI model"):
with pytest.raises(ValueError, match="Unsupported model 'invalid-model' for provider xai"):
provider.get_capabilities("invalid-model")
def test_extended_thinking_flags(self):

View File

@@ -105,13 +105,11 @@ class ListModelsTool(BaseTool):
output_lines.append("**Status**: Configured and available")
output_lines.append("\n**Models**:")
# Get models from the provider's model configurations
for model_name, capabilities in provider.get_model_configurations().items():
# Get description and context from the ModelCapabilities object
aliases = []
for model_name, capabilities in provider.get_all_model_capabilities().items():
description = capabilities.description or "No description available"
context_window = capabilities.context_window
# Format context window
if context_window >= 1_000_000:
context_str = f"{context_window // 1_000_000}M context"
elif context_window >= 1_000:
@@ -120,31 +118,15 @@ class ListModelsTool(BaseTool):
context_str = f"{context_window} context" if context_window > 0 else "unknown context"
output_lines.append(f"- `{model_name}` - {context_str}")
output_lines.append(f" - {description}")
# Extract key capability from description
if "Ultra-fast" in description:
output_lines.append(" - Fast processing, quick iterations")
elif "Deep reasoning" in description:
output_lines.append(" - Extended reasoning with thinking mode")
elif "Strong reasoning" in description:
output_lines.append(" - Logical problems, systematic analysis")
elif "EXTREMELY EXPENSIVE" in description:
output_lines.append(" - ⚠️ Professional grade (very expensive)")
elif "Advanced reasoning" in description:
output_lines.append(" - Advanced reasoning and complex analysis")
# Show aliases for this provider
aliases = []
for model_name, capabilities in provider.get_model_configurations().items():
if capabilities.aliases:
for alias in capabilities.aliases:
# Skip aliases that are the same as the model name to avoid duplicates
if alias != model_name:
aliases.append(f"- `{alias}` → `{model_name}`")
for alias in capabilities.aliases or []:
if alias != model_name:
aliases.append(f"- `{alias}` → `{model_name}`")
if aliases:
output_lines.append("\n**Aliases**:")
output_lines.extend(sorted(aliases)) # Sort for consistent output
output_lines.extend(sorted(aliases))
else:
output_lines.append(f"**Status**: Not configured (set {info['env_key']})")