refactor: cleanup provider base class; cleanup shared responsibilities; cleanup public contract
docs: document provider base class refactor: cleanup custom provider, it should only deal with `is_custom` model configurations fix: make sure openrouter provider does not load `is_custom` models fix: listmodels tool cleanup
This commit is contained in:
@@ -7,7 +7,7 @@ This guide explains how to add support for a new AI model provider to the Zen MC
|
|||||||
Each provider:
|
Each provider:
|
||||||
- Inherits from `ModelProvider` (base class) or `OpenAICompatibleProvider` (for OpenAI-compatible APIs)
|
- Inherits from `ModelProvider` (base class) or `OpenAICompatibleProvider` (for OpenAI-compatible APIs)
|
||||||
- Defines supported models using `ModelCapabilities` objects
|
- Defines supported models using `ModelCapabilities` objects
|
||||||
- Implements a few core abstract methods
|
- Implements the minimal abstract hooks (`get_provider_type()` and `generate_content()`)
|
||||||
- Gets registered automatically via environment variables
|
- Gets registered automatically via environment variables
|
||||||
|
|
||||||
## Choose Your Implementation Path
|
## Choose Your Implementation Path
|
||||||
@@ -15,11 +15,11 @@ Each provider:
|
|||||||
**Option A: Full Provider (`ModelProvider`)**
|
**Option A: Full Provider (`ModelProvider`)**
|
||||||
- For APIs with unique features or custom authentication
|
- For APIs with unique features or custom authentication
|
||||||
- Complete control over API calls and response handling
|
- Complete control over API calls and response handling
|
||||||
- Required methods: `generate_content()`, `get_capabilities()`, `validate_model_name()`, `get_provider_type()` (override `count_tokens()` only when you have a provider-accurate tokenizer)
|
- Implement `generate_content()` and `get_provider_type()`; override `get_all_model_capabilities()` to expose your catalogue and extend `_lookup_capabilities()` / `_ensure_model_allowed()` only when you need registry lookups or custom restriction rules (override `count_tokens()` only when you have a provider-accurate tokenizer)
|
||||||
|
|
||||||
**Option B: OpenAI-Compatible (`OpenAICompatibleProvider`)**
|
**Option B: OpenAI-Compatible (`OpenAICompatibleProvider`)**
|
||||||
- For APIs that follow OpenAI's chat completion format
|
- For APIs that follow OpenAI's chat completion format
|
||||||
- Only need to define: model configurations, capabilities, and validation
|
- Supply `MODEL_CAPABILITIES`, override `get_provider_type()`, and optionally adjust configuration (the base class handles alias resolution, validation, and request wiring)
|
||||||
- Inherits all API handling automatically
|
- Inherits all API handling automatically
|
||||||
|
|
||||||
⚠️ **Important**: If using aliases (like `"gpt"` → `"gpt-4"`), override `generate_content()` to resolve them before API calls.
|
⚠️ **Important**: If using aliases (like `"gpt"` → `"gpt-4"`), override `generate_content()` to resolve them before API calls.
|
||||||
@@ -62,8 +62,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class ExampleModelProvider(ModelProvider):
|
class ExampleModelProvider(ModelProvider):
|
||||||
"""Example model provider implementation."""
|
"""Example model provider implementation."""
|
||||||
|
|
||||||
# Define models using ModelCapabilities objects (like Gemini provider)
|
|
||||||
MODEL_CAPABILITIES = {
|
MODEL_CAPABILITIES = {
|
||||||
"example-large": ModelCapabilities(
|
"example-large": ModelCapabilities(
|
||||||
provider=ProviderType.EXAMPLE,
|
provider=ProviderType.EXAMPLE,
|
||||||
@@ -87,51 +86,47 @@ class ExampleModelProvider(ModelProvider):
|
|||||||
aliases=["small", "fast"],
|
aliases=["small", "fast"],
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, api_key: str, **kwargs):
|
def __init__(self, api_key: str, **kwargs):
|
||||||
super().__init__(api_key, **kwargs)
|
super().__init__(api_key, **kwargs)
|
||||||
# Initialize your API client here
|
# Initialize your API client here
|
||||||
|
|
||||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||||
|
return dict(self.MODEL_CAPABILITIES)
|
||||||
|
|
||||||
|
def get_provider_type(self) -> ProviderType:
|
||||||
|
return ProviderType.EXAMPLE
|
||||||
|
|
||||||
|
def generate_content(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model_name: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ModelResponse:
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
|
||||||
raise ValueError(f"Unsupported model: {model_name}")
|
|
||||||
|
|
||||||
# Apply restrictions if needed
|
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
restriction_service = get_restriction_service()
|
|
||||||
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
|
|
||||||
raise ValueError(f"Model '{model_name}' is not allowed.")
|
|
||||||
|
|
||||||
return self.MODEL_CAPABILITIES[resolved_name]
|
|
||||||
|
|
||||||
def generate_content(self, prompt: str, model_name: str, system_prompt: Optional[str] = None,
|
|
||||||
temperature: float = 0.7, max_output_tokens: Optional[int] = None, **kwargs) -> ModelResponse:
|
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
|
||||||
|
|
||||||
# Your API call logic here
|
# Your API call logic here
|
||||||
# response = your_api_client.generate(...)
|
# response = your_api_client.generate(...)
|
||||||
|
|
||||||
return ModelResponse(
|
return ModelResponse(
|
||||||
content="Generated response", # From your API
|
content="Generated response",
|
||||||
usage={"input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
|
usage={"input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
|
||||||
model_name=resolved_name,
|
model_name=resolved_name,
|
||||||
friendly_name="Example",
|
friendly_name="Example",
|
||||||
provider=ProviderType.EXAMPLE,
|
provider=ProviderType.EXAMPLE,
|
||||||
)
|
)
|
||||||
def get_provider_type(self) -> ProviderType:
|
|
||||||
return ProviderType.EXAMPLE
|
|
||||||
|
|
||||||
def validate_model_name(self, model_name: str) -> bool:
|
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
|
||||||
return resolved_name in self.MODEL_CAPABILITIES
|
|
||||||
```
|
```
|
||||||
|
|
||||||
`ModelProvider.count_tokens()` uses a simple 4-characters-per-token estimate so
|
`ModelProvider.get_capabilities()` automatically resolves aliases, enforces the
|
||||||
providers work out of the box. Override the method only when you can call into
|
shared restriction service, and returns the correct `ModelCapabilities`
|
||||||
the provider's real tokenizer (for example, the OpenAI-compatible base class
|
instance. Override `_lookup_capabilities()` only when you source capabilities
|
||||||
already integrates `tiktoken`).
|
from a registry or remote API. `ModelProvider.count_tokens()` uses a simple
|
||||||
|
4-characters-per-token estimate so providers work out of the box—override it
|
||||||
|
only when you can call the provider's real tokenizer (for example, the
|
||||||
|
OpenAI-compatible base class integrates `tiktoken`).
|
||||||
|
|
||||||
#### Option B: OpenAI-Compatible Provider (Simplified)
|
#### Option B: OpenAI-Compatible Provider (Simplified)
|
||||||
|
|
||||||
@@ -172,26 +167,16 @@ class ExampleProvider(OpenAICompatibleProvider):
|
|||||||
def __init__(self, api_key: str, **kwargs):
|
def __init__(self, api_key: str, **kwargs):
|
||||||
kwargs.setdefault("base_url", "https://api.example.com/v1")
|
kwargs.setdefault("base_url", "https://api.example.com/v1")
|
||||||
super().__init__(api_key, **kwargs)
|
super().__init__(api_key, **kwargs)
|
||||||
|
|
||||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
|
||||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
|
||||||
raise ValueError(f"Unsupported model: {model_name}")
|
|
||||||
return self.MODEL_CAPABILITIES[resolved_name]
|
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
return ProviderType.EXAMPLE
|
return ProviderType.EXAMPLE
|
||||||
|
|
||||||
def validate_model_name(self, model_name: str) -> bool:
|
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
|
||||||
return resolved_name in self.MODEL_CAPABILITIES
|
|
||||||
|
|
||||||
def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
|
|
||||||
# IMPORTANT: Resolve aliases before API call
|
|
||||||
resolved_model_name = self._resolve_model_name(model_name)
|
|
||||||
return super().generate_content(prompt=prompt, model_name=resolved_model_name, **kwargs)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
`OpenAICompatibleProvider` already exposes the declared models via
|
||||||
|
`MODEL_CAPABILITIES`, resolves aliases through the shared base pipeline, and
|
||||||
|
enforces restrictions. Most subclasses only need to provide the class metadata
|
||||||
|
shown above.
|
||||||
|
|
||||||
### 3. Register Your Provider
|
### 3. Register Your Provider
|
||||||
|
|
||||||
Add environment variable mapping in `providers/registry.py`:
|
Add environment variable mapping in `providers/registry.py`:
|
||||||
@@ -237,15 +222,11 @@ DISABLED_TOOLS=debug,tracer
|
|||||||
Create basic tests to verify your implementation:
|
Create basic tests to verify your implementation:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Test model validation
|
|
||||||
provider = ExampleModelProvider("test-key")
|
|
||||||
assert provider.validate_model_name("large") == True
|
|
||||||
assert provider.validate_model_name("unknown") == False
|
|
||||||
|
|
||||||
# Test capabilities
|
# Test capabilities
|
||||||
caps = provider.get_capabilities("large")
|
provider = ExampleModelProvider("test-key")
|
||||||
assert caps.context_window > 0
|
capabilities = provider.get_capabilities("large")
|
||||||
assert caps.provider == ProviderType.EXAMPLE
|
assert capabilities.context_window > 0
|
||||||
|
assert capabilities.provider == ProviderType.EXAMPLE
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
@@ -259,31 +240,19 @@ When a user requests a model, providers are checked in priority order:
|
|||||||
3. **OpenRouter** - catch-all for everything else
|
3. **OpenRouter** - catch-all for everything else
|
||||||
|
|
||||||
### Model Validation
|
### Model Validation
|
||||||
Your `validate_model_name()` should **only** return `True` for models you explicitly support:
|
`ModelProvider.validate_model_name()` delegates to `get_capabilities()` so most
|
||||||
|
providers can rely on the shared implementation. Override it only when you need
|
||||||
```python
|
to opt out of that pipeline—for example, `CustomProvider` declines OpenRouter
|
||||||
def validate_model_name(self, model_name: str) -> bool:
|
models so they fall through to the dedicated OpenRouter provider.
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
|
||||||
return resolved_name in self.MODEL_CAPABILITIES # Be specific!
|
|
||||||
```
|
|
||||||
|
|
||||||
### Model Aliases
|
### Model Aliases
|
||||||
The base class handles alias resolution automatically via the `aliases` field in `ModelCapabilities`.
|
Aliases declared on `ModelCapabilities` are applied automatically via
|
||||||
|
`_resolve_model_name()`, and both the validation and request flows call it
|
||||||
|
before touching your SDK. Override `generate_content()` only when your provider
|
||||||
|
needs additional alias handling beyond the shared behaviour.
|
||||||
|
|
||||||
## Important Notes
|
## Important Notes
|
||||||
|
|
||||||
### Alias Resolution in OpenAI-Compatible Providers
|
|
||||||
If using `OpenAICompatibleProvider` with aliases, **you must override `generate_content()`** to resolve aliases before API calls:
|
|
||||||
|
|
||||||
```python
|
|
||||||
def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
|
|
||||||
# Resolve alias before API call
|
|
||||||
resolved_model_name = self._resolve_model_name(model_name)
|
|
||||||
return super().generate_content(prompt=prompt, model_name=resolved_model_name, **kwargs)
|
|
||||||
```
|
|
||||||
|
|
||||||
Without this, API calls with aliases like `"large"` will fail because your API doesn't recognize the alias.
|
|
||||||
|
|
||||||
## Best Practices
|
## Best Practices
|
||||||
|
|
||||||
- **Be specific in model validation** - only accept models you actually support
|
- **Be specific in model validation** - only accept models you actually support
|
||||||
|
|||||||
@@ -43,128 +43,37 @@ class ModelProvider(ABC):
|
|||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.config = kwargs
|
self.config = kwargs
|
||||||
|
|
||||||
@abstractmethod
|
# ------------------------------------------------------------------
|
||||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
# Provider identity & capability surface
|
||||||
"""Get capabilities for a specific model."""
|
# ------------------------------------------------------------------
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def generate_content(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
model_name: str,
|
|
||||||
system_prompt: Optional[str] = None,
|
|
||||||
temperature: float = 0.3,
|
|
||||||
max_output_tokens: Optional[int] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> ModelResponse:
|
|
||||||
"""Generate content using the model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt: User prompt to send to the model
|
|
||||||
model_name: Name of the model to use
|
|
||||||
system_prompt: Optional system prompt for model behavior
|
|
||||||
temperature: Sampling temperature (0-2)
|
|
||||||
max_output_tokens: Maximum tokens to generate
|
|
||||||
**kwargs: Provider-specific parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ModelResponse with generated content and metadata
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def count_tokens(self, text: str, model_name: str) -> int:
|
|
||||||
"""Estimate token usage for a piece of text.
|
|
||||||
|
|
||||||
Providers can rely on this shared implementation or override it when
|
|
||||||
they expose a more accurate tokenizer. This default uses a simple
|
|
||||||
character-based heuristic so it works even without provider-specific
|
|
||||||
tooling.
|
|
||||||
"""
|
|
||||||
|
|
||||||
resolved_model = self._resolve_model_name(model_name)
|
|
||||||
|
|
||||||
if not text:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# Rough estimation: ~4 characters per token for English text
|
|
||||||
estimated = max(1, len(text) // 4)
|
|
||||||
logger.debug("Estimating %s tokens for model %s via character heuristic", estimated, resolved_model)
|
|
||||||
return estimated
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Return the concrete provider identity."""
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||||
def validate_model_name(self, model_name: str) -> bool:
|
"""Resolve capability metadata for a model name.
|
||||||
"""Validate if the model name is supported by this provider."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
|
This centralises the alias resolution → lookup → restriction check
|
||||||
"""Validate model parameters against capabilities.
|
pipeline so providers only override the pieces they genuinely need to
|
||||||
|
customise. Subclasses usually only override ``_lookup_capabilities`` to
|
||||||
Raises:
|
integrate a registry or dynamic source, or ``_finalise_capabilities`` to
|
||||||
ValueError: If parameters are invalid
|
tweak the returned object.
|
||||||
"""
|
"""
|
||||||
capabilities = self.get_capabilities(model_name)
|
|
||||||
|
|
||||||
# Validate temperature using constraint
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
if not capabilities.temperature_constraint.validate(temperature):
|
capabilities = self._lookup_capabilities(resolved_name, model_name)
|
||||||
constraint_desc = capabilities.temperature_constraint.get_description()
|
|
||||||
raise ValueError(f"Temperature {temperature} is invalid for model {model_name}. {constraint_desc}")
|
|
||||||
|
|
||||||
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
if capabilities is None:
|
||||||
"""Get model configurations for this provider.
|
self._raise_unsupported_model(model_name)
|
||||||
|
|
||||||
This is a hook method that subclasses can override to provide
|
self._ensure_model_allowed(capabilities, resolved_name, model_name)
|
||||||
their model configurations from different sources.
|
return self._finalise_capabilities(capabilities, resolved_name, model_name)
|
||||||
|
|
||||||
|
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||||
|
"""Return the provider's statically declared model capabilities."""
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary mapping model names to their ModelCapabilities objects
|
|
||||||
"""
|
|
||||||
model_map = getattr(self, "MODEL_CAPABILITIES", None)
|
|
||||||
if isinstance(model_map, dict) and model_map:
|
|
||||||
return {k: v for k, v in model_map.items() if isinstance(v, ModelCapabilities)}
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _resolve_model_name(self, model_name: str) -> str:
|
|
||||||
"""Resolve model shorthand to full name.
|
|
||||||
|
|
||||||
This implementation uses the hook methods to support different
|
|
||||||
model configuration sources.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Model name that may be an alias
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Resolved model name
|
|
||||||
"""
|
|
||||||
# Get model configurations from the hook method
|
|
||||||
model_configs = self.get_model_configurations()
|
|
||||||
|
|
||||||
# First check if it's already a base model name (case-sensitive exact match)
|
|
||||||
if model_name in model_configs:
|
|
||||||
return model_name
|
|
||||||
|
|
||||||
# Check case-insensitively for both base models and aliases
|
|
||||||
model_name_lower = model_name.lower()
|
|
||||||
|
|
||||||
# Check base model names case-insensitively
|
|
||||||
for base_model in model_configs:
|
|
||||||
if base_model.lower() == model_name_lower:
|
|
||||||
return base_model
|
|
||||||
|
|
||||||
# Check aliases from the model configurations
|
|
||||||
alias_map = ModelCapabilities.collect_aliases(model_configs)
|
|
||||||
for base_model, aliases in alias_map.items():
|
|
||||||
if any(alias.lower() == model_name_lower for alias in aliases):
|
|
||||||
return base_model
|
|
||||||
|
|
||||||
# If not found, return as-is
|
|
||||||
return model_name
|
|
||||||
|
|
||||||
def list_models(
|
def list_models(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -175,7 +84,7 @@ class ModelProvider(ABC):
|
|||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Return formatted model names supported by this provider."""
|
"""Return formatted model names supported by this provider."""
|
||||||
|
|
||||||
model_configs = self.get_model_configurations()
|
model_configs = self.get_all_model_capabilities()
|
||||||
if not model_configs:
|
if not model_configs:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@@ -202,36 +111,155 @@ class ModelProvider(ABC):
|
|||||||
unique=unique,
|
unique=unique,
|
||||||
)
|
)
|
||||||
|
|
||||||
def close(self):
|
# ------------------------------------------------------------------
|
||||||
"""Clean up any resources held by the provider.
|
# Request execution
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
@abstractmethod
|
||||||
|
def generate_content(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model_name: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
temperature: float = 0.3,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ModelResponse:
|
||||||
|
"""Generate content using the model."""
|
||||||
|
|
||||||
|
def count_tokens(self, text: str, model_name: str) -> int:
|
||||||
|
"""Estimate token usage for a piece of text."""
|
||||||
|
|
||||||
|
resolved_model = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
estimated = max(1, len(text) // 4)
|
||||||
|
logger.debug("Estimating %s tokens for model %s via character heuristic", estimated, resolved_model)
|
||||||
|
return estimated
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Clean up any resources held by the provider."""
|
||||||
|
|
||||||
Default implementation does nothing.
|
|
||||||
Subclasses should override if they hold resources that need cleanup.
|
|
||||||
"""
|
|
||||||
# Base implementation: no resources to clean up
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Validation hooks
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def validate_model_name(self, model_name: str) -> bool:
|
||||||
|
"""Return ``True`` when the model resolves to an allowed capability."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.get_capabilities(model_name)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
|
||||||
|
"""Validate model parameters against capabilities."""
|
||||||
|
|
||||||
|
capabilities = self.get_capabilities(model_name)
|
||||||
|
|
||||||
|
if not capabilities.temperature_constraint.validate(temperature):
|
||||||
|
constraint_desc = capabilities.temperature_constraint.get_description()
|
||||||
|
raise ValueError(f"Temperature {temperature} is invalid for model {model_name}. {constraint_desc}")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Preference / registry hooks
|
||||||
|
# ------------------------------------------------------------------
|
||||||
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
|
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
|
||||||
"""Get the preferred model from this provider for a given category.
|
"""Get the preferred model from this provider for a given category."""
|
||||||
|
|
||||||
Args:
|
|
||||||
category: The tool category requiring a model
|
|
||||||
allowed_models: Pre-filtered list of model names that are allowed by restrictions
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Model name if this provider has a preference, None otherwise
|
|
||||||
"""
|
|
||||||
# Default implementation - providers can override with specific logic
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_model_registry(self) -> Optional[dict[str, Any]]:
|
def get_model_registry(self) -> Optional[dict[str, Any]]:
|
||||||
"""Get the model registry for providers that maintain one.
|
"""Return the model registry backing this provider, if any."""
|
||||||
|
|
||||||
This is a hook method for providers like CustomProvider that maintain
|
return None
|
||||||
a dynamic model registry.
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Capability lookup pipeline
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def _lookup_capabilities(
|
||||||
|
self,
|
||||||
|
canonical_name: str,
|
||||||
|
requested_name: Optional[str] = None,
|
||||||
|
) -> Optional[ModelCapabilities]:
|
||||||
|
"""Return ``ModelCapabilities`` for the canonical model name."""
|
||||||
|
|
||||||
|
return self.get_all_model_capabilities().get(canonical_name)
|
||||||
|
|
||||||
|
def _ensure_model_allowed(
|
||||||
|
self,
|
||||||
|
capabilities: ModelCapabilities,
|
||||||
|
canonical_name: str,
|
||||||
|
requested_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Raise ``ValueError`` if the model violates restriction policy."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
except Exception: # pragma: no cover - only triggered if service import breaks
|
||||||
|
return
|
||||||
|
|
||||||
|
restriction_service = get_restriction_service()
|
||||||
|
if not restriction_service:
|
||||||
|
return
|
||||||
|
|
||||||
|
if restriction_service.is_allowed(self.get_provider_type(), canonical_name, requested_name):
|
||||||
|
return
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.get_provider_type().value} model '{canonical_name}' is not allowed by restriction policy."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _finalise_capabilities(
|
||||||
|
self,
|
||||||
|
capabilities: ModelCapabilities,
|
||||||
|
canonical_name: str,
|
||||||
|
requested_name: str,
|
||||||
|
) -> ModelCapabilities:
|
||||||
|
"""Allow subclasses to adjust capability metadata before returning."""
|
||||||
|
|
||||||
|
return capabilities
|
||||||
|
|
||||||
|
def _raise_unsupported_model(self, model_name: str) -> None:
|
||||||
|
"""Raise the canonical unsupported-model error."""
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported model '{model_name}' for provider {self.get_provider_type().value}.")
|
||||||
|
|
||||||
|
def _resolve_model_name(self, model_name: str) -> str:
|
||||||
|
"""Resolve model shorthand to full name.
|
||||||
|
|
||||||
|
This implementation uses the hook methods to support different
|
||||||
|
model configuration sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Model name that may be an alias
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Model registry dict or None if not applicable
|
Resolved model name
|
||||||
"""
|
"""
|
||||||
# Default implementation - most providers don't have a registry
|
# Get model configurations from the hook method
|
||||||
return None
|
model_configs = self.get_all_model_capabilities()
|
||||||
|
|
||||||
|
# First check if it's already a base model name (case-sensitive exact match)
|
||||||
|
if model_name in model_configs:
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
# Check case-insensitively for both base models and aliases
|
||||||
|
model_name_lower = model_name.lower()
|
||||||
|
|
||||||
|
# Check base model names case-insensitively
|
||||||
|
for base_model in model_configs:
|
||||||
|
if base_model.lower() == model_name_lower:
|
||||||
|
return base_model
|
||||||
|
|
||||||
|
# Check aliases from the model configurations
|
||||||
|
alias_map = ModelCapabilities.collect_aliases(model_configs)
|
||||||
|
for base_model, aliases in alias_map.items():
|
||||||
|
if any(alias.lower() == model_name_lower for alias in aliases):
|
||||||
|
return base_model
|
||||||
|
|
||||||
|
# If not found, return as-is
|
||||||
|
return model_name
|
||||||
|
|||||||
@@ -83,117 +83,69 @@ class CustomProvider(OpenAICompatibleProvider):
|
|||||||
aliases = self._registry.list_aliases()
|
aliases = self._registry.list_aliases()
|
||||||
logging.info(f"Custom provider loaded {len(models)} models with {len(aliases)} aliases")
|
logging.info(f"Custom provider loaded {len(models)} models with {len(aliases)} aliases")
|
||||||
|
|
||||||
def _resolve_model_name(self, model_name: str) -> str:
|
# ------------------------------------------------------------------
|
||||||
"""Resolve model aliases to actual model names.
|
# Capability surface
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def _lookup_capabilities(
|
||||||
|
self,
|
||||||
|
canonical_name: str,
|
||||||
|
requested_name: Optional[str] = None,
|
||||||
|
) -> Optional[ModelCapabilities]:
|
||||||
|
"""Return custom capabilities from the registry or generic defaults."""
|
||||||
|
|
||||||
For Ollama-style models, strips version tags (e.g., 'llama3.2:latest' -> 'llama3.2')
|
builtin = super()._lookup_capabilities(canonical_name, requested_name)
|
||||||
since the base model name is what's typically used in API calls.
|
if builtin is not None:
|
||||||
|
return builtin
|
||||||
Args:
|
|
||||||
model_name: Input model name or alias
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Resolved model name with version tags stripped if applicable
|
|
||||||
"""
|
|
||||||
# First, try to resolve through registry as-is
|
|
||||||
config = self._registry.resolve(model_name)
|
|
||||||
|
|
||||||
if config:
|
|
||||||
if config.model_name != model_name:
|
|
||||||
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
|
||||||
return config.model_name
|
|
||||||
else:
|
|
||||||
# If not found in registry, handle version tags for local models
|
|
||||||
# Strip version tags (anything after ':') for Ollama-style models
|
|
||||||
if ":" in model_name:
|
|
||||||
base_model = model_name.split(":")[0]
|
|
||||||
logging.debug(f"Stripped version tag from '{model_name}' -> '{base_model}'")
|
|
||||||
|
|
||||||
# Try to resolve the base model through registry
|
|
||||||
base_config = self._registry.resolve(base_model)
|
|
||||||
if base_config:
|
|
||||||
logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'")
|
|
||||||
return base_config.model_name
|
|
||||||
else:
|
|
||||||
return base_model
|
|
||||||
else:
|
|
||||||
# If not found in registry and no version tag, return as-is
|
|
||||||
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
|
||||||
return model_name
|
|
||||||
|
|
||||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
|
||||||
"""Get capabilities for a custom model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Name of the model (or alias)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ModelCapabilities from registry or generic defaults
|
|
||||||
"""
|
|
||||||
# Try to get from registry first
|
|
||||||
capabilities = self._registry.get_capabilities(model_name)
|
|
||||||
|
|
||||||
|
capabilities = self._registry.get_capabilities(canonical_name)
|
||||||
if capabilities:
|
if capabilities:
|
||||||
# Check if this is an OpenRouter model and apply restrictions
|
config = self._registry.resolve(canonical_name)
|
||||||
config = self._registry.resolve(model_name)
|
if config and getattr(config, "is_custom", False):
|
||||||
if config and not config.is_custom:
|
|
||||||
# This is an OpenRouter model, check restrictions
|
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service()
|
|
||||||
if not restriction_service.is_allowed(ProviderType.OPENROUTER, config.model_name, model_name):
|
|
||||||
raise ValueError(f"OpenRouter model '{model_name}' is not allowed by restriction policy.")
|
|
||||||
|
|
||||||
# Update provider type to OPENROUTER for OpenRouter models
|
|
||||||
capabilities.provider = ProviderType.OPENROUTER
|
|
||||||
else:
|
|
||||||
# Update provider type to CUSTOM for local custom models
|
|
||||||
capabilities.provider = ProviderType.CUSTOM
|
capabilities.provider = ProviderType.CUSTOM
|
||||||
return capabilities
|
return capabilities
|
||||||
else:
|
# Non-custom models should fall through so OpenRouter handles them
|
||||||
# Resolve any potential aliases and create generic capabilities
|
return None
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
|
||||||
|
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"Using generic capabilities for '{resolved_name}' via Custom API. "
|
f"Using generic capabilities for '{canonical_name}' via Custom API. "
|
||||||
"Consider adding to custom_models.json for specific capabilities."
|
"Consider adding to custom_models.json for specific capabilities."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Infer temperature behaviour for generic capability fallback
|
supports_temperature, temperature_constraint, temperature_reason = TemperatureConstraint.resolve_settings(
|
||||||
supports_temperature, temperature_constraint, temperature_reason = TemperatureConstraint.resolve_settings(
|
canonical_name
|
||||||
resolved_name
|
)
|
||||||
)
|
|
||||||
|
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Model '{resolved_name}' not found in custom_models.json. Using generic capabilities with inferred settings. "
|
f"Model '{canonical_name}' not found in custom_models.json. Using generic capabilities with inferred settings. "
|
||||||
f"Temperature support: {supports_temperature} ({temperature_reason}). "
|
f"Temperature support: {supports_temperature} ({temperature_reason}). "
|
||||||
"For better accuracy, add this model to your custom_models.json configuration."
|
"For better accuracy, add this model to your custom_models.json configuration."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create generic capabilities with inferred defaults
|
generic = ModelCapabilities(
|
||||||
capabilities = ModelCapabilities(
|
provider=ProviderType.CUSTOM,
|
||||||
provider=ProviderType.CUSTOM,
|
model_name=canonical_name,
|
||||||
model_name=resolved_name,
|
friendly_name=f"{self.FRIENDLY_NAME} ({canonical_name})",
|
||||||
friendly_name=f"{self.FRIENDLY_NAME} ({resolved_name})",
|
context_window=32_768,
|
||||||
context_window=32_768, # Conservative default
|
max_output_tokens=32_768,
|
||||||
max_output_tokens=32_768, # Conservative default max output
|
supports_extended_thinking=False,
|
||||||
supports_extended_thinking=False, # Most custom models don't support this
|
supports_system_prompts=True,
|
||||||
supports_system_prompts=True,
|
supports_streaming=True,
|
||||||
supports_streaming=True,
|
supports_function_calling=False,
|
||||||
supports_function_calling=False, # Conservative default
|
supports_temperature=supports_temperature,
|
||||||
supports_temperature=supports_temperature,
|
temperature_constraint=temperature_constraint,
|
||||||
temperature_constraint=temperature_constraint,
|
)
|
||||||
)
|
generic._is_generic = True
|
||||||
|
return generic
|
||||||
# Mark as generic for validation purposes
|
|
||||||
capabilities._is_generic = True
|
|
||||||
|
|
||||||
return capabilities
|
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Identify this provider for restriction and logging logic."""
|
||||||
|
|
||||||
return ProviderType.CUSTOM
|
return ProviderType.CUSTOM
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Validation
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def validate_model_name(self, model_name: str) -> bool:
|
def validate_model_name(self, model_name: str) -> bool:
|
||||||
"""Validate if the model name is allowed.
|
"""Validate if the model name is allowed.
|
||||||
|
|
||||||
@@ -206,49 +158,41 @@ class CustomProvider(OpenAICompatibleProvider):
|
|||||||
Returns:
|
Returns:
|
||||||
True if model is intended for custom/local endpoint
|
True if model is intended for custom/local endpoint
|
||||||
"""
|
"""
|
||||||
# logging.debug(f"Custom provider validating model: '{model_name}'")
|
if super().validate_model_name(model_name):
|
||||||
|
return True
|
||||||
|
|
||||||
# Try to resolve through registry first
|
|
||||||
config = self._registry.resolve(model_name)
|
config = self._registry.resolve(model_name)
|
||||||
if config:
|
if config and not getattr(config, "is_custom", False):
|
||||||
model_id = config.model_name
|
return False
|
||||||
# Use explicit is_custom flag for clean validation
|
|
||||||
if config.is_custom:
|
|
||||||
logging.debug(f"... [Custom] Model '{model_name}' -> '{model_id}' validated via registry")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
# This is a cloud/OpenRouter model - CustomProvider should NOT handle these
|
|
||||||
# Let OpenRouter provider handle them instead
|
|
||||||
# logging.debug(f"... [Custom] Model '{model_name}' -> '{model_id}' not custom (defer to OpenRouter)")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Handle version tags for unknown models (e.g., "my-model:latest")
|
|
||||||
clean_model_name = model_name
|
clean_model_name = model_name
|
||||||
if ":" in model_name:
|
if ":" in model_name:
|
||||||
clean_model_name = model_name.split(":")[0]
|
clean_model_name = model_name.split(":", 1)[0]
|
||||||
logging.debug(f"Stripped version tag from '{model_name}' -> '{clean_model_name}'")
|
logging.debug(f"Stripped version tag from '{model_name}' -> '{clean_model_name}'")
|
||||||
# Try to resolve the clean name
|
|
||||||
|
if super().validate_model_name(clean_model_name):
|
||||||
|
return True
|
||||||
|
|
||||||
config = self._registry.resolve(clean_model_name)
|
config = self._registry.resolve(clean_model_name)
|
||||||
if config:
|
if config and not getattr(config, "is_custom", False):
|
||||||
return self.validate_model_name(clean_model_name) # Recursively validate clean name
|
return False
|
||||||
|
|
||||||
# For unknown models (not in registry), only accept if they look like local models
|
lowered = clean_model_name.lower()
|
||||||
# This maintains backward compatibility for custom models not yet in the registry
|
if any(indicator in lowered for indicator in ["local", "ollama", "vllm", "lmstudio"]):
|
||||||
|
|
||||||
# Accept models with explicit local indicators in the name
|
|
||||||
if any(indicator in clean_model_name.lower() for indicator in ["local", "ollama", "vllm", "lmstudio"]):
|
|
||||||
logging.debug(f"Model '{clean_model_name}' validated via local indicators")
|
logging.debug(f"Model '{clean_model_name}' validated via local indicators")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Accept simple model names without vendor prefix (likely local/custom models)
|
|
||||||
if "/" not in clean_model_name:
|
if "/" not in clean_model_name:
|
||||||
logging.debug(f"Model '{clean_model_name}' validated as potential local model (no vendor prefix)")
|
logging.debug(f"Model '{clean_model_name}' validated as potential local model (no vendor prefix)")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Reject everything else (likely cloud models not in registry)
|
|
||||||
logging.debug(f"Model '{model_name}' rejected by custom provider (appears to be cloud model)")
|
logging.debug(f"Model '{model_name}' rejected by custom provider (appears to be cloud model)")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Request execution
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def generate_content(
|
def generate_content(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -284,25 +228,41 @@ class CustomProvider(OpenAICompatibleProvider):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
# ------------------------------------------------------------------
|
||||||
"""Get model configurations from the registry.
|
# Registry helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
For CustomProvider, we convert registry configurations to ModelCapabilities objects.
|
def _resolve_model_name(self, model_name: str) -> str:
|
||||||
|
"""Resolve registry aliases and strip version tags for local models."""
|
||||||
|
|
||||||
Returns:
|
config = self._registry.resolve(model_name)
|
||||||
Dictionary mapping model names to their ModelCapabilities objects
|
if config:
|
||||||
"""
|
if config.model_name != model_name:
|
||||||
|
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
||||||
|
return config.model_name
|
||||||
|
|
||||||
configs = {}
|
if ":" in model_name:
|
||||||
|
base_model = model_name.split(":")[0]
|
||||||
|
logging.debug(f"Stripped version tag from '{model_name}' -> '{base_model}'")
|
||||||
|
|
||||||
if self._registry:
|
base_config = self._registry.resolve(base_model)
|
||||||
# Get all models from registry
|
if base_config:
|
||||||
for model_name in self._registry.list_models():
|
logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'")
|
||||||
# Only include custom models that this provider validates
|
return base_config.model_name
|
||||||
if self.validate_model_name(model_name):
|
return base_model
|
||||||
config = self._registry.resolve(model_name)
|
|
||||||
if config and config.is_custom:
|
|
||||||
# Use ModelCapabilities directly from registry
|
|
||||||
configs[model_name] = config
|
|
||||||
|
|
||||||
return configs
|
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||||
|
"""Expose registry capabilities for models marked as custom."""
|
||||||
|
|
||||||
|
if not self._registry:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
capabilities: dict[str, ModelCapabilities] = {}
|
||||||
|
for model_name in self._registry.list_models():
|
||||||
|
config = self._registry.resolve(model_name)
|
||||||
|
if config and getattr(config, "is_custom", False):
|
||||||
|
capabilities[model_name] = config
|
||||||
|
return capabilities
|
||||||
|
|||||||
@@ -261,68 +261,10 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
|
|
||||||
logger.info(f"Initialized DIAL provider with host: {dial_host} and api-version: {self.api_version}")
|
logger.info(f"Initialized DIAL provider with host: {dial_host} and api-version: {self.api_version}")
|
||||||
|
|
||||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
|
||||||
"""Get capabilities for a specific model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Name of the model (can be shorthand)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ModelCapabilities object
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If model is not supported or not allowed
|
|
||||||
"""
|
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
|
||||||
|
|
||||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
|
||||||
raise ValueError(f"Unsupported DIAL model: {model_name}")
|
|
||||||
|
|
||||||
# Check restrictions
|
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service()
|
|
||||||
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
|
|
||||||
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
|
|
||||||
|
|
||||||
# Return the ModelCapabilities object directly from MODEL_CAPABILITIES
|
|
||||||
return self.MODEL_CAPABILITIES[resolved_name]
|
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
return ProviderType.DIAL
|
return ProviderType.DIAL
|
||||||
|
|
||||||
def validate_model_name(self, model_name: str) -> bool:
|
|
||||||
"""Validate if the model name is supported.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Model name to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if model is supported and allowed, False otherwise
|
|
||||||
"""
|
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
|
||||||
|
|
||||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check against base class allowed_models if configured
|
|
||||||
if self.allowed_models is not None:
|
|
||||||
# Check both original and resolved names (case-insensitive)
|
|
||||||
if model_name.lower() not in self.allowed_models and resolved_name.lower() not in self.allowed_models:
|
|
||||||
logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' not in allowed_models list")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Also check restrictions via ModelRestrictionService
|
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service()
|
|
||||||
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
|
|
||||||
logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _get_deployment_client(self, deployment: str):
|
def _get_deployment_client(self, deployment: str):
|
||||||
"""Get or create a cached client for a specific deployment.
|
"""Get or create a cached client for a specific deployment.
|
||||||
|
|
||||||
@@ -504,7 +446,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
|
f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def close(self):
|
def close(self) -> None:
|
||||||
"""Clean up HTTP clients when provider is closed."""
|
"""Clean up HTTP clients when provider is closed."""
|
||||||
logger.info("Closing DIAL provider HTTP clients...")
|
logger.info("Closing DIAL provider HTTP clients...")
|
||||||
|
|
||||||
|
|||||||
@@ -131,6 +131,19 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
self._token_counters = {} # Cache for token counting
|
self._token_counters = {} # Cache for token counting
|
||||||
self._base_url = kwargs.get("base_url", None) # Optional custom endpoint
|
self._base_url = kwargs.get("base_url", None) # Optional custom endpoint
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Capability surface
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||||
|
"""Return statically defined Gemini capabilities."""
|
||||||
|
|
||||||
|
return dict(self.MODEL_CAPABILITIES)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Client access
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self):
|
def client(self):
|
||||||
"""Lazy initialization of Gemini client."""
|
"""Lazy initialization of Gemini client."""
|
||||||
@@ -146,25 +159,9 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
self._client = genai.Client(api_key=self.api_key)
|
self._client = genai.Client(api_key=self.api_key)
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
# ------------------------------------------------------------------
|
||||||
"""Get capabilities for a specific Gemini model."""
|
# Request execution
|
||||||
# Resolve shorthand
|
# ------------------------------------------------------------------
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
|
||||||
|
|
||||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
|
||||||
raise ValueError(f"Unsupported Gemini model: {model_name}")
|
|
||||||
|
|
||||||
# Check if model is allowed by restrictions
|
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service()
|
|
||||||
# IMPORTANT: Parameter order is (provider_type, model_name, original_name)
|
|
||||||
# resolved_name is the canonical model name, model_name is the user input
|
|
||||||
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
|
|
||||||
raise ValueError(f"Gemini model '{resolved_name}' is not allowed by restriction policy.")
|
|
||||||
|
|
||||||
# Return the ModelCapabilities object directly from MODEL_CAPABILITIES
|
|
||||||
return self.MODEL_CAPABILITIES[resolved_name]
|
|
||||||
|
|
||||||
def generate_content(
|
def generate_content(
|
||||||
self,
|
self,
|
||||||
@@ -365,26 +362,6 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
return ProviderType.GOOGLE
|
return ProviderType.GOOGLE
|
||||||
|
|
||||||
def validate_model_name(self, model_name: str) -> bool:
|
|
||||||
"""Validate if the model name is supported and allowed."""
|
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
|
||||||
|
|
||||||
# First check if model is supported
|
|
||||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Then check if model is allowed by restrictions
|
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service()
|
|
||||||
# IMPORTANT: Parameter order is (provider_type, model_name, original_name)
|
|
||||||
# resolved_name is the canonical model name, model_name is the user input
|
|
||||||
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
|
|
||||||
logger.debug(f"Gemini model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int:
|
def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int:
|
||||||
"""Get actual thinking token budget for a model and thinking mode."""
|
"""Get actual thinking token budget for a model and thinking mode."""
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import ipaddress
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from abc import abstractmethod
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
@@ -61,6 +60,33 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
"This may be insecure. Consider setting an API key for authentication."
|
"This may be insecure. Consider setting an API key for authentication."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _ensure_model_allowed(
|
||||||
|
self,
|
||||||
|
capabilities: ModelCapabilities,
|
||||||
|
canonical_name: str,
|
||||||
|
requested_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Respect provider-specific allowlists before default restriction checks."""
|
||||||
|
|
||||||
|
super()._ensure_model_allowed(capabilities, canonical_name, requested_name)
|
||||||
|
|
||||||
|
if self.allowed_models is not None:
|
||||||
|
requested = requested_name.lower()
|
||||||
|
canonical = canonical_name.lower()
|
||||||
|
|
||||||
|
if requested not in self.allowed_models and canonical not in self.allowed_models:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||||
|
"""Return statically declared capabilities for OpenAI-compatible providers."""
|
||||||
|
|
||||||
|
model_map = getattr(self, "MODEL_CAPABILITIES", None)
|
||||||
|
if isinstance(model_map, dict):
|
||||||
|
return {k: v for k, v in model_map.items() if isinstance(v, ModelCapabilities)}
|
||||||
|
return {}
|
||||||
|
|
||||||
def _parse_allowed_models(self) -> Optional[set[str]]:
|
def _parse_allowed_models(self) -> Optional[set[str]]:
|
||||||
"""Parse allowed models from environment variable.
|
"""Parse allowed models from environment variable.
|
||||||
|
|
||||||
@@ -686,30 +712,6 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
|
|
||||||
return super().count_tokens(text, model_name)
|
return super().count_tokens(text, model_name)
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
|
||||||
"""Get capabilities for a specific model.
|
|
||||||
|
|
||||||
Must be implemented by subclasses.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
|
||||||
"""Get the provider type.
|
|
||||||
|
|
||||||
Must be implemented by subclasses.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def validate_model_name(self, model_name: str) -> bool:
|
|
||||||
"""Validate if the model name is supported.
|
|
||||||
|
|
||||||
Must be implemented by subclasses.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _is_error_retryable(self, error: Exception) -> bool:
|
def _is_error_retryable(self, error: Exception) -> bool:
|
||||||
"""Determine if an error should be retried based on structured error codes.
|
"""Determine if an error should be retried based on structured error codes.
|
||||||
|
|
||||||
|
|||||||
@@ -174,106 +174,61 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
kwargs.setdefault("base_url", "https://api.openai.com/v1")
|
kwargs.setdefault("base_url", "https://api.openai.com/v1")
|
||||||
super().__init__(api_key, **kwargs)
|
super().__init__(api_key, **kwargs)
|
||||||
|
|
||||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
# ------------------------------------------------------------------
|
||||||
"""Get capabilities for a specific OpenAI model."""
|
# Capability surface
|
||||||
# First check if it's a key in MODEL_CAPABILITIES
|
# ------------------------------------------------------------------
|
||||||
if model_name in self.MODEL_CAPABILITIES:
|
|
||||||
self._check_model_restrictions(model_name, model_name)
|
|
||||||
return self.MODEL_CAPABILITIES[model_name]
|
|
||||||
|
|
||||||
# Try resolving as alias
|
def _lookup_capabilities(
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
self,
|
||||||
|
canonical_name: str,
|
||||||
|
requested_name: Optional[str] = None,
|
||||||
|
) -> Optional[ModelCapabilities]:
|
||||||
|
"""Look up OpenAI capabilities from built-ins or the custom registry."""
|
||||||
|
|
||||||
# Check if resolved name is a key
|
builtin = super()._lookup_capabilities(canonical_name, requested_name)
|
||||||
if resolved_name in self.MODEL_CAPABILITIES:
|
if builtin is not None:
|
||||||
self._check_model_restrictions(resolved_name, model_name)
|
return builtin
|
||||||
return self.MODEL_CAPABILITIES[resolved_name]
|
|
||||||
|
|
||||||
# Finally check if resolved name matches any API model name
|
|
||||||
for key, capabilities in self.MODEL_CAPABILITIES.items():
|
|
||||||
if resolved_name == capabilities.model_name:
|
|
||||||
self._check_model_restrictions(key, model_name)
|
|
||||||
return capabilities
|
|
||||||
|
|
||||||
# Check custom models registry for user-configured OpenAI models
|
|
||||||
try:
|
try:
|
||||||
from .openrouter_registry import OpenRouterModelRegistry
|
from .openrouter_registry import OpenRouterModelRegistry
|
||||||
|
|
||||||
registry = OpenRouterModelRegistry()
|
registry = OpenRouterModelRegistry()
|
||||||
config = registry.get_model_config(resolved_name)
|
config = registry.get_model_config(canonical_name)
|
||||||
|
|
||||||
if config and config.provider == ProviderType.OPENAI:
|
if config and config.provider == ProviderType.OPENAI:
|
||||||
self._check_model_restrictions(config.model_name, model_name)
|
|
||||||
|
|
||||||
# Update provider type to ensure consistency
|
|
||||||
config.provider = ProviderType.OPENAI
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as exc: # pragma: no cover - registry failures are non-critical
|
||||||
# Log but don't fail - registry might not be available
|
logger.debug(f"Could not resolve custom OpenAI model '{canonical_name}': {exc}")
|
||||||
logger.debug(f"Could not check custom models registry for '{resolved_name}': {e}")
|
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _finalise_capabilities(
|
||||||
|
self,
|
||||||
|
capabilities: ModelCapabilities,
|
||||||
|
canonical_name: str,
|
||||||
|
requested_name: str,
|
||||||
|
) -> ModelCapabilities:
|
||||||
|
"""Ensure registry-sourced models report the correct provider type."""
|
||||||
|
|
||||||
|
if capabilities.provider != ProviderType.OPENAI:
|
||||||
|
capabilities.provider = ProviderType.OPENAI
|
||||||
|
return capabilities
|
||||||
|
|
||||||
|
def _raise_unsupported_model(self, model_name: str) -> None:
|
||||||
raise ValueError(f"Unsupported OpenAI model: {model_name}")
|
raise ValueError(f"Unsupported OpenAI model: {model_name}")
|
||||||
|
|
||||||
def _check_model_restrictions(self, provider_model_name: str, user_model_name: str) -> None:
|
# ------------------------------------------------------------------
|
||||||
"""Check if a model is allowed by restriction policy.
|
# Provider identity
|
||||||
|
# ------------------------------------------------------------------
|
||||||
Args:
|
|
||||||
provider_model_name: The model name used by the provider
|
|
||||||
user_model_name: The model name requested by the user
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the model is not allowed by restriction policy
|
|
||||||
"""
|
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service()
|
|
||||||
if not restriction_service.is_allowed(ProviderType.OPENAI, provider_model_name, user_model_name):
|
|
||||||
raise ValueError(f"OpenAI model '{user_model_name}' is not allowed by restriction policy.")
|
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
return ProviderType.OPENAI
|
return ProviderType.OPENAI
|
||||||
|
|
||||||
def validate_model_name(self, model_name: str) -> bool:
|
# ------------------------------------------------------------------
|
||||||
"""Validate if the model name is supported and allowed."""
|
# Request execution
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
# First, determine which model name to check against restrictions.
|
|
||||||
model_to_check = None
|
|
||||||
is_custom_model = False
|
|
||||||
|
|
||||||
if resolved_name in self.MODEL_CAPABILITIES:
|
|
||||||
model_to_check = resolved_name
|
|
||||||
else:
|
|
||||||
# If not a built-in model, check the custom models registry.
|
|
||||||
try:
|
|
||||||
from .openrouter_registry import OpenRouterModelRegistry
|
|
||||||
|
|
||||||
registry = OpenRouterModelRegistry()
|
|
||||||
config = registry.get_model_config(resolved_name)
|
|
||||||
|
|
||||||
if config and config.provider == ProviderType.OPENAI:
|
|
||||||
model_to_check = config.model_name
|
|
||||||
is_custom_model = True
|
|
||||||
except Exception as e:
|
|
||||||
# Log but don't fail - registry might not be available.
|
|
||||||
logger.debug(f"Could not check custom models registry for '{resolved_name}': {e}")
|
|
||||||
|
|
||||||
# If no model was found (neither built-in nor custom), it's invalid.
|
|
||||||
if not model_to_check:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Now, perform the restriction check once.
|
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service()
|
|
||||||
if not restriction_service.is_allowed(ProviderType.OPENAI, model_to_check, model_name):
|
|
||||||
model_type = "custom " if is_custom_model else ""
|
|
||||||
logger.debug(f"OpenAI {model_type}model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def generate_content(
|
def generate_content(
|
||||||
self,
|
self,
|
||||||
@@ -298,6 +253,10 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Provider preferences
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
|
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
|
||||||
"""Get OpenAI's preferred model for a given category from allowed models.
|
"""Get OpenAI's preferred model for a given category from allowed models.
|
||||||
|
|
||||||
|
|||||||
@@ -61,108 +61,52 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
|||||||
aliases = self._registry.list_aliases()
|
aliases = self._registry.list_aliases()
|
||||||
logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases")
|
logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases")
|
||||||
|
|
||||||
def _resolve_model_name(self, model_name: str) -> str:
|
# ------------------------------------------------------------------
|
||||||
"""Resolve model aliases to OpenRouter model names.
|
# Capability surface
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
Args:
|
def _lookup_capabilities(
|
||||||
model_name: Input model name or alias
|
self,
|
||||||
|
canonical_name: str,
|
||||||
Returns:
|
requested_name: Optional[str] = None,
|
||||||
Resolved OpenRouter model name
|
) -> Optional[ModelCapabilities]:
|
||||||
"""
|
"""Fetch OpenRouter capabilities from the registry or build a generic fallback."""
|
||||||
# Try to resolve through registry
|
|
||||||
config = self._registry.resolve(model_name)
|
|
||||||
|
|
||||||
if config:
|
|
||||||
if config.model_name != model_name:
|
|
||||||
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
|
||||||
return config.model_name
|
|
||||||
else:
|
|
||||||
# If not found in registry, return as-is
|
|
||||||
# This allows using models not in our config file
|
|
||||||
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
|
||||||
return model_name
|
|
||||||
|
|
||||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
|
||||||
"""Get capabilities for a model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Name of the model (or alias)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ModelCapabilities from registry or generic defaults
|
|
||||||
"""
|
|
||||||
# Try to get from registry first
|
|
||||||
capabilities = self._registry.get_capabilities(model_name)
|
|
||||||
|
|
||||||
|
capabilities = self._registry.get_capabilities(canonical_name)
|
||||||
if capabilities:
|
if capabilities:
|
||||||
return capabilities
|
return capabilities
|
||||||
else:
|
|
||||||
# Resolve any potential aliases and create generic capabilities
|
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
|
||||||
|
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"Using generic capabilities for '{resolved_name}' via OpenRouter. "
|
f"Using generic capabilities for '{canonical_name}' via OpenRouter. "
|
||||||
"Consider adding to custom_models.json for specific capabilities."
|
"Consider adding to custom_models.json for specific capabilities."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create generic capabilities with conservative defaults
|
generic = ModelCapabilities(
|
||||||
capabilities = ModelCapabilities(
|
provider=ProviderType.OPENROUTER,
|
||||||
provider=ProviderType.OPENROUTER,
|
model_name=canonical_name,
|
||||||
model_name=resolved_name,
|
friendly_name=self.FRIENDLY_NAME,
|
||||||
friendly_name=self.FRIENDLY_NAME,
|
context_window=32_768,
|
||||||
context_window=32_768, # Conservative default context window
|
max_output_tokens=32_768,
|
||||||
max_output_tokens=32_768,
|
supports_extended_thinking=False,
|
||||||
supports_extended_thinking=False,
|
supports_system_prompts=True,
|
||||||
supports_system_prompts=True,
|
supports_streaming=True,
|
||||||
supports_streaming=True,
|
supports_function_calling=False,
|
||||||
supports_function_calling=False,
|
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
|
||||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
|
)
|
||||||
)
|
generic._is_generic = True
|
||||||
|
return generic
|
||||||
|
|
||||||
# Mark as generic for validation purposes
|
# ------------------------------------------------------------------
|
||||||
capabilities._is_generic = True
|
# Provider identity
|
||||||
|
# ------------------------------------------------------------------
|
||||||
return capabilities
|
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Identify this provider for restrictions and logging."""
|
||||||
return ProviderType.OPENROUTER
|
return ProviderType.OPENROUTER
|
||||||
|
|
||||||
def validate_model_name(self, model_name: str) -> bool:
|
# ------------------------------------------------------------------
|
||||||
"""Validate if the model name is allowed.
|
# Request execution
|
||||||
|
# ------------------------------------------------------------------
|
||||||
As the catch-all provider, OpenRouter accepts any model name that wasn't
|
|
||||||
handled by higher-priority providers. OpenRouter will validate based on
|
|
||||||
the API key's permissions and local restrictions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Model name to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if model is allowed, False if restricted
|
|
||||||
"""
|
|
||||||
# Check model restrictions if configured
|
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service()
|
|
||||||
if restriction_service:
|
|
||||||
# Check if model name itself is allowed
|
|
||||||
if restriction_service.is_allowed(self.get_provider_type(), model_name):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Also check aliases - model_name might be an alias
|
|
||||||
model_config = self._registry.resolve(model_name)
|
|
||||||
if model_config and model_config.aliases:
|
|
||||||
for alias in model_config.aliases:
|
|
||||||
if restriction_service.is_allowed(self.get_provider_type(), alias):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# If restrictions are configured and model/alias not in allowed list, reject
|
|
||||||
return False
|
|
||||||
|
|
||||||
# No restrictions configured - accept any model name as the fallback provider
|
|
||||||
return True
|
|
||||||
|
|
||||||
def generate_content(
|
def generate_content(
|
||||||
self,
|
self,
|
||||||
@@ -204,6 +148,10 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Registry helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def list_models(
|
def list_models(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -227,6 +175,12 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
|||||||
if not config:
|
if not config:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Custom models belong to CustomProvider; skip them here so the two
|
||||||
|
# providers don't race over the same registrations (important for tests
|
||||||
|
# that stub the registry with minimal objects lacking attrs).
|
||||||
|
if hasattr(config, "is_custom") and config.is_custom is True:
|
||||||
|
continue
|
||||||
|
|
||||||
if restriction_service:
|
if restriction_service:
|
||||||
allowed = restriction_service.is_allowed(self.get_provider_type(), model_name)
|
allowed = restriction_service.is_allowed(self.get_provider_type(), model_name)
|
||||||
|
|
||||||
@@ -255,24 +209,37 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
|||||||
unique=unique,
|
unique=unique,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
# ------------------------------------------------------------------
|
||||||
"""Get model configurations from the registry.
|
# Registry helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
For OpenRouter, we convert registry configurations to ModelCapabilities objects.
|
def _resolve_model_name(self, model_name: str) -> str:
|
||||||
|
"""Resolve aliases defined in the OpenRouter registry."""
|
||||||
|
|
||||||
Returns:
|
config = self._registry.resolve(model_name)
|
||||||
Dictionary mapping model names to their ModelCapabilities objects
|
if config:
|
||||||
"""
|
if config.model_name != model_name:
|
||||||
configs = {}
|
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
||||||
|
return config.model_name
|
||||||
|
|
||||||
if self._registry:
|
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
||||||
# Get all models from registry
|
return model_name
|
||||||
for model_name in self._registry.list_models():
|
|
||||||
# Only include models that this provider validates
|
|
||||||
if self.validate_model_name(model_name):
|
|
||||||
config = self._registry.resolve(model_name)
|
|
||||||
if config and not config.is_custom: # Only OpenRouter models, not custom ones
|
|
||||||
# Use ModelCapabilities directly from registry
|
|
||||||
configs[model_name] = config
|
|
||||||
|
|
||||||
return configs
|
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||||
|
"""Expose registry-backed OpenRouter capabilities."""
|
||||||
|
|
||||||
|
if not self._registry:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
capabilities: dict[str, ModelCapabilities] = {}
|
||||||
|
for model_name in self._registry.list_models():
|
||||||
|
config = self._registry.resolve(model_name)
|
||||||
|
if not config:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# See note in list_models: respect the CustomProvider boundary.
|
||||||
|
if hasattr(config, "is_custom") and config.is_custom is True:
|
||||||
|
continue
|
||||||
|
|
||||||
|
capabilities[model_name] = config
|
||||||
|
return capabilities
|
||||||
|
|||||||
@@ -64,6 +64,8 @@ class ModelProviderRegistry:
|
|||||||
"""
|
"""
|
||||||
instance = cls()
|
instance = cls()
|
||||||
instance._providers[provider_type] = provider_class
|
instance._providers[provider_type] = provider_class
|
||||||
|
# Invalidate any cached instance so subsequent lookups use the new registration
|
||||||
|
instance._initialized_providers.pop(provider_type, None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_provider(cls, provider_type: ProviderType, force_new: bool = False) -> Optional[ModelProvider]:
|
def get_provider(cls, provider_type: ProviderType, force_new: bool = False) -> Optional[ModelProvider]:
|
||||||
|
|||||||
@@ -85,46 +85,10 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
|||||||
kwargs.setdefault("base_url", "https://api.x.ai/v1")
|
kwargs.setdefault("base_url", "https://api.x.ai/v1")
|
||||||
super().__init__(api_key, **kwargs)
|
super().__init__(api_key, **kwargs)
|
||||||
|
|
||||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
|
||||||
"""Get capabilities for a specific X.AI model."""
|
|
||||||
# Resolve shorthand
|
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
|
||||||
|
|
||||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
|
||||||
raise ValueError(f"Unsupported X.AI model: {model_name}")
|
|
||||||
|
|
||||||
# Check if model is allowed by restrictions
|
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service()
|
|
||||||
if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name):
|
|
||||||
raise ValueError(f"X.AI model '{model_name}' is not allowed by restriction policy.")
|
|
||||||
|
|
||||||
# Return the ModelCapabilities object directly from MODEL_CAPABILITIES
|
|
||||||
return self.MODEL_CAPABILITIES[resolved_name]
|
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
return ProviderType.XAI
|
return ProviderType.XAI
|
||||||
|
|
||||||
def validate_model_name(self, model_name: str) -> bool:
|
|
||||||
"""Validate if the model name is supported and allowed."""
|
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
|
||||||
|
|
||||||
# First check if model is supported
|
|
||||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Then check if model is allowed by restrictions
|
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service()
|
|
||||||
if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name):
|
|
||||||
logger.debug(f"X.AI model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def generate_content(
|
def generate_content(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
|||||||
@@ -54,11 +54,9 @@ class TestCustomProvider:
|
|||||||
|
|
||||||
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
||||||
|
|
||||||
# Test with a model that should be in the registry (OpenRouter model)
|
# OpenRouter-backed models should be handled by the OpenRouter provider
|
||||||
capabilities = provider.get_capabilities("o3") # o3 is an OpenRouter model
|
with pytest.raises(ValueError):
|
||||||
|
provider.get_capabilities("o3")
|
||||||
assert capabilities.provider == ProviderType.OPENROUTER # o3 is an OpenRouter model (is_custom=false)
|
|
||||||
assert capabilities.context_window > 0
|
|
||||||
|
|
||||||
# Test with a custom model (is_custom=true)
|
# Test with a custom model (is_custom=true)
|
||||||
capabilities = provider.get_capabilities("local-llama")
|
capabilities = provider.get_capabilities("local-llama")
|
||||||
@@ -168,7 +166,13 @@ class TestCustomProviderRegistration:
|
|||||||
return CustomProvider(api_key="", base_url="http://localhost:11434/v1")
|
return CustomProvider(api_key="", base_url="http://localhost:11434/v1")
|
||||||
|
|
||||||
with patch.dict(
|
with patch.dict(
|
||||||
os.environ, {"OPENROUTER_API_KEY": "test-openrouter-key", "CUSTOM_API_PLACEHOLDER": "configured"}
|
os.environ,
|
||||||
|
{
|
||||||
|
"OPENROUTER_API_KEY": "test-openrouter-key",
|
||||||
|
"CUSTOM_API_PLACEHOLDER": "configured",
|
||||||
|
"OPENROUTER_ALLOWED_MODELS": "llama,anthropic/claude-opus-4.1",
|
||||||
|
},
|
||||||
|
clear=True,
|
||||||
):
|
):
|
||||||
# Register both providers
|
# Register both providers
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||||
@@ -195,18 +199,22 @@ class TestCustomProviderRegistration:
|
|||||||
return CustomProvider(api_key="", base_url="http://localhost:11434/v1")
|
return CustomProvider(api_key="", base_url="http://localhost:11434/v1")
|
||||||
|
|
||||||
with patch.dict(
|
with patch.dict(
|
||||||
os.environ, {"OPENROUTER_API_KEY": "test-openrouter-key", "CUSTOM_API_PLACEHOLDER": "configured"}
|
os.environ,
|
||||||
|
{
|
||||||
|
"OPENROUTER_API_KEY": "test-openrouter-key",
|
||||||
|
"CUSTOM_API_PLACEHOLDER": "configured",
|
||||||
|
"OPENROUTER_ALLOWED_MODELS": "",
|
||||||
|
},
|
||||||
|
clear=True,
|
||||||
):
|
):
|
||||||
# Register OpenRouter first (higher priority)
|
import utils.model_restrictions
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, custom_provider_factory)
|
|
||||||
|
|
||||||
# Test model resolution - OpenRouter should win for shared aliases
|
utils.model_restrictions._restriction_service = None
|
||||||
provider_for_model = ModelProviderRegistry.get_provider_for_model("llama")
|
custom_provider = custom_provider_factory()
|
||||||
|
openrouter_provider = OpenRouterProvider(api_key="test-openrouter-key")
|
||||||
|
|
||||||
# OpenRouter should be selected first due to registration order
|
assert not custom_provider.validate_model_name("llama")
|
||||||
assert provider_for_model is not None
|
assert openrouter_provider.validate_model_name("llama")
|
||||||
# The exact provider type depends on which validates the model first
|
|
||||||
|
|
||||||
|
|
||||||
class TestConfigureProvidersFunction:
|
class TestConfigureProvidersFunction:
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ class TestDIALProvider:
|
|||||||
"""Test that get_capabilities raises for invalid models."""
|
"""Test that get_capabilities raises for invalid models."""
|
||||||
provider = DIALModelProvider("test-key")
|
provider = DIALModelProvider("test-key")
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Unsupported DIAL model"):
|
with pytest.raises(ValueError, match="Unsupported model 'invalid-model' for provider dial"):
|
||||||
provider.get_capabilities("invalid-model")
|
provider.get_capabilities("invalid-model")
|
||||||
|
|
||||||
@patch("utils.model_restrictions.get_restriction_service")
|
@patch("utils.model_restrictions.get_restriction_service")
|
||||||
|
|||||||
@@ -356,15 +356,13 @@ class TestCustomProviderOpenRouterRestrictions:
|
|||||||
|
|
||||||
provider = CustomProvider(base_url="http://test.com/v1")
|
provider = CustomProvider(base_url="http://test.com/v1")
|
||||||
|
|
||||||
# For OpenRouter models, get_capabilities should still work but mark them as OPENROUTER
|
# For OpenRouter models, CustomProvider should defer by raising
|
||||||
# This tests the capabilities lookup, not validation
|
with pytest.raises(ValueError):
|
||||||
capabilities = provider.get_capabilities("opus")
|
provider.get_capabilities("opus")
|
||||||
assert capabilities.provider == ProviderType.OPENROUTER
|
|
||||||
|
|
||||||
# Should raise for disallowed OpenRouter model
|
# Should raise for disallowed OpenRouter model (still defers)
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError):
|
||||||
provider.get_capabilities("haiku")
|
provider.get_capabilities("haiku")
|
||||||
assert "not allowed by restriction policy" in str(exc_info.value)
|
|
||||||
|
|
||||||
# Should still work for custom models (is_custom=true)
|
# Should still work for custom models (is_custom=true)
|
||||||
capabilities = provider.get_capabilities("local-llama")
|
capabilities = provider.get_capabilities("local-llama")
|
||||||
|
|||||||
@@ -141,7 +141,7 @@ class TestXAIProvider:
|
|||||||
"""Test error handling for unsupported models."""
|
"""Test error handling for unsupported models."""
|
||||||
provider = XAIModelProvider("test-key")
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Unsupported X.AI model"):
|
with pytest.raises(ValueError, match="Unsupported model 'invalid-model' for provider xai"):
|
||||||
provider.get_capabilities("invalid-model")
|
provider.get_capabilities("invalid-model")
|
||||||
|
|
||||||
def test_extended_thinking_flags(self):
|
def test_extended_thinking_flags(self):
|
||||||
|
|||||||
@@ -105,13 +105,11 @@ class ListModelsTool(BaseTool):
|
|||||||
output_lines.append("**Status**: Configured and available")
|
output_lines.append("**Status**: Configured and available")
|
||||||
output_lines.append("\n**Models**:")
|
output_lines.append("\n**Models**:")
|
||||||
|
|
||||||
# Get models from the provider's model configurations
|
aliases = []
|
||||||
for model_name, capabilities in provider.get_model_configurations().items():
|
for model_name, capabilities in provider.get_all_model_capabilities().items():
|
||||||
# Get description and context from the ModelCapabilities object
|
|
||||||
description = capabilities.description or "No description available"
|
description = capabilities.description or "No description available"
|
||||||
context_window = capabilities.context_window
|
context_window = capabilities.context_window
|
||||||
|
|
||||||
# Format context window
|
|
||||||
if context_window >= 1_000_000:
|
if context_window >= 1_000_000:
|
||||||
context_str = f"{context_window // 1_000_000}M context"
|
context_str = f"{context_window // 1_000_000}M context"
|
||||||
elif context_window >= 1_000:
|
elif context_window >= 1_000:
|
||||||
@@ -120,31 +118,15 @@ class ListModelsTool(BaseTool):
|
|||||||
context_str = f"{context_window} context" if context_window > 0 else "unknown context"
|
context_str = f"{context_window} context" if context_window > 0 else "unknown context"
|
||||||
|
|
||||||
output_lines.append(f"- `{model_name}` - {context_str}")
|
output_lines.append(f"- `{model_name}` - {context_str}")
|
||||||
|
output_lines.append(f" - {description}")
|
||||||
|
|
||||||
# Extract key capability from description
|
for alias in capabilities.aliases or []:
|
||||||
if "Ultra-fast" in description:
|
if alias != model_name:
|
||||||
output_lines.append(" - Fast processing, quick iterations")
|
aliases.append(f"- `{alias}` → `{model_name}`")
|
||||||
elif "Deep reasoning" in description:
|
|
||||||
output_lines.append(" - Extended reasoning with thinking mode")
|
|
||||||
elif "Strong reasoning" in description:
|
|
||||||
output_lines.append(" - Logical problems, systematic analysis")
|
|
||||||
elif "EXTREMELY EXPENSIVE" in description:
|
|
||||||
output_lines.append(" - ⚠️ Professional grade (very expensive)")
|
|
||||||
elif "Advanced reasoning" in description:
|
|
||||||
output_lines.append(" - Advanced reasoning and complex analysis")
|
|
||||||
|
|
||||||
# Show aliases for this provider
|
|
||||||
aliases = []
|
|
||||||
for model_name, capabilities in provider.get_model_configurations().items():
|
|
||||||
if capabilities.aliases:
|
|
||||||
for alias in capabilities.aliases:
|
|
||||||
# Skip aliases that are the same as the model name to avoid duplicates
|
|
||||||
if alias != model_name:
|
|
||||||
aliases.append(f"- `{alias}` → `{model_name}`")
|
|
||||||
|
|
||||||
if aliases:
|
if aliases:
|
||||||
output_lines.append("\n**Aliases**:")
|
output_lines.append("\n**Aliases**:")
|
||||||
output_lines.extend(sorted(aliases)) # Sort for consistent output
|
output_lines.extend(sorted(aliases))
|
||||||
else:
|
else:
|
||||||
output_lines.append(f"**Status**: Not configured (set {info['env_key']})")
|
output_lines.append(f"**Status**: Not configured (set {info['env_key']})")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user