refactor: renaming to reflect underlying type
docs: updated to reflect new modules
This commit is contained in:
@@ -30,7 +30,7 @@ DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "auto")
|
|||||||
# Auto mode detection - when DEFAULT_MODEL is "auto", Claude picks the model
|
# Auto mode detection - when DEFAULT_MODEL is "auto", Claude picks the model
|
||||||
IS_AUTO_MODE = DEFAULT_MODEL.lower() == "auto"
|
IS_AUTO_MODE = DEFAULT_MODEL.lower() == "auto"
|
||||||
|
|
||||||
# Each provider (gemini.py, openai_provider.py, xai.py) defines its own SUPPORTED_MODELS
|
# Each provider (gemini.py, openai_provider.py, xai.py) defines its own MODEL_CAPABILITIES
|
||||||
# with detailed descriptions. Tools use ModelProviderRegistry.get_available_model_names()
|
# with detailed descriptions. Tools use ModelProviderRegistry.get_available_model_names()
|
||||||
# to get models only from enabled providers (those with valid API keys).
|
# to get models only from enabled providers (those with valid API keys).
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ Each provider:
|
|||||||
|
|
||||||
### 1. Add Provider Type
|
### 1. Add Provider Type
|
||||||
|
|
||||||
Add your provider to `ProviderType` enum in `providers/base.py`:
|
Add your provider to the `ProviderType` enum in `providers/shared/provider_type.py`:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class ProviderType(Enum):
|
class ProviderType(Enum):
|
||||||
@@ -48,15 +48,23 @@ Create `providers/example.py`:
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, RangeTemperatureConstraint
|
|
||||||
|
from .base import ModelProvider
|
||||||
|
from .shared import (
|
||||||
|
ModelCapabilities,
|
||||||
|
ModelResponse,
|
||||||
|
ProviderType,
|
||||||
|
RangeTemperatureConstraint,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ExampleModelProvider(ModelProvider):
|
class ExampleModelProvider(ModelProvider):
|
||||||
"""Example model provider implementation."""
|
"""Example model provider implementation."""
|
||||||
|
|
||||||
# Define models using ModelCapabilities objects (like Gemini provider)
|
# Define models using ModelCapabilities objects (like Gemini provider)
|
||||||
SUPPORTED_MODELS = {
|
MODEL_CAPABILITIES = {
|
||||||
"example-large": ModelCapabilities(
|
"example-large": ModelCapabilities(
|
||||||
provider=ProviderType.EXAMPLE,
|
provider=ProviderType.EXAMPLE,
|
||||||
model_name="example-large",
|
model_name="example-large",
|
||||||
@@ -87,7 +95,7 @@ class ExampleModelProvider(ModelProvider):
|
|||||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
if resolved_name not in self.SUPPORTED_MODELS:
|
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||||
raise ValueError(f"Unsupported model: {model_name}")
|
raise ValueError(f"Unsupported model: {model_name}")
|
||||||
|
|
||||||
# Apply restrictions if needed
|
# Apply restrictions if needed
|
||||||
@@ -96,7 +104,7 @@ class ExampleModelProvider(ModelProvider):
|
|||||||
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
|
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
|
||||||
raise ValueError(f"Model '{model_name}' is not allowed.")
|
raise ValueError(f"Model '{model_name}' is not allowed.")
|
||||||
|
|
||||||
return self.SUPPORTED_MODELS[resolved_name]
|
return self.MODEL_CAPABILITIES[resolved_name]
|
||||||
|
|
||||||
def generate_content(self, prompt: str, model_name: str, system_prompt: Optional[str] = None,
|
def generate_content(self, prompt: str, model_name: str, system_prompt: Optional[str] = None,
|
||||||
temperature: float = 0.7, max_output_tokens: Optional[int] = None, **kwargs) -> ModelResponse:
|
temperature: float = 0.7, max_output_tokens: Optional[int] = None, **kwargs) -> ModelResponse:
|
||||||
@@ -121,7 +129,7 @@ class ExampleModelProvider(ModelProvider):
|
|||||||
|
|
||||||
def validate_model_name(self, model_name: str) -> bool:
|
def validate_model_name(self, model_name: str) -> bool:
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
return resolved_name in self.SUPPORTED_MODELS
|
return resolved_name in self.MODEL_CAPABILITIES
|
||||||
|
|
||||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||||
capabilities = self.get_capabilities(model_name)
|
capabilities = self.get_capabilities(model_name)
|
||||||
@@ -136,8 +144,15 @@ For OpenAI-compatible APIs:
|
|||||||
"""Example OpenAI-compatible provider."""
|
"""Example OpenAI-compatible provider."""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from .base import ModelCapabilities, ModelResponse, ProviderType, RangeTemperatureConstraint
|
|
||||||
from .openai_compatible import OpenAICompatibleProvider
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
from .shared import (
|
||||||
|
ModelCapabilities,
|
||||||
|
ModelResponse,
|
||||||
|
ProviderType,
|
||||||
|
RangeTemperatureConstraint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ExampleProvider(OpenAICompatibleProvider):
|
class ExampleProvider(OpenAICompatibleProvider):
|
||||||
"""Example OpenAI-compatible provider."""
|
"""Example OpenAI-compatible provider."""
|
||||||
@@ -145,7 +160,7 @@ class ExampleProvider(OpenAICompatibleProvider):
|
|||||||
FRIENDLY_NAME = "Example"
|
FRIENDLY_NAME = "Example"
|
||||||
|
|
||||||
# Define models using ModelCapabilities (consistent with other providers)
|
# Define models using ModelCapabilities (consistent with other providers)
|
||||||
SUPPORTED_MODELS = {
|
MODEL_CAPABILITIES = {
|
||||||
"example-model-large": ModelCapabilities(
|
"example-model-large": ModelCapabilities(
|
||||||
provider=ProviderType.EXAMPLE,
|
provider=ProviderType.EXAMPLE,
|
||||||
model_name="example-model-large",
|
model_name="example-model-large",
|
||||||
@@ -163,16 +178,16 @@ class ExampleProvider(OpenAICompatibleProvider):
|
|||||||
|
|
||||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
if resolved_name not in self.SUPPORTED_MODELS:
|
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||||
raise ValueError(f"Unsupported model: {model_name}")
|
raise ValueError(f"Unsupported model: {model_name}")
|
||||||
return self.SUPPORTED_MODELS[resolved_name]
|
return self.MODEL_CAPABILITIES[resolved_name]
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
return ProviderType.EXAMPLE
|
return ProviderType.EXAMPLE
|
||||||
|
|
||||||
def validate_model_name(self, model_name: str) -> bool:
|
def validate_model_name(self, model_name: str) -> bool:
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
return resolved_name in self.SUPPORTED_MODELS
|
return resolved_name in self.MODEL_CAPABILITIES
|
||||||
|
|
||||||
def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
|
def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
|
||||||
# IMPORTANT: Resolve aliases before API call
|
# IMPORTANT: Resolve aliases before API call
|
||||||
@@ -185,12 +200,8 @@ class ExampleProvider(OpenAICompatibleProvider):
|
|||||||
Add environment variable mapping in `providers/registry.py`:
|
Add environment variable mapping in `providers/registry.py`:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# In _get_api_key_for_provider method:
|
# In _get_api_key_for_provider (providers/registry.py), add:
|
||||||
key_mapping = {
|
ProviderType.EXAMPLE: "EXAMPLE_API_KEY",
|
||||||
ProviderType.GOOGLE: "GEMINI_API_KEY",
|
|
||||||
ProviderType.OPENAI: "OPENAI_API_KEY",
|
|
||||||
ProviderType.EXAMPLE: "EXAMPLE_API_KEY", # Add this
|
|
||||||
}
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Add to `server.py`:
|
Add to `server.py`:
|
||||||
@@ -209,16 +220,7 @@ if example_key:
|
|||||||
logger.info("Example API key found - Example models available")
|
logger.info("Example API key found - Example models available")
|
||||||
```
|
```
|
||||||
|
|
||||||
3. **Add to provider priority** (in `providers/registry.py`):
|
3. **Add to provider priority** (edit `ModelProviderRegistry.PROVIDER_PRIORITY_ORDER` in `providers/registry.py`): insert your provider in the list at the appropriate point in the cascade of native → custom → catch-all providers.
|
||||||
```python
|
|
||||||
PROVIDER_PRIORITY_ORDER = [
|
|
||||||
ProviderType.GOOGLE,
|
|
||||||
ProviderType.OPENAI,
|
|
||||||
ProviderType.EXAMPLE, # Add your provider here
|
|
||||||
ProviderType.CUSTOM, # Local models
|
|
||||||
ProviderType.OPENROUTER, # Catch-all (keep last)
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. Environment Configuration
|
### 4. Environment Configuration
|
||||||
|
|
||||||
@@ -265,7 +267,7 @@ Your `validate_model_name()` should **only** return `True` for models you explic
|
|||||||
```python
|
```python
|
||||||
def validate_model_name(self, model_name: str) -> bool:
|
def validate_model_name(self, model_name: str) -> bool:
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
return resolved_name in self.SUPPORTED_MODELS # Be specific!
|
return resolved_name in self.MODEL_CAPABILITIES # Be specific!
|
||||||
```
|
```
|
||||||
|
|
||||||
### Model Aliases
|
### Model Aliases
|
||||||
@@ -296,7 +298,7 @@ Without this, API calls with aliases like `"large"` will fail because your API d
|
|||||||
|
|
||||||
## Quick Checklist
|
## Quick Checklist
|
||||||
|
|
||||||
- [ ] Added to `ProviderType` enum in `providers/base.py`
|
- [ ] Added to `ProviderType` enum in `providers/shared/provider_type.py`
|
||||||
- [ ] Created provider class with all required methods
|
- [ ] Created provider class with all required methods
|
||||||
- [ ] Added API key mapping in `providers/registry.py`
|
- [ ] Added API key mapping in `providers/registry.py`
|
||||||
- [ ] Added to provider priority order in `registry.py`
|
- [ ] Added to provider priority order in `registry.py`
|
||||||
@@ -310,5 +312,3 @@ See existing implementations:
|
|||||||
- **Full provider**: `providers/gemini.py`
|
- **Full provider**: `providers/gemini.py`
|
||||||
- **OpenAI-compatible**: `providers/custom.py`
|
- **OpenAI-compatible**: `providers/custom.py`
|
||||||
- **Base classes**: `providers/base.py`
|
- **Base classes**: `providers/base.py`
|
||||||
|
|
||||||
The modern approach uses `ModelCapabilities` objects directly in `SUPPORTED_MODELS`, making the implementation much cleaner and more consistent.
|
|
||||||
@@ -28,7 +28,7 @@ class ModelProvider(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# All concrete providers must define their supported models
|
# All concrete providers must define their supported models
|
||||||
SUPPORTED_MODELS: dict[str, Any] = {}
|
MODEL_CAPABILITIES: dict[str, Any] = {}
|
||||||
|
|
||||||
# Default maximum image size in MB
|
# Default maximum image size in MB
|
||||||
DEFAULT_MAX_IMAGE_SIZE_MB = 20.0
|
DEFAULT_MAX_IMAGE_SIZE_MB = 20.0
|
||||||
@@ -147,9 +147,9 @@ class ModelProvider(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
Dictionary mapping model names to their ModelCapabilities objects
|
Dictionary mapping model names to their ModelCapabilities objects
|
||||||
"""
|
"""
|
||||||
# Return SUPPORTED_MODELS if it exists (must contain ModelCapabilities objects)
|
model_map = getattr(self, "MODEL_CAPABILITIES", None)
|
||||||
if hasattr(self, "SUPPORTED_MODELS"):
|
if isinstance(model_map, dict) and model_map:
|
||||||
return {k: v for k, v in self.SUPPORTED_MODELS.items() if isinstance(v, ModelCapabilities)}
|
return {k: v for k, v in model_map.items() if isinstance(v, ModelCapabilities)}
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _resolve_model_name(self, model_name: str) -> str:
|
def _resolve_model_name(self, model_name: str) -> str:
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
RETRY_DELAYS = [1, 3, 5, 8] # seconds
|
RETRY_DELAYS = [1, 3, 5, 8] # seconds
|
||||||
|
|
||||||
# Model configurations using ModelCapabilities objects
|
# Model configurations using ModelCapabilities objects
|
||||||
SUPPORTED_MODELS = {
|
MODEL_CAPABILITIES = {
|
||||||
"o3-2025-04-16": ModelCapabilities(
|
"o3-2025-04-16": ModelCapabilities(
|
||||||
provider=ProviderType.DIAL,
|
provider=ProviderType.DIAL,
|
||||||
model_name="o3-2025-04-16",
|
model_name="o3-2025-04-16",
|
||||||
@@ -280,7 +280,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
"""
|
"""
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
if resolved_name not in self.SUPPORTED_MODELS:
|
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||||
raise ValueError(f"Unsupported DIAL model: {model_name}")
|
raise ValueError(f"Unsupported DIAL model: {model_name}")
|
||||||
|
|
||||||
# Check restrictions
|
# Check restrictions
|
||||||
@@ -290,8 +290,8 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
|
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
|
||||||
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
|
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
|
||||||
|
|
||||||
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
# Return the ModelCapabilities object directly from MODEL_CAPABILITIES
|
||||||
return self.SUPPORTED_MODELS[resolved_name]
|
return self.MODEL_CAPABILITIES[resolved_name]
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
@@ -308,7 +308,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
"""
|
"""
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
if resolved_name not in self.SUPPORTED_MODELS:
|
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check against base class allowed_models if configured
|
# Check against base class allowed_models if configured
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Model configurations using ModelCapabilities objects
|
# Model configurations using ModelCapabilities objects
|
||||||
SUPPORTED_MODELS = {
|
MODEL_CAPABILITIES = {
|
||||||
"gemini-2.5-pro": ModelCapabilities(
|
"gemini-2.5-pro": ModelCapabilities(
|
||||||
provider=ProviderType.GOOGLE,
|
provider=ProviderType.GOOGLE,
|
||||||
model_name="gemini-2.5-pro",
|
model_name="gemini-2.5-pro",
|
||||||
@@ -154,7 +154,7 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
# Resolve shorthand
|
# Resolve shorthand
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
if resolved_name not in self.SUPPORTED_MODELS:
|
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||||
raise ValueError(f"Unsupported Gemini model: {model_name}")
|
raise ValueError(f"Unsupported Gemini model: {model_name}")
|
||||||
|
|
||||||
# Check if model is allowed by restrictions
|
# Check if model is allowed by restrictions
|
||||||
@@ -166,8 +166,8 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
|
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
|
||||||
raise ValueError(f"Gemini model '{resolved_name}' is not allowed by restriction policy.")
|
raise ValueError(f"Gemini model '{resolved_name}' is not allowed by restriction policy.")
|
||||||
|
|
||||||
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
# Return the ModelCapabilities object directly from MODEL_CAPABILITIES
|
||||||
return self.SUPPORTED_MODELS[resolved_name]
|
return self.MODEL_CAPABILITIES[resolved_name]
|
||||||
|
|
||||||
def generate_content(
|
def generate_content(
|
||||||
self,
|
self,
|
||||||
@@ -227,7 +227,7 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
# Add thinking configuration for models that support it
|
# Add thinking configuration for models that support it
|
||||||
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
|
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
|
||||||
# Get model's max thinking tokens and calculate actual budget
|
# Get model's max thinking tokens and calculate actual budget
|
||||||
model_config = self.SUPPORTED_MODELS.get(resolved_name)
|
model_config = self.MODEL_CAPABILITIES.get(resolved_name)
|
||||||
if model_config and model_config.max_thinking_tokens > 0:
|
if model_config and model_config.max_thinking_tokens > 0:
|
||||||
max_thinking_tokens = model_config.max_thinking_tokens
|
max_thinking_tokens = model_config.max_thinking_tokens
|
||||||
actual_thinking_budget = int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
|
actual_thinking_budget = int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
|
||||||
@@ -382,7 +382,7 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
# First check if model is supported
|
# First check if model is supported
|
||||||
if resolved_name not in self.SUPPORTED_MODELS:
|
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Then check if model is allowed by restrictions
|
# Then check if model is allowed by restrictions
|
||||||
@@ -405,7 +405,7 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int:
|
def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int:
|
||||||
"""Get actual thinking token budget for a model and thinking mode."""
|
"""Get actual thinking token budget for a model and thinking mode."""
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
model_config = self.SUPPORTED_MODELS.get(resolved_name)
|
model_config = self.MODEL_CAPABILITIES.get(resolved_name)
|
||||||
|
|
||||||
if not model_config or not model_config.supports_extended_thinking:
|
if not model_config or not model_config.supports_extended_thinking:
|
||||||
return 0
|
return 0
|
||||||
@@ -584,7 +584,7 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
pro_thinking = [
|
pro_thinking = [
|
||||||
m
|
m
|
||||||
for m in allowed_models
|
for m in allowed_models
|
||||||
if "pro" in m and m in self.SUPPORTED_MODELS and self.SUPPORTED_MODELS[m].supports_extended_thinking
|
if "pro" in m and m in self.MODEL_CAPABILITIES and self.MODEL_CAPABILITIES[m].supports_extended_thinking
|
||||||
]
|
]
|
||||||
if pro_thinking:
|
if pro_thinking:
|
||||||
return find_best(pro_thinking)
|
return find_best(pro_thinking)
|
||||||
@@ -593,7 +593,7 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
any_thinking = [
|
any_thinking = [
|
||||||
m
|
m
|
||||||
for m in allowed_models
|
for m in allowed_models
|
||||||
if m in self.SUPPORTED_MODELS and self.SUPPORTED_MODELS[m].supports_extended_thinking
|
if m in self.MODEL_CAPABILITIES and self.MODEL_CAPABILITIES[m].supports_extended_thinking
|
||||||
]
|
]
|
||||||
if any_thinking:
|
if any_thinking:
|
||||||
return find_best(any_thinking)
|
return find_best(any_thinking)
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Model configurations using ModelCapabilities objects
|
# Model configurations using ModelCapabilities objects
|
||||||
SUPPORTED_MODELS = {
|
MODEL_CAPABILITIES = {
|
||||||
"gpt-5": ModelCapabilities(
|
"gpt-5": ModelCapabilities(
|
||||||
provider=ProviderType.OPENAI,
|
provider=ProviderType.OPENAI,
|
||||||
model_name="gpt-5",
|
model_name="gpt-5",
|
||||||
@@ -181,21 +181,21 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
|
|
||||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||||
"""Get capabilities for a specific OpenAI model."""
|
"""Get capabilities for a specific OpenAI model."""
|
||||||
# First check if it's a key in SUPPORTED_MODELS
|
# First check if it's a key in MODEL_CAPABILITIES
|
||||||
if model_name in self.SUPPORTED_MODELS:
|
if model_name in self.MODEL_CAPABILITIES:
|
||||||
self._check_model_restrictions(model_name, model_name)
|
self._check_model_restrictions(model_name, model_name)
|
||||||
return self.SUPPORTED_MODELS[model_name]
|
return self.MODEL_CAPABILITIES[model_name]
|
||||||
|
|
||||||
# Try resolving as alias
|
# Try resolving as alias
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
# Check if resolved name is a key
|
# Check if resolved name is a key
|
||||||
if resolved_name in self.SUPPORTED_MODELS:
|
if resolved_name in self.MODEL_CAPABILITIES:
|
||||||
self._check_model_restrictions(resolved_name, model_name)
|
self._check_model_restrictions(resolved_name, model_name)
|
||||||
return self.SUPPORTED_MODELS[resolved_name]
|
return self.MODEL_CAPABILITIES[resolved_name]
|
||||||
|
|
||||||
# Finally check if resolved name matches any API model name
|
# Finally check if resolved name matches any API model name
|
||||||
for key, capabilities in self.SUPPORTED_MODELS.items():
|
for key, capabilities in self.MODEL_CAPABILITIES.items():
|
||||||
if resolved_name == capabilities.model_name:
|
if resolved_name == capabilities.model_name:
|
||||||
self._check_model_restrictions(key, model_name)
|
self._check_model_restrictions(key, model_name)
|
||||||
return capabilities
|
return capabilities
|
||||||
@@ -248,7 +248,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
model_to_check = None
|
model_to_check = None
|
||||||
is_custom_model = False
|
is_custom_model = False
|
||||||
|
|
||||||
if resolved_name in self.SUPPORTED_MODELS:
|
if resolved_name in self.MODEL_CAPABILITIES:
|
||||||
model_to_check = resolved_name
|
model_to_check = resolved_name
|
||||||
else:
|
else:
|
||||||
# If not a built-in model, check the custom models registry.
|
# If not a built-in model, check the custom models registry.
|
||||||
|
|||||||
@@ -282,11 +282,9 @@ class ModelProviderRegistry:
|
|||||||
# Use list_models to get all supported models (handles both regular and custom providers)
|
# Use list_models to get all supported models (handles both regular and custom providers)
|
||||||
supported_models = provider.list_models(respect_restrictions=False)
|
supported_models = provider.list_models(respect_restrictions=False)
|
||||||
except (NotImplementedError, AttributeError):
|
except (NotImplementedError, AttributeError):
|
||||||
# Fallback to SUPPORTED_MODELS if list_models not implemented
|
# Fallback to provider-declared capability maps if list_models not implemented
|
||||||
try:
|
model_map = getattr(provider, "MODEL_CAPABILITIES", None)
|
||||||
supported_models = list(provider.SUPPORTED_MODELS.keys())
|
supported_models = list(model_map.keys()) if isinstance(model_map, dict) else []
|
||||||
except AttributeError:
|
|
||||||
supported_models = []
|
|
||||||
|
|
||||||
# Filter by restrictions
|
# Filter by restrictions
|
||||||
for model_name in supported_models:
|
for model_name in supported_models:
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
|||||||
FRIENDLY_NAME = "X.AI"
|
FRIENDLY_NAME = "X.AI"
|
||||||
|
|
||||||
# Model configurations using ModelCapabilities objects
|
# Model configurations using ModelCapabilities objects
|
||||||
SUPPORTED_MODELS = {
|
MODEL_CAPABILITIES = {
|
||||||
"grok-4": ModelCapabilities(
|
"grok-4": ModelCapabilities(
|
||||||
provider=ProviderType.XAI,
|
provider=ProviderType.XAI,
|
||||||
model_name="grok-4",
|
model_name="grok-4",
|
||||||
@@ -95,7 +95,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
|||||||
# Resolve shorthand
|
# Resolve shorthand
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
if resolved_name not in self.SUPPORTED_MODELS:
|
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||||
raise ValueError(f"Unsupported X.AI model: {model_name}")
|
raise ValueError(f"Unsupported X.AI model: {model_name}")
|
||||||
|
|
||||||
# Check if model is allowed by restrictions
|
# Check if model is allowed by restrictions
|
||||||
@@ -105,8 +105,8 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
|||||||
if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name):
|
if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name):
|
||||||
raise ValueError(f"X.AI model '{model_name}' is not allowed by restriction policy.")
|
raise ValueError(f"X.AI model '{model_name}' is not allowed by restriction policy.")
|
||||||
|
|
||||||
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
# Return the ModelCapabilities object directly from MODEL_CAPABILITIES
|
||||||
return self.SUPPORTED_MODELS[resolved_name]
|
return self.MODEL_CAPABILITIES[resolved_name]
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
@@ -117,7 +117,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
|||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
# First check if model is supported
|
# First check if model is supported
|
||||||
if resolved_name not in self.SUPPORTED_MODELS:
|
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Then check if model is allowed by restrictions
|
# Then check if model is allowed by restrictions
|
||||||
@@ -156,7 +156,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
|||||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||||
"""Check if the model supports extended thinking mode."""
|
"""Check if the model supports extended thinking mode."""
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
capabilities = self.SUPPORTED_MODELS.get(resolved_name)
|
capabilities = self.MODEL_CAPABILITIES.get(resolved_name)
|
||||||
if capabilities:
|
if capabilities:
|
||||||
return capabilities.supports_extended_thinking
|
return capabilities.supports_extended_thinking
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ class TestAliasTargetRestrictions:
|
|||||||
openai_all_known = openai_provider.list_all_known_models()
|
openai_all_known = openai_provider.list_all_known_models()
|
||||||
|
|
||||||
# Verify that for each alias, its target is also included
|
# Verify that for each alias, its target is also included
|
||||||
for model_name, config in openai_provider.SUPPORTED_MODELS.items():
|
for model_name, config in openai_provider.MODEL_CAPABILITIES.items():
|
||||||
assert model_name.lower() in openai_all_known
|
assert model_name.lower() in openai_all_known
|
||||||
if isinstance(config, str): # This is an alias
|
if isinstance(config, str): # This is an alias
|
||||||
# The target should also be in the known models
|
# The target should also be in the known models
|
||||||
@@ -178,7 +178,7 @@ class TestAliasTargetRestrictions:
|
|||||||
gemini_all_known = gemini_provider.list_all_known_models()
|
gemini_all_known = gemini_provider.list_all_known_models()
|
||||||
|
|
||||||
# Verify that for each alias, its target is also included
|
# Verify that for each alias, its target is also included
|
||||||
for model_name, config in gemini_provider.SUPPORTED_MODELS.items():
|
for model_name, config in gemini_provider.MODEL_CAPABILITIES.items():
|
||||||
assert model_name.lower() in gemini_all_known
|
assert model_name.lower() in gemini_all_known
|
||||||
if isinstance(config, str): # This is an alias
|
if isinstance(config, str): # This is an alias
|
||||||
# The target should also be in the known models
|
# The target should also be in the known models
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ class TestAutoMode:
|
|||||||
for provider_type in enabled_provider_types:
|
for provider_type in enabled_provider_types:
|
||||||
provider = ModelProviderRegistry.get_provider(provider_type)
|
provider = ModelProviderRegistry.get_provider(provider_type)
|
||||||
if provider:
|
if provider:
|
||||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
for model_name, config in provider.MODEL_CAPABILITIES.items():
|
||||||
# Skip alias entries (string values)
|
# Skip alias entries (string values)
|
||||||
if isinstance(config, str):
|
if isinstance(config, str):
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -176,7 +176,7 @@ class TestBuggyBehaviorPrevention:
|
|||||||
|
|
||||||
# Create a mock provider that simulates the old behavior
|
# Create a mock provider that simulates the old behavior
|
||||||
old_style_provider = MagicMock()
|
old_style_provider = MagicMock()
|
||||||
old_style_provider.SUPPORTED_MODELS = {
|
old_style_provider.MODEL_CAPABILITIES = {
|
||||||
"mini": "o4-mini",
|
"mini": "o4-mini",
|
||||||
"o3mini": "o3-mini",
|
"o3mini": "o3-mini",
|
||||||
"o4-mini": {"context_window": 200000},
|
"o4-mini": {"context_window": 200000},
|
||||||
|
|||||||
@@ -137,7 +137,7 @@ class TestModelRestrictionService:
|
|||||||
|
|
||||||
# Create mock provider with known models
|
# Create mock provider with known models
|
||||||
mock_provider = MagicMock()
|
mock_provider = MagicMock()
|
||||||
mock_provider.SUPPORTED_MODELS = {
|
mock_provider.MODEL_CAPABILITIES = {
|
||||||
"o3": {"context_window": 200000},
|
"o3": {"context_window": 200000},
|
||||||
"o3-mini": {"context_window": 200000},
|
"o3-mini": {"context_window": 200000},
|
||||||
"o4-mini": {"context_window": 200000},
|
"o4-mini": {"context_window": 200000},
|
||||||
@@ -441,7 +441,7 @@ class TestRegistryIntegration:
|
|||||||
|
|
||||||
# Mock providers
|
# Mock providers
|
||||||
mock_openai = MagicMock()
|
mock_openai = MagicMock()
|
||||||
mock_openai.SUPPORTED_MODELS = {
|
mock_openai.MODEL_CAPABILITIES = {
|
||||||
"o3": {"context_window": 200000},
|
"o3": {"context_window": 200000},
|
||||||
"o3-mini": {"context_window": 200000},
|
"o3-mini": {"context_window": 200000},
|
||||||
}
|
}
|
||||||
@@ -452,7 +452,7 @@ class TestRegistryIntegration:
|
|||||||
|
|
||||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||||
models = []
|
models = []
|
||||||
for model_name, config in mock_openai.SUPPORTED_MODELS.items():
|
for model_name, config in mock_openai.MODEL_CAPABILITIES.items():
|
||||||
if isinstance(config, str):
|
if isinstance(config, str):
|
||||||
target_model = config
|
target_model = config
|
||||||
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model):
|
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model):
|
||||||
@@ -468,7 +468,7 @@ class TestRegistryIntegration:
|
|||||||
mock_openai.list_all_known_models.return_value = ["o3", "o3-mini"]
|
mock_openai.list_all_known_models.return_value = ["o3", "o3-mini"]
|
||||||
|
|
||||||
mock_gemini = MagicMock()
|
mock_gemini = MagicMock()
|
||||||
mock_gemini.SUPPORTED_MODELS = {
|
mock_gemini.MODEL_CAPABILITIES = {
|
||||||
"gemini-2.5-pro": {"context_window": 1048576},
|
"gemini-2.5-pro": {"context_window": 1048576},
|
||||||
"gemini-2.5-flash": {"context_window": 1048576},
|
"gemini-2.5-flash": {"context_window": 1048576},
|
||||||
}
|
}
|
||||||
@@ -479,7 +479,7 @@ class TestRegistryIntegration:
|
|||||||
|
|
||||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||||
models = []
|
models = []
|
||||||
for model_name, config in mock_gemini.SUPPORTED_MODELS.items():
|
for model_name, config in mock_gemini.MODEL_CAPABILITIES.items():
|
||||||
if isinstance(config, str):
|
if isinstance(config, str):
|
||||||
target_model = config
|
target_model = config
|
||||||
if restriction_service and not restriction_service.is_allowed(ProviderType.GOOGLE, target_model):
|
if restriction_service and not restriction_service.is_allowed(ProviderType.GOOGLE, target_model):
|
||||||
@@ -608,7 +608,7 @@ class TestAutoModeWithRestrictions:
|
|||||||
|
|
||||||
# Mock providers
|
# Mock providers
|
||||||
mock_openai = MagicMock()
|
mock_openai = MagicMock()
|
||||||
mock_openai.SUPPORTED_MODELS = {
|
mock_openai.MODEL_CAPABILITIES = {
|
||||||
"o3": {"context_window": 200000},
|
"o3": {"context_window": 200000},
|
||||||
"o3-mini": {"context_window": 200000},
|
"o3-mini": {"context_window": 200000},
|
||||||
"o4-mini": {"context_window": 200000},
|
"o4-mini": {"context_window": 200000},
|
||||||
@@ -620,7 +620,7 @@ class TestAutoModeWithRestrictions:
|
|||||||
|
|
||||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||||
models = []
|
models = []
|
||||||
for model_name, config in mock_openai.SUPPORTED_MODELS.items():
|
for model_name, config in mock_openai.MODEL_CAPABILITIES.items():
|
||||||
if isinstance(config, str):
|
if isinstance(config, str):
|
||||||
target_model = config
|
target_model = config
|
||||||
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model):
|
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model):
|
||||||
|
|||||||
@@ -205,7 +205,7 @@ class TestO3TemperatureParameterFixSimple:
|
|||||||
), f"Model {model} capabilities should have supports_temperature field"
|
), f"Model {model} capabilities should have supports_temperature field"
|
||||||
assert capabilities.supports_temperature is True, f"Model {model} should have supports_temperature=True"
|
assert capabilities.supports_temperature is True, f"Model {model} should have supports_temperature=True"
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# Skip if model not in SUPPORTED_MODELS (that's okay for this test)
|
# Skip if model not in MODEL_CAPABILITIES (that's okay for this test)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@patch("utils.model_restrictions.get_restriction_service")
|
@patch("utils.model_restrictions.get_restriction_service")
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class TestOldBehaviorSimulation:
|
|||||||
"""
|
"""
|
||||||
# Create a mock provider that simulates the OLD BROKEN BEHAVIOR
|
# Create a mock provider that simulates the OLD BROKEN BEHAVIOR
|
||||||
old_broken_provider = MagicMock()
|
old_broken_provider = MagicMock()
|
||||||
old_broken_provider.SUPPORTED_MODELS = {
|
old_broken_provider.MODEL_CAPABILITIES = {
|
||||||
"mini": "o4-mini", # alias -> target
|
"mini": "o4-mini", # alias -> target
|
||||||
"o3mini": "o3-mini", # alias -> target
|
"o3mini": "o3-mini", # alias -> target
|
||||||
"o4-mini": {"context_window": 200000},
|
"o4-mini": {"context_window": 200000},
|
||||||
@@ -73,7 +73,7 @@ class TestOldBehaviorSimulation:
|
|||||||
"""
|
"""
|
||||||
# Create mock provider with NEW FIXED BEHAVIOR
|
# Create mock provider with NEW FIXED BEHAVIOR
|
||||||
new_fixed_provider = MagicMock()
|
new_fixed_provider = MagicMock()
|
||||||
new_fixed_provider.SUPPORTED_MODELS = {
|
new_fixed_provider.MODEL_CAPABILITIES = {
|
||||||
"mini": "o4-mini",
|
"mini": "o4-mini",
|
||||||
"o3mini": "o3-mini",
|
"o3mini": "o3-mini",
|
||||||
"o4-mini": {"context_window": 200000},
|
"o4-mini": {"context_window": 200000},
|
||||||
@@ -203,14 +203,14 @@ class TestOldBehaviorSimulation:
|
|||||||
for provider in providers:
|
for provider in providers:
|
||||||
all_known = provider.list_all_known_models()
|
all_known = provider.list_all_known_models()
|
||||||
|
|
||||||
# Check that for every alias in SUPPORTED_MODELS, its target is also included
|
# Check that every model and its aliases appear in the comprehensive list
|
||||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
for model_name, config in provider.MODEL_CAPABILITIES.items():
|
||||||
# Model name itself should be in the list
|
|
||||||
assert model_name.lower() in all_known, f"{provider.__class__.__name__}: Missing model {model_name}"
|
assert model_name.lower() in all_known, f"{provider.__class__.__name__}: Missing model {model_name}"
|
||||||
|
|
||||||
# If it's an alias (config is a string), target should also be in list
|
for alias in getattr(config, "aliases", []):
|
||||||
if isinstance(config, str):
|
|
||||||
target_model = config
|
|
||||||
assert (
|
assert (
|
||||||
target_model.lower() in all_known
|
alias.lower() in all_known
|
||||||
), f"{provider.__class__.__name__}: Missing target {target_model} for alias {model_name}"
|
), f"{provider.__class__.__name__}: Missing alias {alias} for model {model_name}"
|
||||||
|
assert (
|
||||||
|
provider._resolve_model_name(alias) == model_name
|
||||||
|
), f"{provider.__class__.__name__}: Alias {alias} should resolve to {model_name}"
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ class TestOpenAICompatibleTokenUsage(unittest.TestCase):
|
|||||||
# Create a concrete implementation for testing
|
# Create a concrete implementation for testing
|
||||||
class TestProvider(OpenAICompatibleProvider):
|
class TestProvider(OpenAICompatibleProvider):
|
||||||
FRIENDLY_NAME = "Test"
|
FRIENDLY_NAME = "Test"
|
||||||
SUPPORTED_MODELS = {"test-model": {"context_window": 4096}}
|
MODEL_CAPABILITIES = {"test-model": {"context_window": 4096}}
|
||||||
|
|
||||||
def get_capabilities(self, model_name):
|
def get_capabilities(self, model_name):
|
||||||
return Mock()
|
return Mock()
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Test the SUPPORTED_MODELS aliases structure across all providers."""
|
"""Test the MODEL_CAPABILITIES aliases structure across all providers."""
|
||||||
|
|
||||||
from providers.dial import DIALModelProvider
|
from providers.dial import DIALModelProvider
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
@@ -7,24 +7,24 @@ from providers.xai import XAIModelProvider
|
|||||||
|
|
||||||
|
|
||||||
class TestSupportedModelsAliases:
|
class TestSupportedModelsAliases:
|
||||||
"""Test that all providers have correctly structured SUPPORTED_MODELS with aliases."""
|
"""Test that all providers have correctly structured MODEL_CAPABILITIES with aliases."""
|
||||||
|
|
||||||
def test_gemini_provider_aliases(self):
|
def test_gemini_provider_aliases(self):
|
||||||
"""Test Gemini provider's alias structure."""
|
"""Test Gemini provider's alias structure."""
|
||||||
provider = GeminiModelProvider("test-key")
|
provider = GeminiModelProvider("test-key")
|
||||||
|
|
||||||
# Check that all models have ModelCapabilities with aliases
|
# Check that all models have ModelCapabilities with aliases
|
||||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
for model_name, config in provider.MODEL_CAPABILITIES.items():
|
||||||
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
||||||
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
||||||
|
|
||||||
# Test specific aliases
|
# Test specific aliases
|
||||||
assert "flash" in provider.SUPPORTED_MODELS["gemini-2.5-flash"].aliases
|
assert "flash" in provider.MODEL_CAPABILITIES["gemini-2.5-flash"].aliases
|
||||||
assert "pro" in provider.SUPPORTED_MODELS["gemini-2.5-pro"].aliases
|
assert "pro" in provider.MODEL_CAPABILITIES["gemini-2.5-pro"].aliases
|
||||||
assert "flash-2.0" in provider.SUPPORTED_MODELS["gemini-2.0-flash"].aliases
|
assert "flash-2.0" in provider.MODEL_CAPABILITIES["gemini-2.0-flash"].aliases
|
||||||
assert "flash2" in provider.SUPPORTED_MODELS["gemini-2.0-flash"].aliases
|
assert "flash2" in provider.MODEL_CAPABILITIES["gemini-2.0-flash"].aliases
|
||||||
assert "flashlite" in provider.SUPPORTED_MODELS["gemini-2.0-flash-lite"].aliases
|
assert "flashlite" in provider.MODEL_CAPABILITIES["gemini-2.0-flash-lite"].aliases
|
||||||
assert "flash-lite" in provider.SUPPORTED_MODELS["gemini-2.0-flash-lite"].aliases
|
assert "flash-lite" in provider.MODEL_CAPABILITIES["gemini-2.0-flash-lite"].aliases
|
||||||
|
|
||||||
# Test alias resolution
|
# Test alias resolution
|
||||||
assert provider._resolve_model_name("flash") == "gemini-2.5-flash"
|
assert provider._resolve_model_name("flash") == "gemini-2.5-flash"
|
||||||
@@ -42,18 +42,18 @@ class TestSupportedModelsAliases:
|
|||||||
provider = OpenAIModelProvider("test-key")
|
provider = OpenAIModelProvider("test-key")
|
||||||
|
|
||||||
# Check that all models have ModelCapabilities with aliases
|
# Check that all models have ModelCapabilities with aliases
|
||||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
for model_name, config in provider.MODEL_CAPABILITIES.items():
|
||||||
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
||||||
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
||||||
|
|
||||||
# Test specific aliases
|
# Test specific aliases
|
||||||
# "mini" is now an alias for gpt-5-mini, not o4-mini
|
# "mini" is now an alias for gpt-5-mini, not o4-mini
|
||||||
assert "mini" in provider.SUPPORTED_MODELS["gpt-5-mini"].aliases
|
assert "mini" in provider.MODEL_CAPABILITIES["gpt-5-mini"].aliases
|
||||||
assert "o4mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases
|
assert "o4mini" in provider.MODEL_CAPABILITIES["o4-mini"].aliases
|
||||||
# o4-mini is no longer in its own aliases (removed self-reference)
|
# o4-mini is no longer in its own aliases (removed self-reference)
|
||||||
assert "o3mini" in provider.SUPPORTED_MODELS["o3-mini"].aliases
|
assert "o3mini" in provider.MODEL_CAPABILITIES["o3-mini"].aliases
|
||||||
assert "o3pro" in provider.SUPPORTED_MODELS["o3-pro"].aliases
|
assert "o3pro" in provider.MODEL_CAPABILITIES["o3-pro"].aliases
|
||||||
assert "gpt4.1" in provider.SUPPORTED_MODELS["gpt-4.1"].aliases
|
assert "gpt4.1" in provider.MODEL_CAPABILITIES["gpt-4.1"].aliases
|
||||||
|
|
||||||
# Test alias resolution
|
# Test alias resolution
|
||||||
assert provider._resolve_model_name("mini") == "gpt-5-mini" # mini -> gpt-5-mini now
|
assert provider._resolve_model_name("mini") == "gpt-5-mini" # mini -> gpt-5-mini now
|
||||||
@@ -71,16 +71,16 @@ class TestSupportedModelsAliases:
|
|||||||
provider = XAIModelProvider("test-key")
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
# Check that all models have ModelCapabilities with aliases
|
# Check that all models have ModelCapabilities with aliases
|
||||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
for model_name, config in provider.MODEL_CAPABILITIES.items():
|
||||||
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
||||||
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
||||||
|
|
||||||
# Test specific aliases
|
# Test specific aliases
|
||||||
assert "grok" in provider.SUPPORTED_MODELS["grok-4"].aliases
|
assert "grok" in provider.MODEL_CAPABILITIES["grok-4"].aliases
|
||||||
assert "grok4" in provider.SUPPORTED_MODELS["grok-4"].aliases
|
assert "grok4" in provider.MODEL_CAPABILITIES["grok-4"].aliases
|
||||||
assert "grok3" in provider.SUPPORTED_MODELS["grok-3"].aliases
|
assert "grok3" in provider.MODEL_CAPABILITIES["grok-3"].aliases
|
||||||
assert "grok3fast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases
|
assert "grok3fast" in provider.MODEL_CAPABILITIES["grok-3-fast"].aliases
|
||||||
assert "grokfast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases
|
assert "grokfast" in provider.MODEL_CAPABILITIES["grok-3-fast"].aliases
|
||||||
|
|
||||||
# Test alias resolution
|
# Test alias resolution
|
||||||
assert provider._resolve_model_name("grok") == "grok-4"
|
assert provider._resolve_model_name("grok") == "grok-4"
|
||||||
@@ -98,16 +98,16 @@ class TestSupportedModelsAliases:
|
|||||||
provider = DIALModelProvider("test-key")
|
provider = DIALModelProvider("test-key")
|
||||||
|
|
||||||
# Check that all models have ModelCapabilities with aliases
|
# Check that all models have ModelCapabilities with aliases
|
||||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
for model_name, config in provider.MODEL_CAPABILITIES.items():
|
||||||
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
||||||
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
||||||
|
|
||||||
# Test specific aliases
|
# Test specific aliases
|
||||||
assert "o3" in provider.SUPPORTED_MODELS["o3-2025-04-16"].aliases
|
assert "o3" in provider.MODEL_CAPABILITIES["o3-2025-04-16"].aliases
|
||||||
assert "o4-mini" in provider.SUPPORTED_MODELS["o4-mini-2025-04-16"].aliases
|
assert "o4-mini" in provider.MODEL_CAPABILITIES["o4-mini-2025-04-16"].aliases
|
||||||
assert "sonnet-4.1" in provider.SUPPORTED_MODELS["anthropic.claude-sonnet-4.1-20250805-v1:0"].aliases
|
assert "sonnet-4.1" in provider.MODEL_CAPABILITIES["anthropic.claude-sonnet-4.1-20250805-v1:0"].aliases
|
||||||
assert "opus-4.1" in provider.SUPPORTED_MODELS["anthropic.claude-opus-4.1-20250805-v1:0"].aliases
|
assert "opus-4.1" in provider.MODEL_CAPABILITIES["anthropic.claude-opus-4.1-20250805-v1:0"].aliases
|
||||||
assert "gemini-2.5-pro" in provider.SUPPORTED_MODELS["gemini-2.5-pro-preview-05-06"].aliases
|
assert "gemini-2.5-pro" in provider.MODEL_CAPABILITIES["gemini-2.5-pro-preview-05-06"].aliases
|
||||||
|
|
||||||
# Test alias resolution
|
# Test alias resolution
|
||||||
assert provider._resolve_model_name("o3") == "o3-2025-04-16"
|
assert provider._resolve_model_name("o3") == "o3-2025-04-16"
|
||||||
@@ -183,12 +183,12 @@ class TestSupportedModelsAliases:
|
|||||||
]
|
]
|
||||||
|
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
for model_name, config in provider.MODEL_CAPABILITIES.items():
|
||||||
# All values must be ModelCapabilities objects, not strings or dicts
|
# All values must be ModelCapabilities objects, not strings or dicts
|
||||||
from providers.shared import ModelCapabilities
|
from providers.shared import ModelCapabilities
|
||||||
|
|
||||||
assert isinstance(config, ModelCapabilities), (
|
assert isinstance(config, ModelCapabilities), (
|
||||||
f"{provider.__class__.__name__}.SUPPORTED_MODELS['{model_name}'] "
|
f"{provider.__class__.__name__}.MODEL_CAPABILITIES['{model_name}'] "
|
||||||
f"must be a ModelCapabilities object, not {type(config).__name__}"
|
f"must be a ModelCapabilities object, not {type(config).__name__}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -256,18 +256,18 @@ class TestXAIProvider:
|
|||||||
assert capabilities.friendly_name == "X.AI (Grok 3)"
|
assert capabilities.friendly_name == "X.AI (Grok 3)"
|
||||||
|
|
||||||
def test_supported_models_structure(self):
|
def test_supported_models_structure(self):
|
||||||
"""Test that SUPPORTED_MODELS has the correct structure."""
|
"""Test that MODEL_CAPABILITIES has the correct structure."""
|
||||||
provider = XAIModelProvider("test-key")
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
# Check that all expected base models are present
|
# Check that all expected base models are present
|
||||||
assert "grok-4" in provider.SUPPORTED_MODELS
|
assert "grok-4" in provider.MODEL_CAPABILITIES
|
||||||
assert "grok-3" in provider.SUPPORTED_MODELS
|
assert "grok-3" in provider.MODEL_CAPABILITIES
|
||||||
assert "grok-3-fast" in provider.SUPPORTED_MODELS
|
assert "grok-3-fast" in provider.MODEL_CAPABILITIES
|
||||||
|
|
||||||
# Check model configs have required fields
|
# Check model configs have required fields
|
||||||
from providers.shared import ModelCapabilities
|
from providers.shared import ModelCapabilities
|
||||||
|
|
||||||
grok4_config = provider.SUPPORTED_MODELS["grok-4"]
|
grok4_config = provider.MODEL_CAPABILITIES["grok-4"]
|
||||||
assert isinstance(grok4_config, ModelCapabilities)
|
assert isinstance(grok4_config, ModelCapabilities)
|
||||||
assert hasattr(grok4_config, "context_window")
|
assert hasattr(grok4_config, "context_window")
|
||||||
assert hasattr(grok4_config, "supports_extended_thinking")
|
assert hasattr(grok4_config, "supports_extended_thinking")
|
||||||
@@ -280,18 +280,18 @@ class TestXAIProvider:
|
|||||||
assert "grok-4" in grok4_config.aliases
|
assert "grok-4" in grok4_config.aliases
|
||||||
assert "grok4" in grok4_config.aliases
|
assert "grok4" in grok4_config.aliases
|
||||||
|
|
||||||
grok3_config = provider.SUPPORTED_MODELS["grok-3"]
|
grok3_config = provider.MODEL_CAPABILITIES["grok-3"]
|
||||||
assert grok3_config.context_window == 131_072
|
assert grok3_config.context_window == 131_072
|
||||||
assert grok3_config.supports_extended_thinking is False
|
assert grok3_config.supports_extended_thinking is False
|
||||||
# Check aliases are correctly structured
|
# Check aliases are correctly structured
|
||||||
assert "grok3" in grok3_config.aliases # grok3 resolves to grok-3
|
assert "grok3" in grok3_config.aliases # grok3 resolves to grok-3
|
||||||
|
|
||||||
# Check grok-4 aliases
|
# Check grok-4 aliases
|
||||||
grok4_config = provider.SUPPORTED_MODELS["grok-4"]
|
grok4_config = provider.MODEL_CAPABILITIES["grok-4"]
|
||||||
assert "grok" in grok4_config.aliases # grok resolves to grok-4
|
assert "grok" in grok4_config.aliases # grok resolves to grok-4
|
||||||
assert "grok4" in grok4_config.aliases
|
assert "grok4" in grok4_config.aliases
|
||||||
|
|
||||||
grok3fast_config = provider.SUPPORTED_MODELS["grok-3-fast"]
|
grok3fast_config = provider.MODEL_CAPABILITIES["grok-3-fast"]
|
||||||
assert "grok3fast" in grok3fast_config.aliases
|
assert "grok3fast" in grok3fast_config.aliases
|
||||||
assert "grokfast" in grok3fast_config.aliases
|
assert "grokfast" in grok3fast_config.aliases
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user