refactor: cleanup provider base class; cleanup shared responsibilities; cleanup public contract

docs: document provider base class
refactor: cleanup custom provider, it should only deal with `is_custom` model configurations
fix: make sure openrouter provider does not load `is_custom` models
fix: listmodels tool cleanup
This commit is contained in:
Fahad
2025-10-02 12:59:45 +04:00
parent 6ec2033f34
commit 693b84db2b
15 changed files with 509 additions and 751 deletions

View File

@@ -7,7 +7,7 @@ This guide explains how to add support for a new AI model provider to the Zen MC
Each provider: Each provider:
- Inherits from `ModelProvider` (base class) or `OpenAICompatibleProvider` (for OpenAI-compatible APIs) - Inherits from `ModelProvider` (base class) or `OpenAICompatibleProvider` (for OpenAI-compatible APIs)
- Defines supported models using `ModelCapabilities` objects - Defines supported models using `ModelCapabilities` objects
- Implements a few core abstract methods - Implements the minimal abstract hooks (`get_provider_type()` and `generate_content()`)
- Gets registered automatically via environment variables - Gets registered automatically via environment variables
## Choose Your Implementation Path ## Choose Your Implementation Path
@@ -15,11 +15,11 @@ Each provider:
**Option A: Full Provider (`ModelProvider`)** **Option A: Full Provider (`ModelProvider`)**
- For APIs with unique features or custom authentication - For APIs with unique features or custom authentication
- Complete control over API calls and response handling - Complete control over API calls and response handling
- Required methods: `generate_content()`, `get_capabilities()`, `validate_model_name()`, `get_provider_type()` (override `count_tokens()` only when you have a provider-accurate tokenizer) - Implement `generate_content()` and `get_provider_type()`; override `get_all_model_capabilities()` to expose your catalogue and extend `_lookup_capabilities()` / `_ensure_model_allowed()` only when you need registry lookups or custom restriction rules (override `count_tokens()` only when you have a provider-accurate tokenizer)
**Option B: OpenAI-Compatible (`OpenAICompatibleProvider`)** **Option B: OpenAI-Compatible (`OpenAICompatibleProvider`)**
- For APIs that follow OpenAI's chat completion format - For APIs that follow OpenAI's chat completion format
- Only need to define: model configurations, capabilities, and validation - Supply `MODEL_CAPABILITIES`, override `get_provider_type()`, and optionally adjust configuration (the base class handles alias resolution, validation, and request wiring)
- Inherits all API handling automatically - Inherits all API handling automatically
⚠️ **Important**: If using aliases (like `"gpt"``"gpt-4"`), override `generate_content()` to resolve them before API calls. ⚠️ **Important**: If using aliases (like `"gpt"``"gpt-4"`), override `generate_content()` to resolve them before API calls.
@@ -62,8 +62,7 @@ logger = logging.getLogger(__name__)
class ExampleModelProvider(ModelProvider): class ExampleModelProvider(ModelProvider):
"""Example model provider implementation.""" """Example model provider implementation."""
# Define models using ModelCapabilities objects (like Gemini provider)
MODEL_CAPABILITIES = { MODEL_CAPABILITIES = {
"example-large": ModelCapabilities( "example-large": ModelCapabilities(
provider=ProviderType.EXAMPLE, provider=ProviderType.EXAMPLE,
@@ -87,51 +86,47 @@ class ExampleModelProvider(ModelProvider):
aliases=["small", "fast"], aliases=["small", "fast"],
), ),
} }
def __init__(self, api_key: str, **kwargs): def __init__(self, api_key: str, **kwargs):
super().__init__(api_key, **kwargs) super().__init__(api_key, **kwargs)
# Initialize your API client here # Initialize your API client here
def get_capabilities(self, model_name: str) -> ModelCapabilities: def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
return dict(self.MODEL_CAPABILITIES)
def get_provider_type(self) -> ProviderType:
return ProviderType.EXAMPLE
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
resolved_name = self._resolve_model_name(model_name) resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.MODEL_CAPABILITIES:
raise ValueError(f"Unsupported model: {model_name}")
# Apply restrictions if needed
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
raise ValueError(f"Model '{model_name}' is not allowed.")
return self.MODEL_CAPABILITIES[resolved_name]
def generate_content(self, prompt: str, model_name: str, system_prompt: Optional[str] = None,
temperature: float = 0.7, max_output_tokens: Optional[int] = None, **kwargs) -> ModelResponse:
resolved_name = self._resolve_model_name(model_name)
# Your API call logic here # Your API call logic here
# response = your_api_client.generate(...) # response = your_api_client.generate(...)
return ModelResponse( return ModelResponse(
content="Generated response", # From your API content="Generated response",
usage={"input_tokens": 100, "output_tokens": 50, "total_tokens": 150}, usage={"input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
model_name=resolved_name, model_name=resolved_name,
friendly_name="Example", friendly_name="Example",
provider=ProviderType.EXAMPLE, provider=ProviderType.EXAMPLE,
) )
def get_provider_type(self) -> ProviderType:
return ProviderType.EXAMPLE
def validate_model_name(self, model_name: str) -> bool:
resolved_name = self._resolve_model_name(model_name)
return resolved_name in self.MODEL_CAPABILITIES
``` ```
`ModelProvider.count_tokens()` uses a simple 4-characters-per-token estimate so `ModelProvider.get_capabilities()` automatically resolves aliases, enforces the
providers work out of the box. Override the method only when you can call into shared restriction service, and returns the correct `ModelCapabilities`
the provider's real tokenizer (for example, the OpenAI-compatible base class instance. Override `_lookup_capabilities()` only when you source capabilities
already integrates `tiktoken`). from a registry or remote API. `ModelProvider.count_tokens()` uses a simple
4-characters-per-token estimate so providers work out of the box—override it
only when you can call the provider's real tokenizer (for example, the
OpenAI-compatible base class integrates `tiktoken`).
#### Option B: OpenAI-Compatible Provider (Simplified) #### Option B: OpenAI-Compatible Provider (Simplified)
@@ -172,26 +167,16 @@ class ExampleProvider(OpenAICompatibleProvider):
def __init__(self, api_key: str, **kwargs): def __init__(self, api_key: str, **kwargs):
kwargs.setdefault("base_url", "https://api.example.com/v1") kwargs.setdefault("base_url", "https://api.example.com/v1")
super().__init__(api_key, **kwargs) super().__init__(api_key, **kwargs)
def get_capabilities(self, model_name: str) -> ModelCapabilities:
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.MODEL_CAPABILITIES:
raise ValueError(f"Unsupported model: {model_name}")
return self.MODEL_CAPABILITIES[resolved_name]
def get_provider_type(self) -> ProviderType: def get_provider_type(self) -> ProviderType:
return ProviderType.EXAMPLE return ProviderType.EXAMPLE
def validate_model_name(self, model_name: str) -> bool:
resolved_name = self._resolve_model_name(model_name)
return resolved_name in self.MODEL_CAPABILITIES
def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
# IMPORTANT: Resolve aliases before API call
resolved_model_name = self._resolve_model_name(model_name)
return super().generate_content(prompt=prompt, model_name=resolved_model_name, **kwargs)
``` ```
`OpenAICompatibleProvider` already exposes the declared models via
`MODEL_CAPABILITIES`, resolves aliases through the shared base pipeline, and
enforces restrictions. Most subclasses only need to provide the class metadata
shown above.
### 3. Register Your Provider ### 3. Register Your Provider
Add environment variable mapping in `providers/registry.py`: Add environment variable mapping in `providers/registry.py`:
@@ -237,15 +222,11 @@ DISABLED_TOOLS=debug,tracer
Create basic tests to verify your implementation: Create basic tests to verify your implementation:
```python ```python
# Test model validation
provider = ExampleModelProvider("test-key")
assert provider.validate_model_name("large") == True
assert provider.validate_model_name("unknown") == False
# Test capabilities # Test capabilities
caps = provider.get_capabilities("large") provider = ExampleModelProvider("test-key")
assert caps.context_window > 0 capabilities = provider.get_capabilities("large")
assert caps.provider == ProviderType.EXAMPLE assert capabilities.context_window > 0
assert capabilities.provider == ProviderType.EXAMPLE
``` ```
@@ -259,31 +240,19 @@ When a user requests a model, providers are checked in priority order:
3. **OpenRouter** - catch-all for everything else 3. **OpenRouter** - catch-all for everything else
### Model Validation ### Model Validation
Your `validate_model_name()` should **only** return `True` for models you explicitly support: `ModelProvider.validate_model_name()` delegates to `get_capabilities()` so most
providers can rely on the shared implementation. Override it only when you need
```python to opt out of that pipeline—for example, `CustomProvider` declines OpenRouter
def validate_model_name(self, model_name: str) -> bool: models so they fall through to the dedicated OpenRouter provider.
resolved_name = self._resolve_model_name(model_name)
return resolved_name in self.MODEL_CAPABILITIES # Be specific!
```
### Model Aliases ### Model Aliases
The base class handles alias resolution automatically via the `aliases` field in `ModelCapabilities`. Aliases declared on `ModelCapabilities` are applied automatically via
`_resolve_model_name()`, and both the validation and request flows call it
before touching your SDK. Override `generate_content()` only when your provider
needs additional alias handling beyond the shared behaviour.
## Important Notes ## Important Notes
### Alias Resolution in OpenAI-Compatible Providers
If using `OpenAICompatibleProvider` with aliases, **you must override `generate_content()`** to resolve aliases before API calls:
```python
def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
# Resolve alias before API call
resolved_model_name = self._resolve_model_name(model_name)
return super().generate_content(prompt=prompt, model_name=resolved_model_name, **kwargs)
```
Without this, API calls with aliases like `"large"` will fail because your API doesn't recognize the alias.
## Best Practices ## Best Practices
- **Be specific in model validation** - only accept models you actually support - **Be specific in model validation** - only accept models you actually support

View File

@@ -43,128 +43,37 @@ class ModelProvider(ABC):
self.api_key = api_key self.api_key = api_key
self.config = kwargs self.config = kwargs
@abstractmethod # ------------------------------------------------------------------
def get_capabilities(self, model_name: str) -> ModelCapabilities: # Provider identity & capability surface
"""Get capabilities for a specific model.""" # ------------------------------------------------------------------
pass
@abstractmethod
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using the model.
Args:
prompt: User prompt to send to the model
model_name: Name of the model to use
system_prompt: Optional system prompt for model behavior
temperature: Sampling temperature (0-2)
max_output_tokens: Maximum tokens to generate
**kwargs: Provider-specific parameters
Returns:
ModelResponse with generated content and metadata
"""
pass
def count_tokens(self, text: str, model_name: str) -> int:
"""Estimate token usage for a piece of text.
Providers can rely on this shared implementation or override it when
they expose a more accurate tokenizer. This default uses a simple
character-based heuristic so it works even without provider-specific
tooling.
"""
resolved_model = self._resolve_model_name(model_name)
if not text:
return 0
# Rough estimation: ~4 characters per token for English text
estimated = max(1, len(text) // 4)
logger.debug("Estimating %s tokens for model %s via character heuristic", estimated, resolved_model)
return estimated
@abstractmethod @abstractmethod
def get_provider_type(self) -> ProviderType: def get_provider_type(self) -> ProviderType:
"""Get the provider type.""" """Return the concrete provider identity."""
pass
@abstractmethod def get_capabilities(self, model_name: str) -> ModelCapabilities:
def validate_model_name(self, model_name: str) -> bool: """Resolve capability metadata for a model name.
"""Validate if the model name is supported by this provider."""
pass
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None: This centralises the alias resolution → lookup → restriction check
"""Validate model parameters against capabilities. pipeline so providers only override the pieces they genuinely need to
customise. Subclasses usually only override ``_lookup_capabilities`` to
Raises: integrate a registry or dynamic source, or ``_finalise_capabilities`` to
ValueError: If parameters are invalid tweak the returned object.
""" """
capabilities = self.get_capabilities(model_name)
# Validate temperature using constraint resolved_name = self._resolve_model_name(model_name)
if not capabilities.temperature_constraint.validate(temperature): capabilities = self._lookup_capabilities(resolved_name, model_name)
constraint_desc = capabilities.temperature_constraint.get_description()
raise ValueError(f"Temperature {temperature} is invalid for model {model_name}. {constraint_desc}")
def get_model_configurations(self) -> dict[str, ModelCapabilities]: if capabilities is None:
"""Get model configurations for this provider. self._raise_unsupported_model(model_name)
This is a hook method that subclasses can override to provide self._ensure_model_allowed(capabilities, resolved_name, model_name)
their model configurations from different sources. return self._finalise_capabilities(capabilities, resolved_name, model_name)
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
"""Return the provider's statically declared model capabilities."""
Returns:
Dictionary mapping model names to their ModelCapabilities objects
"""
model_map = getattr(self, "MODEL_CAPABILITIES", None)
if isinstance(model_map, dict) and model_map:
return {k: v for k, v in model_map.items() if isinstance(v, ModelCapabilities)}
return {} return {}
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model shorthand to full name.
This implementation uses the hook methods to support different
model configuration sources.
Args:
model_name: Model name that may be an alias
Returns:
Resolved model name
"""
# Get model configurations from the hook method
model_configs = self.get_model_configurations()
# First check if it's already a base model name (case-sensitive exact match)
if model_name in model_configs:
return model_name
# Check case-insensitively for both base models and aliases
model_name_lower = model_name.lower()
# Check base model names case-insensitively
for base_model in model_configs:
if base_model.lower() == model_name_lower:
return base_model
# Check aliases from the model configurations
alias_map = ModelCapabilities.collect_aliases(model_configs)
for base_model, aliases in alias_map.items():
if any(alias.lower() == model_name_lower for alias in aliases):
return base_model
# If not found, return as-is
return model_name
def list_models( def list_models(
self, self,
*, *,
@@ -175,7 +84,7 @@ class ModelProvider(ABC):
) -> list[str]: ) -> list[str]:
"""Return formatted model names supported by this provider.""" """Return formatted model names supported by this provider."""
model_configs = self.get_model_configurations() model_configs = self.get_all_model_capabilities()
if not model_configs: if not model_configs:
return [] return []
@@ -202,36 +111,155 @@ class ModelProvider(ABC):
unique=unique, unique=unique,
) )
def close(self): # ------------------------------------------------------------------
"""Clean up any resources held by the provider. # Request execution
# ------------------------------------------------------------------
@abstractmethod
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using the model."""
def count_tokens(self, text: str, model_name: str) -> int:
"""Estimate token usage for a piece of text."""
resolved_model = self._resolve_model_name(model_name)
if not text:
return 0
estimated = max(1, len(text) // 4)
logger.debug("Estimating %s tokens for model %s via character heuristic", estimated, resolved_model)
return estimated
def close(self) -> None:
"""Clean up any resources held by the provider."""
Default implementation does nothing.
Subclasses should override if they hold resources that need cleanup.
"""
# Base implementation: no resources to clean up
return return
# ------------------------------------------------------------------
# Validation hooks
# ------------------------------------------------------------------
def validate_model_name(self, model_name: str) -> bool:
"""Return ``True`` when the model resolves to an allowed capability."""
try:
self.get_capabilities(model_name)
except ValueError:
return False
return True
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
"""Validate model parameters against capabilities."""
capabilities = self.get_capabilities(model_name)
if not capabilities.temperature_constraint.validate(temperature):
constraint_desc = capabilities.temperature_constraint.get_description()
raise ValueError(f"Temperature {temperature} is invalid for model {model_name}. {constraint_desc}")
# ------------------------------------------------------------------
# Preference / registry hooks
# ------------------------------------------------------------------
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]: def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
"""Get the preferred model from this provider for a given category. """Get the preferred model from this provider for a given category."""
Args:
category: The tool category requiring a model
allowed_models: Pre-filtered list of model names that are allowed by restrictions
Returns:
Model name if this provider has a preference, None otherwise
"""
# Default implementation - providers can override with specific logic
return None return None
def get_model_registry(self) -> Optional[dict[str, Any]]: def get_model_registry(self) -> Optional[dict[str, Any]]:
"""Get the model registry for providers that maintain one. """Return the model registry backing this provider, if any."""
This is a hook method for providers like CustomProvider that maintain return None
a dynamic model registry.
# ------------------------------------------------------------------
# Capability lookup pipeline
# ------------------------------------------------------------------
def _lookup_capabilities(
self,
canonical_name: str,
requested_name: Optional[str] = None,
) -> Optional[ModelCapabilities]:
"""Return ``ModelCapabilities`` for the canonical model name."""
return self.get_all_model_capabilities().get(canonical_name)
def _ensure_model_allowed(
self,
capabilities: ModelCapabilities,
canonical_name: str,
requested_name: str,
) -> None:
"""Raise ``ValueError`` if the model violates restriction policy."""
try:
from utils.model_restrictions import get_restriction_service
except Exception: # pragma: no cover - only triggered if service import breaks
return
restriction_service = get_restriction_service()
if not restriction_service:
return
if restriction_service.is_allowed(self.get_provider_type(), canonical_name, requested_name):
return
raise ValueError(
f"{self.get_provider_type().value} model '{canonical_name}' is not allowed by restriction policy."
)
def _finalise_capabilities(
self,
capabilities: ModelCapabilities,
canonical_name: str,
requested_name: str,
) -> ModelCapabilities:
"""Allow subclasses to adjust capability metadata before returning."""
return capabilities
def _raise_unsupported_model(self, model_name: str) -> None:
"""Raise the canonical unsupported-model error."""
raise ValueError(f"Unsupported model '{model_name}' for provider {self.get_provider_type().value}.")
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model shorthand to full name.
This implementation uses the hook methods to support different
model configuration sources.
Args:
model_name: Model name that may be an alias
Returns: Returns:
Model registry dict or None if not applicable Resolved model name
""" """
# Default implementation - most providers don't have a registry # Get model configurations from the hook method
return None model_configs = self.get_all_model_capabilities()
# First check if it's already a base model name (case-sensitive exact match)
if model_name in model_configs:
return model_name
# Check case-insensitively for both base models and aliases
model_name_lower = model_name.lower()
# Check base model names case-insensitively
for base_model in model_configs:
if base_model.lower() == model_name_lower:
return base_model
# Check aliases from the model configurations
alias_map = ModelCapabilities.collect_aliases(model_configs)
for base_model, aliases in alias_map.items():
if any(alias.lower() == model_name_lower for alias in aliases):
return base_model
# If not found, return as-is
return model_name

View File

@@ -83,117 +83,69 @@ class CustomProvider(OpenAICompatibleProvider):
aliases = self._registry.list_aliases() aliases = self._registry.list_aliases()
logging.info(f"Custom provider loaded {len(models)} models with {len(aliases)} aliases") logging.info(f"Custom provider loaded {len(models)} models with {len(aliases)} aliases")
def _resolve_model_name(self, model_name: str) -> str: # ------------------------------------------------------------------
"""Resolve model aliases to actual model names. # Capability surface
# ------------------------------------------------------------------
def _lookup_capabilities(
self,
canonical_name: str,
requested_name: Optional[str] = None,
) -> Optional[ModelCapabilities]:
"""Return custom capabilities from the registry or generic defaults."""
For Ollama-style models, strips version tags (e.g., 'llama3.2:latest' -> 'llama3.2') builtin = super()._lookup_capabilities(canonical_name, requested_name)
since the base model name is what's typically used in API calls. if builtin is not None:
return builtin
Args:
model_name: Input model name or alias
Returns:
Resolved model name with version tags stripped if applicable
"""
# First, try to resolve through registry as-is
config = self._registry.resolve(model_name)
if config:
if config.model_name != model_name:
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
return config.model_name
else:
# If not found in registry, handle version tags for local models
# Strip version tags (anything after ':') for Ollama-style models
if ":" in model_name:
base_model = model_name.split(":")[0]
logging.debug(f"Stripped version tag from '{model_name}' -> '{base_model}'")
# Try to resolve the base model through registry
base_config = self._registry.resolve(base_model)
if base_config:
logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'")
return base_config.model_name
else:
return base_model
else:
# If not found in registry and no version tag, return as-is
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
return model_name
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a custom model.
Args:
model_name: Name of the model (or alias)
Returns:
ModelCapabilities from registry or generic defaults
"""
# Try to get from registry first
capabilities = self._registry.get_capabilities(model_name)
capabilities = self._registry.get_capabilities(canonical_name)
if capabilities: if capabilities:
# Check if this is an OpenRouter model and apply restrictions config = self._registry.resolve(canonical_name)
config = self._registry.resolve(model_name) if config and getattr(config, "is_custom", False):
if config and not config.is_custom:
# This is an OpenRouter model, check restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.OPENROUTER, config.model_name, model_name):
raise ValueError(f"OpenRouter model '{model_name}' is not allowed by restriction policy.")
# Update provider type to OPENROUTER for OpenRouter models
capabilities.provider = ProviderType.OPENROUTER
else:
# Update provider type to CUSTOM for local custom models
capabilities.provider = ProviderType.CUSTOM capabilities.provider = ProviderType.CUSTOM
return capabilities return capabilities
else: # Non-custom models should fall through so OpenRouter handles them
# Resolve any potential aliases and create generic capabilities return None
resolved_name = self._resolve_model_name(model_name)
logging.debug( logging.debug(
f"Using generic capabilities for '{resolved_name}' via Custom API. " f"Using generic capabilities for '{canonical_name}' via Custom API. "
"Consider adding to custom_models.json for specific capabilities." "Consider adding to custom_models.json for specific capabilities."
) )
# Infer temperature behaviour for generic capability fallback supports_temperature, temperature_constraint, temperature_reason = TemperatureConstraint.resolve_settings(
supports_temperature, temperature_constraint, temperature_reason = TemperatureConstraint.resolve_settings( canonical_name
resolved_name )
)
logging.warning( logging.warning(
f"Model '{resolved_name}' not found in custom_models.json. Using generic capabilities with inferred settings. " f"Model '{canonical_name}' not found in custom_models.json. Using generic capabilities with inferred settings. "
f"Temperature support: {supports_temperature} ({temperature_reason}). " f"Temperature support: {supports_temperature} ({temperature_reason}). "
"For better accuracy, add this model to your custom_models.json configuration." "For better accuracy, add this model to your custom_models.json configuration."
) )
# Create generic capabilities with inferred defaults generic = ModelCapabilities(
capabilities = ModelCapabilities( provider=ProviderType.CUSTOM,
provider=ProviderType.CUSTOM, model_name=canonical_name,
model_name=resolved_name, friendly_name=f"{self.FRIENDLY_NAME} ({canonical_name})",
friendly_name=f"{self.FRIENDLY_NAME} ({resolved_name})", context_window=32_768,
context_window=32_768, # Conservative default max_output_tokens=32_768,
max_output_tokens=32_768, # Conservative default max output supports_extended_thinking=False,
supports_extended_thinking=False, # Most custom models don't support this supports_system_prompts=True,
supports_system_prompts=True, supports_streaming=True,
supports_streaming=True, supports_function_calling=False,
supports_function_calling=False, # Conservative default supports_temperature=supports_temperature,
supports_temperature=supports_temperature, temperature_constraint=temperature_constraint,
temperature_constraint=temperature_constraint, )
) generic._is_generic = True
return generic
# Mark as generic for validation purposes
capabilities._is_generic = True
return capabilities
def get_provider_type(self) -> ProviderType: def get_provider_type(self) -> ProviderType:
"""Get the provider type.""" """Identify this provider for restriction and logging logic."""
return ProviderType.CUSTOM return ProviderType.CUSTOM
# ------------------------------------------------------------------
# Validation
# ------------------------------------------------------------------
def validate_model_name(self, model_name: str) -> bool: def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is allowed. """Validate if the model name is allowed.
@@ -206,49 +158,41 @@ class CustomProvider(OpenAICompatibleProvider):
Returns: Returns:
True if model is intended for custom/local endpoint True if model is intended for custom/local endpoint
""" """
# logging.debug(f"Custom provider validating model: '{model_name}'") if super().validate_model_name(model_name):
return True
# Try to resolve through registry first
config = self._registry.resolve(model_name) config = self._registry.resolve(model_name)
if config: if config and not getattr(config, "is_custom", False):
model_id = config.model_name return False
# Use explicit is_custom flag for clean validation
if config.is_custom:
logging.debug(f"... [Custom] Model '{model_name}' -> '{model_id}' validated via registry")
return True
else:
# This is a cloud/OpenRouter model - CustomProvider should NOT handle these
# Let OpenRouter provider handle them instead
# logging.debug(f"... [Custom] Model '{model_name}' -> '{model_id}' not custom (defer to OpenRouter)")
return False
# Handle version tags for unknown models (e.g., "my-model:latest")
clean_model_name = model_name clean_model_name = model_name
if ":" in model_name: if ":" in model_name:
clean_model_name = model_name.split(":")[0] clean_model_name = model_name.split(":", 1)[0]
logging.debug(f"Stripped version tag from '{model_name}' -> '{clean_model_name}'") logging.debug(f"Stripped version tag from '{model_name}' -> '{clean_model_name}'")
# Try to resolve the clean name
if super().validate_model_name(clean_model_name):
return True
config = self._registry.resolve(clean_model_name) config = self._registry.resolve(clean_model_name)
if config: if config and not getattr(config, "is_custom", False):
return self.validate_model_name(clean_model_name) # Recursively validate clean name return False
# For unknown models (not in registry), only accept if they look like local models lowered = clean_model_name.lower()
# This maintains backward compatibility for custom models not yet in the registry if any(indicator in lowered for indicator in ["local", "ollama", "vllm", "lmstudio"]):
# Accept models with explicit local indicators in the name
if any(indicator in clean_model_name.lower() for indicator in ["local", "ollama", "vllm", "lmstudio"]):
logging.debug(f"Model '{clean_model_name}' validated via local indicators") logging.debug(f"Model '{clean_model_name}' validated via local indicators")
return True return True
# Accept simple model names without vendor prefix (likely local/custom models)
if "/" not in clean_model_name: if "/" not in clean_model_name:
logging.debug(f"Model '{clean_model_name}' validated as potential local model (no vendor prefix)") logging.debug(f"Model '{clean_model_name}' validated as potential local model (no vendor prefix)")
return True return True
# Reject everything else (likely cloud models not in registry)
logging.debug(f"Model '{model_name}' rejected by custom provider (appears to be cloud model)") logging.debug(f"Model '{model_name}' rejected by custom provider (appears to be cloud model)")
return False return False
# ------------------------------------------------------------------
# Request execution
# ------------------------------------------------------------------
def generate_content( def generate_content(
self, self,
prompt: str, prompt: str,
@@ -284,25 +228,41 @@ class CustomProvider(OpenAICompatibleProvider):
**kwargs, **kwargs,
) )
def get_model_configurations(self) -> dict[str, ModelCapabilities]: # ------------------------------------------------------------------
"""Get model configurations from the registry. # Registry helpers
# ------------------------------------------------------------------
For CustomProvider, we convert registry configurations to ModelCapabilities objects. def _resolve_model_name(self, model_name: str) -> str:
"""Resolve registry aliases and strip version tags for local models."""
Returns: config = self._registry.resolve(model_name)
Dictionary mapping model names to their ModelCapabilities objects if config:
""" if config.model_name != model_name:
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
return config.model_name
configs = {} if ":" in model_name:
base_model = model_name.split(":")[0]
logging.debug(f"Stripped version tag from '{model_name}' -> '{base_model}'")
if self._registry: base_config = self._registry.resolve(base_model)
# Get all models from registry if base_config:
for model_name in self._registry.list_models(): logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'")
# Only include custom models that this provider validates return base_config.model_name
if self.validate_model_name(model_name): return base_model
config = self._registry.resolve(model_name)
if config and config.is_custom:
# Use ModelCapabilities directly from registry
configs[model_name] = config
return configs logging.debug(f"Model '{model_name}' not found in registry, using as-is")
return model_name
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
"""Expose registry capabilities for models marked as custom."""
if not self._registry:
return {}
capabilities: dict[str, ModelCapabilities] = {}
for model_name in self._registry.list_models():
config = self._registry.resolve(model_name)
if config and getattr(config, "is_custom", False):
capabilities[model_name] = config
return capabilities

View File

@@ -261,68 +261,10 @@ class DIALModelProvider(OpenAICompatibleProvider):
logger.info(f"Initialized DIAL provider with host: {dial_host} and api-version: {self.api_version}") logger.info(f"Initialized DIAL provider with host: {dial_host} and api-version: {self.api_version}")
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific model.
Args:
model_name: Name of the model (can be shorthand)
Returns:
ModelCapabilities object
Raises:
ValueError: If model is not supported or not allowed
"""
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.MODEL_CAPABILITIES:
raise ValueError(f"Unsupported DIAL model: {model_name}")
# Check restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
# Return the ModelCapabilities object directly from MODEL_CAPABILITIES
return self.MODEL_CAPABILITIES[resolved_name]
def get_provider_type(self) -> ProviderType: def get_provider_type(self) -> ProviderType:
"""Get the provider type.""" """Get the provider type."""
return ProviderType.DIAL return ProviderType.DIAL
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported.
Args:
model_name: Model name to validate
Returns:
True if model is supported and allowed, False otherwise
"""
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.MODEL_CAPABILITIES:
return False
# Check against base class allowed_models if configured
if self.allowed_models is not None:
# Check both original and resolved names (case-insensitive)
if model_name.lower() not in self.allowed_models and resolved_name.lower() not in self.allowed_models:
logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' not in allowed_models list")
return False
# Also check restrictions via ModelRestrictionService
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' blocked by restrictions")
return False
return True
def _get_deployment_client(self, deployment: str): def _get_deployment_client(self, deployment: str):
"""Get or create a cached client for a specific deployment. """Get or create a cached client for a specific deployment.
@@ -504,7 +446,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}" f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
) )
def close(self): def close(self) -> None:
"""Clean up HTTP clients when provider is closed.""" """Clean up HTTP clients when provider is closed."""
logger.info("Closing DIAL provider HTTP clients...") logger.info("Closing DIAL provider HTTP clients...")

View File

@@ -131,6 +131,19 @@ class GeminiModelProvider(ModelProvider):
self._token_counters = {} # Cache for token counting self._token_counters = {} # Cache for token counting
self._base_url = kwargs.get("base_url", None) # Optional custom endpoint self._base_url = kwargs.get("base_url", None) # Optional custom endpoint
# ------------------------------------------------------------------
# Capability surface
# ------------------------------------------------------------------
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
"""Return statically defined Gemini capabilities."""
return dict(self.MODEL_CAPABILITIES)
# ------------------------------------------------------------------
# Client access
# ------------------------------------------------------------------
@property @property
def client(self): def client(self):
"""Lazy initialization of Gemini client.""" """Lazy initialization of Gemini client."""
@@ -146,25 +159,9 @@ class GeminiModelProvider(ModelProvider):
self._client = genai.Client(api_key=self.api_key) self._client = genai.Client(api_key=self.api_key)
return self._client return self._client
def get_capabilities(self, model_name: str) -> ModelCapabilities: # ------------------------------------------------------------------
"""Get capabilities for a specific Gemini model.""" # Request execution
# Resolve shorthand # ------------------------------------------------------------------
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.MODEL_CAPABILITIES:
raise ValueError(f"Unsupported Gemini model: {model_name}")
# Check if model is allowed by restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
# IMPORTANT: Parameter order is (provider_type, model_name, original_name)
# resolved_name is the canonical model name, model_name is the user input
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
raise ValueError(f"Gemini model '{resolved_name}' is not allowed by restriction policy.")
# Return the ModelCapabilities object directly from MODEL_CAPABILITIES
return self.MODEL_CAPABILITIES[resolved_name]
def generate_content( def generate_content(
self, self,
@@ -365,26 +362,6 @@ class GeminiModelProvider(ModelProvider):
"""Get the provider type.""" """Get the provider type."""
return ProviderType.GOOGLE return ProviderType.GOOGLE
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported and allowed."""
resolved_name = self._resolve_model_name(model_name)
# First check if model is supported
if resolved_name not in self.MODEL_CAPABILITIES:
return False
# Then check if model is allowed by restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
# IMPORTANT: Parameter order is (provider_type, model_name, original_name)
# resolved_name is the canonical model name, model_name is the user input
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
logger.debug(f"Gemini model '{model_name}' -> '{resolved_name}' blocked by restrictions")
return False
return True
def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int: def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int:
"""Get actual thinking token budget for a model and thinking mode.""" """Get actual thinking token budget for a model and thinking mode."""
resolved_name = self._resolve_model_name(model_name) resolved_name = self._resolve_model_name(model_name)

View File

@@ -5,7 +5,6 @@ import ipaddress
import logging import logging
import os import os
import time import time
from abc import abstractmethod
from typing import Optional from typing import Optional
from urllib.parse import urlparse from urllib.parse import urlparse
@@ -61,6 +60,33 @@ class OpenAICompatibleProvider(ModelProvider):
"This may be insecure. Consider setting an API key for authentication." "This may be insecure. Consider setting an API key for authentication."
) )
def _ensure_model_allowed(
self,
capabilities: ModelCapabilities,
canonical_name: str,
requested_name: str,
) -> None:
"""Respect provider-specific allowlists before default restriction checks."""
super()._ensure_model_allowed(capabilities, canonical_name, requested_name)
if self.allowed_models is not None:
requested = requested_name.lower()
canonical = canonical_name.lower()
if requested not in self.allowed_models and canonical not in self.allowed_models:
raise ValueError(
f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}"
)
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
"""Return statically declared capabilities for OpenAI-compatible providers."""
model_map = getattr(self, "MODEL_CAPABILITIES", None)
if isinstance(model_map, dict):
return {k: v for k, v in model_map.items() if isinstance(v, ModelCapabilities)}
return {}
def _parse_allowed_models(self) -> Optional[set[str]]: def _parse_allowed_models(self) -> Optional[set[str]]:
"""Parse allowed models from environment variable. """Parse allowed models from environment variable.
@@ -686,30 +712,6 @@ class OpenAICompatibleProvider(ModelProvider):
return super().count_tokens(text, model_name) return super().count_tokens(text, model_name)
@abstractmethod
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific model.
Must be implemented by subclasses.
"""
pass
@abstractmethod
def get_provider_type(self) -> ProviderType:
"""Get the provider type.
Must be implemented by subclasses.
"""
pass
@abstractmethod
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported.
Must be implemented by subclasses.
"""
pass
def _is_error_retryable(self, error: Exception) -> bool: def _is_error_retryable(self, error: Exception) -> bool:
"""Determine if an error should be retried based on structured error codes. """Determine if an error should be retried based on structured error codes.

View File

@@ -174,106 +174,61 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
kwargs.setdefault("base_url", "https://api.openai.com/v1") kwargs.setdefault("base_url", "https://api.openai.com/v1")
super().__init__(api_key, **kwargs) super().__init__(api_key, **kwargs)
def get_capabilities(self, model_name: str) -> ModelCapabilities: # ------------------------------------------------------------------
"""Get capabilities for a specific OpenAI model.""" # Capability surface
# First check if it's a key in MODEL_CAPABILITIES # ------------------------------------------------------------------
if model_name in self.MODEL_CAPABILITIES:
self._check_model_restrictions(model_name, model_name)
return self.MODEL_CAPABILITIES[model_name]
# Try resolving as alias def _lookup_capabilities(
resolved_name = self._resolve_model_name(model_name) self,
canonical_name: str,
requested_name: Optional[str] = None,
) -> Optional[ModelCapabilities]:
"""Look up OpenAI capabilities from built-ins or the custom registry."""
# Check if resolved name is a key builtin = super()._lookup_capabilities(canonical_name, requested_name)
if resolved_name in self.MODEL_CAPABILITIES: if builtin is not None:
self._check_model_restrictions(resolved_name, model_name) return builtin
return self.MODEL_CAPABILITIES[resolved_name]
# Finally check if resolved name matches any API model name
for key, capabilities in self.MODEL_CAPABILITIES.items():
if resolved_name == capabilities.model_name:
self._check_model_restrictions(key, model_name)
return capabilities
# Check custom models registry for user-configured OpenAI models
try: try:
from .openrouter_registry import OpenRouterModelRegistry from .openrouter_registry import OpenRouterModelRegistry
registry = OpenRouterModelRegistry() registry = OpenRouterModelRegistry()
config = registry.get_model_config(resolved_name) config = registry.get_model_config(canonical_name)
if config and config.provider == ProviderType.OPENAI: if config and config.provider == ProviderType.OPENAI:
self._check_model_restrictions(config.model_name, model_name)
# Update provider type to ensure consistency
config.provider = ProviderType.OPENAI
return config return config
except Exception as e: except Exception as exc: # pragma: no cover - registry failures are non-critical
# Log but don't fail - registry might not be available logger.debug(f"Could not resolve custom OpenAI model '{canonical_name}': {exc}")
logger.debug(f"Could not check custom models registry for '{resolved_name}': {e}")
return None
def _finalise_capabilities(
self,
capabilities: ModelCapabilities,
canonical_name: str,
requested_name: str,
) -> ModelCapabilities:
"""Ensure registry-sourced models report the correct provider type."""
if capabilities.provider != ProviderType.OPENAI:
capabilities.provider = ProviderType.OPENAI
return capabilities
def _raise_unsupported_model(self, model_name: str) -> None:
raise ValueError(f"Unsupported OpenAI model: {model_name}") raise ValueError(f"Unsupported OpenAI model: {model_name}")
def _check_model_restrictions(self, provider_model_name: str, user_model_name: str) -> None: # ------------------------------------------------------------------
"""Check if a model is allowed by restriction policy. # Provider identity
# ------------------------------------------------------------------
Args:
provider_model_name: The model name used by the provider
user_model_name: The model name requested by the user
Raises:
ValueError: If the model is not allowed by restriction policy
"""
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.OPENAI, provider_model_name, user_model_name):
raise ValueError(f"OpenAI model '{user_model_name}' is not allowed by restriction policy.")
def get_provider_type(self) -> ProviderType: def get_provider_type(self) -> ProviderType:
"""Get the provider type.""" """Get the provider type."""
return ProviderType.OPENAI return ProviderType.OPENAI
def validate_model_name(self, model_name: str) -> bool: # ------------------------------------------------------------------
"""Validate if the model name is supported and allowed.""" # Request execution
resolved_name = self._resolve_model_name(model_name) # ------------------------------------------------------------------
# First, determine which model name to check against restrictions.
model_to_check = None
is_custom_model = False
if resolved_name in self.MODEL_CAPABILITIES:
model_to_check = resolved_name
else:
# If not a built-in model, check the custom models registry.
try:
from .openrouter_registry import OpenRouterModelRegistry
registry = OpenRouterModelRegistry()
config = registry.get_model_config(resolved_name)
if config and config.provider == ProviderType.OPENAI:
model_to_check = config.model_name
is_custom_model = True
except Exception as e:
# Log but don't fail - registry might not be available.
logger.debug(f"Could not check custom models registry for '{resolved_name}': {e}")
# If no model was found (neither built-in nor custom), it's invalid.
if not model_to_check:
return False
# Now, perform the restriction check once.
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.OPENAI, model_to_check, model_name):
model_type = "custom " if is_custom_model else ""
logger.debug(f"OpenAI {model_type}model '{model_name}' -> '{resolved_name}' blocked by restrictions")
return False
return True
def generate_content( def generate_content(
self, self,
@@ -298,6 +253,10 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
**kwargs, **kwargs,
) )
# ------------------------------------------------------------------
# Provider preferences
# ------------------------------------------------------------------
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]: def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
"""Get OpenAI's preferred model for a given category from allowed models. """Get OpenAI's preferred model for a given category from allowed models.

View File

@@ -61,108 +61,52 @@ class OpenRouterProvider(OpenAICompatibleProvider):
aliases = self._registry.list_aliases() aliases = self._registry.list_aliases()
logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases") logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases")
def _resolve_model_name(self, model_name: str) -> str: # ------------------------------------------------------------------
"""Resolve model aliases to OpenRouter model names. # Capability surface
# ------------------------------------------------------------------
Args: def _lookup_capabilities(
model_name: Input model name or alias self,
canonical_name: str,
Returns: requested_name: Optional[str] = None,
Resolved OpenRouter model name ) -> Optional[ModelCapabilities]:
""" """Fetch OpenRouter capabilities from the registry or build a generic fallback."""
# Try to resolve through registry
config = self._registry.resolve(model_name)
if config:
if config.model_name != model_name:
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
return config.model_name
else:
# If not found in registry, return as-is
# This allows using models not in our config file
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
return model_name
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a model.
Args:
model_name: Name of the model (or alias)
Returns:
ModelCapabilities from registry or generic defaults
"""
# Try to get from registry first
capabilities = self._registry.get_capabilities(model_name)
capabilities = self._registry.get_capabilities(canonical_name)
if capabilities: if capabilities:
return capabilities return capabilities
else:
# Resolve any potential aliases and create generic capabilities
resolved_name = self._resolve_model_name(model_name)
logging.debug( logging.debug(
f"Using generic capabilities for '{resolved_name}' via OpenRouter. " f"Using generic capabilities for '{canonical_name}' via OpenRouter. "
"Consider adding to custom_models.json for specific capabilities." "Consider adding to custom_models.json for specific capabilities."
) )
# Create generic capabilities with conservative defaults generic = ModelCapabilities(
capabilities = ModelCapabilities( provider=ProviderType.OPENROUTER,
provider=ProviderType.OPENROUTER, model_name=canonical_name,
model_name=resolved_name, friendly_name=self.FRIENDLY_NAME,
friendly_name=self.FRIENDLY_NAME, context_window=32_768,
context_window=32_768, # Conservative default context window max_output_tokens=32_768,
max_output_tokens=32_768, supports_extended_thinking=False,
supports_extended_thinking=False, supports_system_prompts=True,
supports_system_prompts=True, supports_streaming=True,
supports_streaming=True, supports_function_calling=False,
supports_function_calling=False, temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0), )
) generic._is_generic = True
return generic
# Mark as generic for validation purposes # ------------------------------------------------------------------
capabilities._is_generic = True # Provider identity
# ------------------------------------------------------------------
return capabilities
def get_provider_type(self) -> ProviderType: def get_provider_type(self) -> ProviderType:
"""Get the provider type.""" """Identify this provider for restrictions and logging."""
return ProviderType.OPENROUTER return ProviderType.OPENROUTER
def validate_model_name(self, model_name: str) -> bool: # ------------------------------------------------------------------
"""Validate if the model name is allowed. # Request execution
# ------------------------------------------------------------------
As the catch-all provider, OpenRouter accepts any model name that wasn't
handled by higher-priority providers. OpenRouter will validate based on
the API key's permissions and local restrictions.
Args:
model_name: Model name to validate
Returns:
True if model is allowed, False if restricted
"""
# Check model restrictions if configured
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if restriction_service:
# Check if model name itself is allowed
if restriction_service.is_allowed(self.get_provider_type(), model_name):
return True
# Also check aliases - model_name might be an alias
model_config = self._registry.resolve(model_name)
if model_config and model_config.aliases:
for alias in model_config.aliases:
if restriction_service.is_allowed(self.get_provider_type(), alias):
return True
# If restrictions are configured and model/alias not in allowed list, reject
return False
# No restrictions configured - accept any model name as the fallback provider
return True
def generate_content( def generate_content(
self, self,
@@ -204,6 +148,10 @@ class OpenRouterProvider(OpenAICompatibleProvider):
**kwargs, **kwargs,
) )
# ------------------------------------------------------------------
# Registry helpers
# ------------------------------------------------------------------
def list_models( def list_models(
self, self,
*, *,
@@ -227,6 +175,12 @@ class OpenRouterProvider(OpenAICompatibleProvider):
if not config: if not config:
continue continue
# Custom models belong to CustomProvider; skip them here so the two
# providers don't race over the same registrations (important for tests
# that stub the registry with minimal objects lacking attrs).
if hasattr(config, "is_custom") and config.is_custom is True:
continue
if restriction_service: if restriction_service:
allowed = restriction_service.is_allowed(self.get_provider_type(), model_name) allowed = restriction_service.is_allowed(self.get_provider_type(), model_name)
@@ -255,24 +209,37 @@ class OpenRouterProvider(OpenAICompatibleProvider):
unique=unique, unique=unique,
) )
def get_model_configurations(self) -> dict[str, ModelCapabilities]: # ------------------------------------------------------------------
"""Get model configurations from the registry. # Registry helpers
# ------------------------------------------------------------------
For OpenRouter, we convert registry configurations to ModelCapabilities objects. def _resolve_model_name(self, model_name: str) -> str:
"""Resolve aliases defined in the OpenRouter registry."""
Returns: config = self._registry.resolve(model_name)
Dictionary mapping model names to their ModelCapabilities objects if config:
""" if config.model_name != model_name:
configs = {} logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
return config.model_name
if self._registry: logging.debug(f"Model '{model_name}' not found in registry, using as-is")
# Get all models from registry return model_name
for model_name in self._registry.list_models():
# Only include models that this provider validates
if self.validate_model_name(model_name):
config = self._registry.resolve(model_name)
if config and not config.is_custom: # Only OpenRouter models, not custom ones
# Use ModelCapabilities directly from registry
configs[model_name] = config
return configs def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
"""Expose registry-backed OpenRouter capabilities."""
if not self._registry:
return {}
capabilities: dict[str, ModelCapabilities] = {}
for model_name in self._registry.list_models():
config = self._registry.resolve(model_name)
if not config:
continue
# See note in list_models: respect the CustomProvider boundary.
if hasattr(config, "is_custom") and config.is_custom is True:
continue
capabilities[model_name] = config
return capabilities

View File

@@ -64,6 +64,8 @@ class ModelProviderRegistry:
""" """
instance = cls() instance = cls()
instance._providers[provider_type] = provider_class instance._providers[provider_type] = provider_class
# Invalidate any cached instance so subsequent lookups use the new registration
instance._initialized_providers.pop(provider_type, None)
@classmethod @classmethod
def get_provider(cls, provider_type: ProviderType, force_new: bool = False) -> Optional[ModelProvider]: def get_provider(cls, provider_type: ProviderType, force_new: bool = False) -> Optional[ModelProvider]:

View File

@@ -85,46 +85,10 @@ class XAIModelProvider(OpenAICompatibleProvider):
kwargs.setdefault("base_url", "https://api.x.ai/v1") kwargs.setdefault("base_url", "https://api.x.ai/v1")
super().__init__(api_key, **kwargs) super().__init__(api_key, **kwargs)
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific X.AI model."""
# Resolve shorthand
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.MODEL_CAPABILITIES:
raise ValueError(f"Unsupported X.AI model: {model_name}")
# Check if model is allowed by restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name):
raise ValueError(f"X.AI model '{model_name}' is not allowed by restriction policy.")
# Return the ModelCapabilities object directly from MODEL_CAPABILITIES
return self.MODEL_CAPABILITIES[resolved_name]
def get_provider_type(self) -> ProviderType: def get_provider_type(self) -> ProviderType:
"""Get the provider type.""" """Get the provider type."""
return ProviderType.XAI return ProviderType.XAI
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported and allowed."""
resolved_name = self._resolve_model_name(model_name)
# First check if model is supported
if resolved_name not in self.MODEL_CAPABILITIES:
return False
# Then check if model is allowed by restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name):
logger.debug(f"X.AI model '{model_name}' -> '{resolved_name}' blocked by restrictions")
return False
return True
def generate_content( def generate_content(
self, self,
prompt: str, prompt: str,

View File

@@ -54,11 +54,9 @@ class TestCustomProvider:
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1") provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
# Test with a model that should be in the registry (OpenRouter model) # OpenRouter-backed models should be handled by the OpenRouter provider
capabilities = provider.get_capabilities("o3") # o3 is an OpenRouter model with pytest.raises(ValueError):
provider.get_capabilities("o3")
assert capabilities.provider == ProviderType.OPENROUTER # o3 is an OpenRouter model (is_custom=false)
assert capabilities.context_window > 0
# Test with a custom model (is_custom=true) # Test with a custom model (is_custom=true)
capabilities = provider.get_capabilities("local-llama") capabilities = provider.get_capabilities("local-llama")
@@ -168,7 +166,13 @@ class TestCustomProviderRegistration:
return CustomProvider(api_key="", base_url="http://localhost:11434/v1") return CustomProvider(api_key="", base_url="http://localhost:11434/v1")
with patch.dict( with patch.dict(
os.environ, {"OPENROUTER_API_KEY": "test-openrouter-key", "CUSTOM_API_PLACEHOLDER": "configured"} os.environ,
{
"OPENROUTER_API_KEY": "test-openrouter-key",
"CUSTOM_API_PLACEHOLDER": "configured",
"OPENROUTER_ALLOWED_MODELS": "llama,anthropic/claude-opus-4.1",
},
clear=True,
): ):
# Register both providers # Register both providers
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider) ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
@@ -195,18 +199,22 @@ class TestCustomProviderRegistration:
return CustomProvider(api_key="", base_url="http://localhost:11434/v1") return CustomProvider(api_key="", base_url="http://localhost:11434/v1")
with patch.dict( with patch.dict(
os.environ, {"OPENROUTER_API_KEY": "test-openrouter-key", "CUSTOM_API_PLACEHOLDER": "configured"} os.environ,
{
"OPENROUTER_API_KEY": "test-openrouter-key",
"CUSTOM_API_PLACEHOLDER": "configured",
"OPENROUTER_ALLOWED_MODELS": "",
},
clear=True,
): ):
# Register OpenRouter first (higher priority) import utils.model_restrictions
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, custom_provider_factory)
# Test model resolution - OpenRouter should win for shared aliases utils.model_restrictions._restriction_service = None
provider_for_model = ModelProviderRegistry.get_provider_for_model("llama") custom_provider = custom_provider_factory()
openrouter_provider = OpenRouterProvider(api_key="test-openrouter-key")
# OpenRouter should be selected first due to registration order assert not custom_provider.validate_model_name("llama")
assert provider_for_model is not None assert openrouter_provider.validate_model_name("llama")
# The exact provider type depends on which validates the model first
class TestConfigureProvidersFunction: class TestConfigureProvidersFunction:

View File

@@ -121,7 +121,7 @@ class TestDIALProvider:
"""Test that get_capabilities raises for invalid models.""" """Test that get_capabilities raises for invalid models."""
provider = DIALModelProvider("test-key") provider = DIALModelProvider("test-key")
with pytest.raises(ValueError, match="Unsupported DIAL model"): with pytest.raises(ValueError, match="Unsupported model 'invalid-model' for provider dial"):
provider.get_capabilities("invalid-model") provider.get_capabilities("invalid-model")
@patch("utils.model_restrictions.get_restriction_service") @patch("utils.model_restrictions.get_restriction_service")

View File

@@ -356,15 +356,13 @@ class TestCustomProviderOpenRouterRestrictions:
provider = CustomProvider(base_url="http://test.com/v1") provider = CustomProvider(base_url="http://test.com/v1")
# For OpenRouter models, get_capabilities should still work but mark them as OPENROUTER # For OpenRouter models, CustomProvider should defer by raising
# This tests the capabilities lookup, not validation with pytest.raises(ValueError):
capabilities = provider.get_capabilities("opus") provider.get_capabilities("opus")
assert capabilities.provider == ProviderType.OPENROUTER
# Should raise for disallowed OpenRouter model # Should raise for disallowed OpenRouter model (still defers)
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError):
provider.get_capabilities("haiku") provider.get_capabilities("haiku")
assert "not allowed by restriction policy" in str(exc_info.value)
# Should still work for custom models (is_custom=true) # Should still work for custom models (is_custom=true)
capabilities = provider.get_capabilities("local-llama") capabilities = provider.get_capabilities("local-llama")

View File

@@ -141,7 +141,7 @@ class TestXAIProvider:
"""Test error handling for unsupported models.""" """Test error handling for unsupported models."""
provider = XAIModelProvider("test-key") provider = XAIModelProvider("test-key")
with pytest.raises(ValueError, match="Unsupported X.AI model"): with pytest.raises(ValueError, match="Unsupported model 'invalid-model' for provider xai"):
provider.get_capabilities("invalid-model") provider.get_capabilities("invalid-model")
def test_extended_thinking_flags(self): def test_extended_thinking_flags(self):

View File

@@ -105,13 +105,11 @@ class ListModelsTool(BaseTool):
output_lines.append("**Status**: Configured and available") output_lines.append("**Status**: Configured and available")
output_lines.append("\n**Models**:") output_lines.append("\n**Models**:")
# Get models from the provider's model configurations aliases = []
for model_name, capabilities in provider.get_model_configurations().items(): for model_name, capabilities in provider.get_all_model_capabilities().items():
# Get description and context from the ModelCapabilities object
description = capabilities.description or "No description available" description = capabilities.description or "No description available"
context_window = capabilities.context_window context_window = capabilities.context_window
# Format context window
if context_window >= 1_000_000: if context_window >= 1_000_000:
context_str = f"{context_window // 1_000_000}M context" context_str = f"{context_window // 1_000_000}M context"
elif context_window >= 1_000: elif context_window >= 1_000:
@@ -120,31 +118,15 @@ class ListModelsTool(BaseTool):
context_str = f"{context_window} context" if context_window > 0 else "unknown context" context_str = f"{context_window} context" if context_window > 0 else "unknown context"
output_lines.append(f"- `{model_name}` - {context_str}") output_lines.append(f"- `{model_name}` - {context_str}")
output_lines.append(f" - {description}")
# Extract key capability from description for alias in capabilities.aliases or []:
if "Ultra-fast" in description: if alias != model_name:
output_lines.append(" - Fast processing, quick iterations") aliases.append(f"- `{alias}` → `{model_name}`")
elif "Deep reasoning" in description:
output_lines.append(" - Extended reasoning with thinking mode")
elif "Strong reasoning" in description:
output_lines.append(" - Logical problems, systematic analysis")
elif "EXTREMELY EXPENSIVE" in description:
output_lines.append(" - ⚠️ Professional grade (very expensive)")
elif "Advanced reasoning" in description:
output_lines.append(" - Advanced reasoning and complex analysis")
# Show aliases for this provider
aliases = []
for model_name, capabilities in provider.get_model_configurations().items():
if capabilities.aliases:
for alias in capabilities.aliases:
# Skip aliases that are the same as the model name to avoid duplicates
if alias != model_name:
aliases.append(f"- `{alias}` → `{model_name}`")
if aliases: if aliases:
output_lines.append("\n**Aliases**:") output_lines.append("\n**Aliases**:")
output_lines.extend(sorted(aliases)) # Sort for consistent output output_lines.extend(sorted(aliases))
else: else:
output_lines.append(f"**Status**: Not configured (set {info['env_key']})") output_lines.append(f"**Status**: Not configured (set {info['env_key']})")