diff --git a/README.md b/README.md index 6876333..47fbe59 100644 --- a/README.md +++ b/README.md @@ -129,7 +129,7 @@ cd zen-mcp-server "Plan the migration strategy with zen, get consensus from multiple models" ``` -👉 **[Complete Setup Guide](docs/getting-started.md)** with detailed installation, configuration for Gemini / Codex, and troubleshooting +👉 **[Complete Setup Guide](docs/getting-started.md)** with detailed installation, configuration for Gemini / Codex, and troubleshooting 👉 **[Cursor & VS Code Setup](docs/getting-started.md#ide-clients)** for IDE integration instructions ## Core Tools diff --git a/config.py b/config.py index 0a0162e..9b71eb4 100644 --- a/config.py +++ b/config.py @@ -30,7 +30,7 @@ DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "auto") # Auto mode detection - when DEFAULT_MODEL is "auto", Claude picks the model IS_AUTO_MODE = DEFAULT_MODEL.lower() == "auto" -# Each provider (gemini.py, openai_provider.py, xai.py) defines its own SUPPORTED_MODELS +# Each provider (gemini.py, openai_provider.py, xai.py) defines its own MODEL_CAPABILITIES # with detailed descriptions. Tools use ModelProviderRegistry.get_available_model_names() # to get models only from enabled providers (those with valid API keys). # diff --git a/docs/adding_providers.md b/docs/adding_providers.md index 1c14571..d0dec90 100644 --- a/docs/adding_providers.md +++ b/docs/adding_providers.md @@ -28,7 +28,7 @@ Each provider: ### 1. Add Provider Type -Add your provider to `ProviderType` enum in `providers/base.py`: +Add your provider to the `ProviderType` enum in `providers/shared/provider_type.py`: ```python class ProviderType(Enum): @@ -48,15 +48,23 @@ Create `providers/example.py`: import logging from typing import Optional -from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, RangeTemperatureConstraint + +from .base import ModelProvider +from .shared import ( + ModelCapabilities, + ModelResponse, + ProviderType, + RangeTemperatureConstraint, +) logger = logging.getLogger(__name__) + class ExampleModelProvider(ModelProvider): """Example model provider implementation.""" # Define models using ModelCapabilities objects (like Gemini provider) - SUPPORTED_MODELS = { + MODEL_CAPABILITIES = { "example-large": ModelCapabilities( provider=ProviderType.EXAMPLE, model_name="example-large", @@ -87,7 +95,7 @@ class ExampleModelProvider(ModelProvider): def get_capabilities(self, model_name: str) -> ModelCapabilities: resolved_name = self._resolve_model_name(model_name) - if resolved_name not in self.SUPPORTED_MODELS: + if resolved_name not in self.MODEL_CAPABILITIES: raise ValueError(f"Unsupported model: {model_name}") # Apply restrictions if needed @@ -96,7 +104,7 @@ class ExampleModelProvider(ModelProvider): if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name): raise ValueError(f"Model '{model_name}' is not allowed.") - return self.SUPPORTED_MODELS[resolved_name] + return self.MODEL_CAPABILITIES[resolved_name] def generate_content(self, prompt: str, model_name: str, system_prompt: Optional[str] = None, temperature: float = 0.7, max_output_tokens: Optional[int] = None, **kwargs) -> ModelResponse: @@ -121,7 +129,7 @@ class ExampleModelProvider(ModelProvider): def validate_model_name(self, model_name: str) -> bool: resolved_name = self._resolve_model_name(model_name) - return resolved_name in self.SUPPORTED_MODELS + return resolved_name in self.MODEL_CAPABILITIES def supports_thinking_mode(self, model_name: str) -> bool: capabilities = self.get_capabilities(model_name) @@ -136,8 +144,15 @@ For OpenAI-compatible APIs: """Example OpenAI-compatible provider.""" from typing import Optional -from .base import ModelCapabilities, ModelResponse, ProviderType, RangeTemperatureConstraint + from .openai_compatible import OpenAICompatibleProvider +from .shared import ( + ModelCapabilities, + ModelResponse, + ProviderType, + RangeTemperatureConstraint, +) + class ExampleProvider(OpenAICompatibleProvider): """Example OpenAI-compatible provider.""" @@ -145,7 +160,7 @@ class ExampleProvider(OpenAICompatibleProvider): FRIENDLY_NAME = "Example" # Define models using ModelCapabilities (consistent with other providers) - SUPPORTED_MODELS = { + MODEL_CAPABILITIES = { "example-model-large": ModelCapabilities( provider=ProviderType.EXAMPLE, model_name="example-model-large", @@ -163,16 +178,16 @@ class ExampleProvider(OpenAICompatibleProvider): def get_capabilities(self, model_name: str) -> ModelCapabilities: resolved_name = self._resolve_model_name(model_name) - if resolved_name not in self.SUPPORTED_MODELS: + if resolved_name not in self.MODEL_CAPABILITIES: raise ValueError(f"Unsupported model: {model_name}") - return self.SUPPORTED_MODELS[resolved_name] + return self.MODEL_CAPABILITIES[resolved_name] def get_provider_type(self) -> ProviderType: return ProviderType.EXAMPLE def validate_model_name(self, model_name: str) -> bool: resolved_name = self._resolve_model_name(model_name) - return resolved_name in self.SUPPORTED_MODELS + return resolved_name in self.MODEL_CAPABILITIES def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse: # IMPORTANT: Resolve aliases before API call @@ -185,12 +200,8 @@ class ExampleProvider(OpenAICompatibleProvider): Add environment variable mapping in `providers/registry.py`: ```python -# In _get_api_key_for_provider method: -key_mapping = { - ProviderType.GOOGLE: "GEMINI_API_KEY", - ProviderType.OPENAI: "OPENAI_API_KEY", - ProviderType.EXAMPLE: "EXAMPLE_API_KEY", # Add this -} +# In _get_api_key_for_provider (providers/registry.py), add: + ProviderType.EXAMPLE: "EXAMPLE_API_KEY", ``` Add to `server.py`: @@ -209,16 +220,7 @@ if example_key: logger.info("Example API key found - Example models available") ``` -3. **Add to provider priority** (in `providers/registry.py`): -```python -PROVIDER_PRIORITY_ORDER = [ - ProviderType.GOOGLE, - ProviderType.OPENAI, - ProviderType.EXAMPLE, # Add your provider here - ProviderType.CUSTOM, # Local models - ProviderType.OPENROUTER, # Catch-all (keep last) -] -``` +3. **Add to provider priority** (edit `ModelProviderRegistry.PROVIDER_PRIORITY_ORDER` in `providers/registry.py`): insert your provider in the list at the appropriate point in the cascade of native → custom → catch-all providers. ### 4. Environment Configuration @@ -265,7 +267,7 @@ Your `validate_model_name()` should **only** return `True` for models you explic ```python def validate_model_name(self, model_name: str) -> bool: resolved_name = self._resolve_model_name(model_name) - return resolved_name in self.SUPPORTED_MODELS # Be specific! + return resolved_name in self.MODEL_CAPABILITIES # Be specific! ``` ### Model Aliases @@ -296,7 +298,7 @@ Without this, API calls with aliases like `"large"` will fail because your API d ## Quick Checklist -- [ ] Added to `ProviderType` enum in `providers/base.py` +- [ ] Added to `ProviderType` enum in `providers/shared/provider_type.py` - [ ] Created provider class with all required methods - [ ] Added API key mapping in `providers/registry.py` - [ ] Added to provider priority order in `registry.py` @@ -307,8 +309,6 @@ Without this, API calls with aliases like `"large"` will fail because your API d ## Examples See existing implementations: -- **Full provider**: `providers/gemini.py` +- **Full provider**: `providers/gemini.py` - **OpenAI-compatible**: `providers/custom.py` - **Base classes**: `providers/base.py` - -The modern approach uses `ModelCapabilities` objects directly in `SUPPORTED_MODELS`, making the implementation much cleaner and more consistent. \ No newline at end of file diff --git a/providers/base.py b/providers/base.py index 4efe9d9..1f69ca0 100644 --- a/providers/base.py +++ b/providers/base.py @@ -28,7 +28,7 @@ class ModelProvider(ABC): """ # All concrete providers must define their supported models - SUPPORTED_MODELS: dict[str, Any] = {} + MODEL_CAPABILITIES: dict[str, Any] = {} # Default maximum image size in MB DEFAULT_MAX_IMAGE_SIZE_MB = 20.0 @@ -147,9 +147,9 @@ class ModelProvider(ABC): Returns: Dictionary mapping model names to their ModelCapabilities objects """ - # Return SUPPORTED_MODELS if it exists (must contain ModelCapabilities objects) - if hasattr(self, "SUPPORTED_MODELS"): - return {k: v for k, v in self.SUPPORTED_MODELS.items() if isinstance(v, ModelCapabilities)} + model_map = getattr(self, "MODEL_CAPABILITIES", None) + if isinstance(model_map, dict) and model_map: + return {k: v for k, v in model_map.items() if isinstance(v, ModelCapabilities)} return {} def _resolve_model_name(self, model_name: str) -> str: diff --git a/providers/dial.py b/providers/dial.py index 59910cc..65f80fe 100644 --- a/providers/dial.py +++ b/providers/dial.py @@ -33,7 +33,7 @@ class DIALModelProvider(OpenAICompatibleProvider): RETRY_DELAYS = [1, 3, 5, 8] # seconds # Model configurations using ModelCapabilities objects - SUPPORTED_MODELS = { + MODEL_CAPABILITIES = { "o3-2025-04-16": ModelCapabilities( provider=ProviderType.DIAL, model_name="o3-2025-04-16", @@ -280,7 +280,7 @@ class DIALModelProvider(OpenAICompatibleProvider): """ resolved_name = self._resolve_model_name(model_name) - if resolved_name not in self.SUPPORTED_MODELS: + if resolved_name not in self.MODEL_CAPABILITIES: raise ValueError(f"Unsupported DIAL model: {model_name}") # Check restrictions @@ -290,8 +290,8 @@ class DIALModelProvider(OpenAICompatibleProvider): if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name): raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.") - # Return the ModelCapabilities object directly from SUPPORTED_MODELS - return self.SUPPORTED_MODELS[resolved_name] + # Return the ModelCapabilities object directly from MODEL_CAPABILITIES + return self.MODEL_CAPABILITIES[resolved_name] def get_provider_type(self) -> ProviderType: """Get the provider type.""" @@ -308,7 +308,7 @@ class DIALModelProvider(OpenAICompatibleProvider): """ resolved_name = self._resolve_model_name(model_name) - if resolved_name not in self.SUPPORTED_MODELS: + if resolved_name not in self.MODEL_CAPABILITIES: return False # Check against base class allowed_models if configured diff --git a/providers/gemini.py b/providers/gemini.py index 44f947d..2a79fce 100644 --- a/providers/gemini.py +++ b/providers/gemini.py @@ -31,7 +31,7 @@ class GeminiModelProvider(ModelProvider): """ # Model configurations using ModelCapabilities objects - SUPPORTED_MODELS = { + MODEL_CAPABILITIES = { "gemini-2.5-pro": ModelCapabilities( provider=ProviderType.GOOGLE, model_name="gemini-2.5-pro", @@ -154,7 +154,7 @@ class GeminiModelProvider(ModelProvider): # Resolve shorthand resolved_name = self._resolve_model_name(model_name) - if resolved_name not in self.SUPPORTED_MODELS: + if resolved_name not in self.MODEL_CAPABILITIES: raise ValueError(f"Unsupported Gemini model: {model_name}") # Check if model is allowed by restrictions @@ -166,8 +166,8 @@ class GeminiModelProvider(ModelProvider): if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name): raise ValueError(f"Gemini model '{resolved_name}' is not allowed by restriction policy.") - # Return the ModelCapabilities object directly from SUPPORTED_MODELS - return self.SUPPORTED_MODELS[resolved_name] + # Return the ModelCapabilities object directly from MODEL_CAPABILITIES + return self.MODEL_CAPABILITIES[resolved_name] def generate_content( self, @@ -227,7 +227,7 @@ class GeminiModelProvider(ModelProvider): # Add thinking configuration for models that support it if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS: # Get model's max thinking tokens and calculate actual budget - model_config = self.SUPPORTED_MODELS.get(resolved_name) + model_config = self.MODEL_CAPABILITIES.get(resolved_name) if model_config and model_config.max_thinking_tokens > 0: max_thinking_tokens = model_config.max_thinking_tokens actual_thinking_budget = int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode]) @@ -382,7 +382,7 @@ class GeminiModelProvider(ModelProvider): resolved_name = self._resolve_model_name(model_name) # First check if model is supported - if resolved_name not in self.SUPPORTED_MODELS: + if resolved_name not in self.MODEL_CAPABILITIES: return False # Then check if model is allowed by restrictions @@ -405,7 +405,7 @@ class GeminiModelProvider(ModelProvider): def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int: """Get actual thinking token budget for a model and thinking mode.""" resolved_name = self._resolve_model_name(model_name) - model_config = self.SUPPORTED_MODELS.get(resolved_name) + model_config = self.MODEL_CAPABILITIES.get(resolved_name) if not model_config or not model_config.supports_extended_thinking: return 0 @@ -584,7 +584,7 @@ class GeminiModelProvider(ModelProvider): pro_thinking = [ m for m in allowed_models - if "pro" in m and m in self.SUPPORTED_MODELS and self.SUPPORTED_MODELS[m].supports_extended_thinking + if "pro" in m and m in self.MODEL_CAPABILITIES and self.MODEL_CAPABILITIES[m].supports_extended_thinking ] if pro_thinking: return find_best(pro_thinking) @@ -593,7 +593,7 @@ class GeminiModelProvider(ModelProvider): any_thinking = [ m for m in allowed_models - if m in self.SUPPORTED_MODELS and self.SUPPORTED_MODELS[m].supports_extended_thinking + if m in self.MODEL_CAPABILITIES and self.MODEL_CAPABILITIES[m].supports_extended_thinking ] if any_thinking: return find_best(any_thinking) diff --git a/providers/openai_provider.py b/providers/openai_provider.py index 55cb657..8040472 100644 --- a/providers/openai_provider.py +++ b/providers/openai_provider.py @@ -26,7 +26,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider): """ # Model configurations using ModelCapabilities objects - SUPPORTED_MODELS = { + MODEL_CAPABILITIES = { "gpt-5": ModelCapabilities( provider=ProviderType.OPENAI, model_name="gpt-5", @@ -181,21 +181,21 @@ class OpenAIModelProvider(OpenAICompatibleProvider): def get_capabilities(self, model_name: str) -> ModelCapabilities: """Get capabilities for a specific OpenAI model.""" - # First check if it's a key in SUPPORTED_MODELS - if model_name in self.SUPPORTED_MODELS: + # First check if it's a key in MODEL_CAPABILITIES + if model_name in self.MODEL_CAPABILITIES: self._check_model_restrictions(model_name, model_name) - return self.SUPPORTED_MODELS[model_name] + return self.MODEL_CAPABILITIES[model_name] # Try resolving as alias resolved_name = self._resolve_model_name(model_name) # Check if resolved name is a key - if resolved_name in self.SUPPORTED_MODELS: + if resolved_name in self.MODEL_CAPABILITIES: self._check_model_restrictions(resolved_name, model_name) - return self.SUPPORTED_MODELS[resolved_name] + return self.MODEL_CAPABILITIES[resolved_name] # Finally check if resolved name matches any API model name - for key, capabilities in self.SUPPORTED_MODELS.items(): + for key, capabilities in self.MODEL_CAPABILITIES.items(): if resolved_name == capabilities.model_name: self._check_model_restrictions(key, model_name) return capabilities @@ -248,7 +248,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider): model_to_check = None is_custom_model = False - if resolved_name in self.SUPPORTED_MODELS: + if resolved_name in self.MODEL_CAPABILITIES: model_to_check = resolved_name else: # If not a built-in model, check the custom models registry. diff --git a/providers/registry.py b/providers/registry.py index c22cfcf..0783f8f 100644 --- a/providers/registry.py +++ b/providers/registry.py @@ -282,11 +282,9 @@ class ModelProviderRegistry: # Use list_models to get all supported models (handles both regular and custom providers) supported_models = provider.list_models(respect_restrictions=False) except (NotImplementedError, AttributeError): - # Fallback to SUPPORTED_MODELS if list_models not implemented - try: - supported_models = list(provider.SUPPORTED_MODELS.keys()) - except AttributeError: - supported_models = [] + # Fallback to provider-declared capability maps if list_models not implemented + model_map = getattr(provider, "MODEL_CAPABILITIES", None) + supported_models = list(model_map.keys()) if isinstance(model_map, dict) else [] # Filter by restrictions for model_name in supported_models: diff --git a/providers/xai.py b/providers/xai.py index 1d3e5db..e0daf2f 100644 --- a/providers/xai.py +++ b/providers/xai.py @@ -27,7 +27,7 @@ class XAIModelProvider(OpenAICompatibleProvider): FRIENDLY_NAME = "X.AI" # Model configurations using ModelCapabilities objects - SUPPORTED_MODELS = { + MODEL_CAPABILITIES = { "grok-4": ModelCapabilities( provider=ProviderType.XAI, model_name="grok-4", @@ -95,7 +95,7 @@ class XAIModelProvider(OpenAICompatibleProvider): # Resolve shorthand resolved_name = self._resolve_model_name(model_name) - if resolved_name not in self.SUPPORTED_MODELS: + if resolved_name not in self.MODEL_CAPABILITIES: raise ValueError(f"Unsupported X.AI model: {model_name}") # Check if model is allowed by restrictions @@ -105,8 +105,8 @@ class XAIModelProvider(OpenAICompatibleProvider): if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name): raise ValueError(f"X.AI model '{model_name}' is not allowed by restriction policy.") - # Return the ModelCapabilities object directly from SUPPORTED_MODELS - return self.SUPPORTED_MODELS[resolved_name] + # Return the ModelCapabilities object directly from MODEL_CAPABILITIES + return self.MODEL_CAPABILITIES[resolved_name] def get_provider_type(self) -> ProviderType: """Get the provider type.""" @@ -117,7 +117,7 @@ class XAIModelProvider(OpenAICompatibleProvider): resolved_name = self._resolve_model_name(model_name) # First check if model is supported - if resolved_name not in self.SUPPORTED_MODELS: + if resolved_name not in self.MODEL_CAPABILITIES: return False # Then check if model is allowed by restrictions @@ -156,7 +156,7 @@ class XAIModelProvider(OpenAICompatibleProvider): def supports_thinking_mode(self, model_name: str) -> bool: """Check if the model supports extended thinking mode.""" resolved_name = self._resolve_model_name(model_name) - capabilities = self.SUPPORTED_MODELS.get(resolved_name) + capabilities = self.MODEL_CAPABILITIES.get(resolved_name) if capabilities: return capabilities.supports_extended_thinking return False diff --git a/tests/test_alias_target_restrictions.py b/tests/test_alias_target_restrictions.py index 83ebeff..f3dbd82 100644 --- a/tests/test_alias_target_restrictions.py +++ b/tests/test_alias_target_restrictions.py @@ -165,7 +165,7 @@ class TestAliasTargetRestrictions: openai_all_known = openai_provider.list_all_known_models() # Verify that for each alias, its target is also included - for model_name, config in openai_provider.SUPPORTED_MODELS.items(): + for model_name, config in openai_provider.MODEL_CAPABILITIES.items(): assert model_name.lower() in openai_all_known if isinstance(config, str): # This is an alias # The target should also be in the known models @@ -178,7 +178,7 @@ class TestAliasTargetRestrictions: gemini_all_known = gemini_provider.list_all_known_models() # Verify that for each alias, its target is also included - for model_name, config in gemini_provider.SUPPORTED_MODELS.items(): + for model_name, config in gemini_provider.MODEL_CAPABILITIES.items(): assert model_name.lower() in gemini_all_known if isinstance(config, str): # This is an alias # The target should also be in the known models diff --git a/tests/test_auto_mode.py b/tests/test_auto_mode.py index d544e56..602aed9 100644 --- a/tests/test_auto_mode.py +++ b/tests/test_auto_mode.py @@ -53,7 +53,7 @@ class TestAutoMode: for provider_type in enabled_provider_types: provider = ModelProviderRegistry.get_provider(provider_type) if provider: - for model_name, config in provider.SUPPORTED_MODELS.items(): + for model_name, config in provider.MODEL_CAPABILITIES.items(): # Skip alias entries (string values) if isinstance(config, str): continue diff --git a/tests/test_buggy_behavior_prevention.py b/tests/test_buggy_behavior_prevention.py index 57cf204..bfc26a0 100644 --- a/tests/test_buggy_behavior_prevention.py +++ b/tests/test_buggy_behavior_prevention.py @@ -176,7 +176,7 @@ class TestBuggyBehaviorPrevention: # Create a mock provider that simulates the old behavior old_style_provider = MagicMock() - old_style_provider.SUPPORTED_MODELS = { + old_style_provider.MODEL_CAPABILITIES = { "mini": "o4-mini", "o3mini": "o3-mini", "o4-mini": {"context_window": 200000}, diff --git a/tests/test_model_restrictions.py b/tests/test_model_restrictions.py index 417ba07..f2eb430 100644 --- a/tests/test_model_restrictions.py +++ b/tests/test_model_restrictions.py @@ -137,7 +137,7 @@ class TestModelRestrictionService: # Create mock provider with known models mock_provider = MagicMock() - mock_provider.SUPPORTED_MODELS = { + mock_provider.MODEL_CAPABILITIES = { "o3": {"context_window": 200000}, "o3-mini": {"context_window": 200000}, "o4-mini": {"context_window": 200000}, @@ -441,7 +441,7 @@ class TestRegistryIntegration: # Mock providers mock_openai = MagicMock() - mock_openai.SUPPORTED_MODELS = { + mock_openai.MODEL_CAPABILITIES = { "o3": {"context_window": 200000}, "o3-mini": {"context_window": 200000}, } @@ -452,7 +452,7 @@ class TestRegistryIntegration: restriction_service = get_restriction_service() if respect_restrictions else None models = [] - for model_name, config in mock_openai.SUPPORTED_MODELS.items(): + for model_name, config in mock_openai.MODEL_CAPABILITIES.items(): if isinstance(config, str): target_model = config if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model): @@ -468,7 +468,7 @@ class TestRegistryIntegration: mock_openai.list_all_known_models.return_value = ["o3", "o3-mini"] mock_gemini = MagicMock() - mock_gemini.SUPPORTED_MODELS = { + mock_gemini.MODEL_CAPABILITIES = { "gemini-2.5-pro": {"context_window": 1048576}, "gemini-2.5-flash": {"context_window": 1048576}, } @@ -479,7 +479,7 @@ class TestRegistryIntegration: restriction_service = get_restriction_service() if respect_restrictions else None models = [] - for model_name, config in mock_gemini.SUPPORTED_MODELS.items(): + for model_name, config in mock_gemini.MODEL_CAPABILITIES.items(): if isinstance(config, str): target_model = config if restriction_service and not restriction_service.is_allowed(ProviderType.GOOGLE, target_model): @@ -608,7 +608,7 @@ class TestAutoModeWithRestrictions: # Mock providers mock_openai = MagicMock() - mock_openai.SUPPORTED_MODELS = { + mock_openai.MODEL_CAPABILITIES = { "o3": {"context_window": 200000}, "o3-mini": {"context_window": 200000}, "o4-mini": {"context_window": 200000}, @@ -620,7 +620,7 @@ class TestAutoModeWithRestrictions: restriction_service = get_restriction_service() if respect_restrictions else None models = [] - for model_name, config in mock_openai.SUPPORTED_MODELS.items(): + for model_name, config in mock_openai.MODEL_CAPABILITIES.items(): if isinstance(config, str): target_model = config if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model): diff --git a/tests/test_o3_temperature_fix_simple.py b/tests/test_o3_temperature_fix_simple.py index 4f1820e..6293822 100644 --- a/tests/test_o3_temperature_fix_simple.py +++ b/tests/test_o3_temperature_fix_simple.py @@ -205,7 +205,7 @@ class TestO3TemperatureParameterFixSimple: ), f"Model {model} capabilities should have supports_temperature field" assert capabilities.supports_temperature is True, f"Model {model} should have supports_temperature=True" except ValueError: - # Skip if model not in SUPPORTED_MODELS (that's okay for this test) + # Skip if model not in MODEL_CAPABILITIES (that's okay for this test) pass @patch("utils.model_restrictions.get_restriction_service") diff --git a/tests/test_old_behavior_simulation.py b/tests/test_old_behavior_simulation.py index 2918183..dc4719a 100644 --- a/tests/test_old_behavior_simulation.py +++ b/tests/test_old_behavior_simulation.py @@ -28,7 +28,7 @@ class TestOldBehaviorSimulation: """ # Create a mock provider that simulates the OLD BROKEN BEHAVIOR old_broken_provider = MagicMock() - old_broken_provider.SUPPORTED_MODELS = { + old_broken_provider.MODEL_CAPABILITIES = { "mini": "o4-mini", # alias -> target "o3mini": "o3-mini", # alias -> target "o4-mini": {"context_window": 200000}, @@ -73,7 +73,7 @@ class TestOldBehaviorSimulation: """ # Create mock provider with NEW FIXED BEHAVIOR new_fixed_provider = MagicMock() - new_fixed_provider.SUPPORTED_MODELS = { + new_fixed_provider.MODEL_CAPABILITIES = { "mini": "o4-mini", "o3mini": "o3-mini", "o4-mini": {"context_window": 200000}, @@ -203,14 +203,14 @@ class TestOldBehaviorSimulation: for provider in providers: all_known = provider.list_all_known_models() - # Check that for every alias in SUPPORTED_MODELS, its target is also included - for model_name, config in provider.SUPPORTED_MODELS.items(): - # Model name itself should be in the list + # Check that every model and its aliases appear in the comprehensive list + for model_name, config in provider.MODEL_CAPABILITIES.items(): assert model_name.lower() in all_known, f"{provider.__class__.__name__}: Missing model {model_name}" - # If it's an alias (config is a string), target should also be in list - if isinstance(config, str): - target_model = config + for alias in getattr(config, "aliases", []): assert ( - target_model.lower() in all_known - ), f"{provider.__class__.__name__}: Missing target {target_model} for alias {model_name}" + alias.lower() in all_known + ), f"{provider.__class__.__name__}: Missing alias {alias} for model {model_name}" + assert ( + provider._resolve_model_name(alias) == model_name + ), f"{provider.__class__.__name__}: Alias {alias} should resolve to {model_name}" diff --git a/tests/test_openai_compatible_token_usage.py b/tests/test_openai_compatible_token_usage.py index 4b75fb3..276ee55 100644 --- a/tests/test_openai_compatible_token_usage.py +++ b/tests/test_openai_compatible_token_usage.py @@ -15,7 +15,7 @@ class TestOpenAICompatibleTokenUsage(unittest.TestCase): # Create a concrete implementation for testing class TestProvider(OpenAICompatibleProvider): FRIENDLY_NAME = "Test" - SUPPORTED_MODELS = {"test-model": {"context_window": 4096}} + MODEL_CAPABILITIES = {"test-model": {"context_window": 4096}} def get_capabilities(self, model_name): return Mock() diff --git a/tests/test_supported_models_aliases.py b/tests/test_supported_models_aliases.py index 1dea8d3..efc1716 100644 --- a/tests/test_supported_models_aliases.py +++ b/tests/test_supported_models_aliases.py @@ -1,4 +1,4 @@ -"""Test the SUPPORTED_MODELS aliases structure across all providers.""" +"""Test the MODEL_CAPABILITIES aliases structure across all providers.""" from providers.dial import DIALModelProvider from providers.gemini import GeminiModelProvider @@ -7,24 +7,24 @@ from providers.xai import XAIModelProvider class TestSupportedModelsAliases: - """Test that all providers have correctly structured SUPPORTED_MODELS with aliases.""" + """Test that all providers have correctly structured MODEL_CAPABILITIES with aliases.""" def test_gemini_provider_aliases(self): """Test Gemini provider's alias structure.""" provider = GeminiModelProvider("test-key") # Check that all models have ModelCapabilities with aliases - for model_name, config in provider.SUPPORTED_MODELS.items(): + for model_name, config in provider.MODEL_CAPABILITIES.items(): assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute" assert isinstance(config.aliases, list), f"{model_name} aliases must be a list" # Test specific aliases - assert "flash" in provider.SUPPORTED_MODELS["gemini-2.5-flash"].aliases - assert "pro" in provider.SUPPORTED_MODELS["gemini-2.5-pro"].aliases - assert "flash-2.0" in provider.SUPPORTED_MODELS["gemini-2.0-flash"].aliases - assert "flash2" in provider.SUPPORTED_MODELS["gemini-2.0-flash"].aliases - assert "flashlite" in provider.SUPPORTED_MODELS["gemini-2.0-flash-lite"].aliases - assert "flash-lite" in provider.SUPPORTED_MODELS["gemini-2.0-flash-lite"].aliases + assert "flash" in provider.MODEL_CAPABILITIES["gemini-2.5-flash"].aliases + assert "pro" in provider.MODEL_CAPABILITIES["gemini-2.5-pro"].aliases + assert "flash-2.0" in provider.MODEL_CAPABILITIES["gemini-2.0-flash"].aliases + assert "flash2" in provider.MODEL_CAPABILITIES["gemini-2.0-flash"].aliases + assert "flashlite" in provider.MODEL_CAPABILITIES["gemini-2.0-flash-lite"].aliases + assert "flash-lite" in provider.MODEL_CAPABILITIES["gemini-2.0-flash-lite"].aliases # Test alias resolution assert provider._resolve_model_name("flash") == "gemini-2.5-flash" @@ -42,18 +42,18 @@ class TestSupportedModelsAliases: provider = OpenAIModelProvider("test-key") # Check that all models have ModelCapabilities with aliases - for model_name, config in provider.SUPPORTED_MODELS.items(): + for model_name, config in provider.MODEL_CAPABILITIES.items(): assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute" assert isinstance(config.aliases, list), f"{model_name} aliases must be a list" # Test specific aliases # "mini" is now an alias for gpt-5-mini, not o4-mini - assert "mini" in provider.SUPPORTED_MODELS["gpt-5-mini"].aliases - assert "o4mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases + assert "mini" in provider.MODEL_CAPABILITIES["gpt-5-mini"].aliases + assert "o4mini" in provider.MODEL_CAPABILITIES["o4-mini"].aliases # o4-mini is no longer in its own aliases (removed self-reference) - assert "o3mini" in provider.SUPPORTED_MODELS["o3-mini"].aliases - assert "o3pro" in provider.SUPPORTED_MODELS["o3-pro"].aliases - assert "gpt4.1" in provider.SUPPORTED_MODELS["gpt-4.1"].aliases + assert "o3mini" in provider.MODEL_CAPABILITIES["o3-mini"].aliases + assert "o3pro" in provider.MODEL_CAPABILITIES["o3-pro"].aliases + assert "gpt4.1" in provider.MODEL_CAPABILITIES["gpt-4.1"].aliases # Test alias resolution assert provider._resolve_model_name("mini") == "gpt-5-mini" # mini -> gpt-5-mini now @@ -71,16 +71,16 @@ class TestSupportedModelsAliases: provider = XAIModelProvider("test-key") # Check that all models have ModelCapabilities with aliases - for model_name, config in provider.SUPPORTED_MODELS.items(): + for model_name, config in provider.MODEL_CAPABILITIES.items(): assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute" assert isinstance(config.aliases, list), f"{model_name} aliases must be a list" # Test specific aliases - assert "grok" in provider.SUPPORTED_MODELS["grok-4"].aliases - assert "grok4" in provider.SUPPORTED_MODELS["grok-4"].aliases - assert "grok3" in provider.SUPPORTED_MODELS["grok-3"].aliases - assert "grok3fast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases - assert "grokfast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases + assert "grok" in provider.MODEL_CAPABILITIES["grok-4"].aliases + assert "grok4" in provider.MODEL_CAPABILITIES["grok-4"].aliases + assert "grok3" in provider.MODEL_CAPABILITIES["grok-3"].aliases + assert "grok3fast" in provider.MODEL_CAPABILITIES["grok-3-fast"].aliases + assert "grokfast" in provider.MODEL_CAPABILITIES["grok-3-fast"].aliases # Test alias resolution assert provider._resolve_model_name("grok") == "grok-4" @@ -98,16 +98,16 @@ class TestSupportedModelsAliases: provider = DIALModelProvider("test-key") # Check that all models have ModelCapabilities with aliases - for model_name, config in provider.SUPPORTED_MODELS.items(): + for model_name, config in provider.MODEL_CAPABILITIES.items(): assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute" assert isinstance(config.aliases, list), f"{model_name} aliases must be a list" # Test specific aliases - assert "o3" in provider.SUPPORTED_MODELS["o3-2025-04-16"].aliases - assert "o4-mini" in provider.SUPPORTED_MODELS["o4-mini-2025-04-16"].aliases - assert "sonnet-4.1" in provider.SUPPORTED_MODELS["anthropic.claude-sonnet-4.1-20250805-v1:0"].aliases - assert "opus-4.1" in provider.SUPPORTED_MODELS["anthropic.claude-opus-4.1-20250805-v1:0"].aliases - assert "gemini-2.5-pro" in provider.SUPPORTED_MODELS["gemini-2.5-pro-preview-05-06"].aliases + assert "o3" in provider.MODEL_CAPABILITIES["o3-2025-04-16"].aliases + assert "o4-mini" in provider.MODEL_CAPABILITIES["o4-mini-2025-04-16"].aliases + assert "sonnet-4.1" in provider.MODEL_CAPABILITIES["anthropic.claude-sonnet-4.1-20250805-v1:0"].aliases + assert "opus-4.1" in provider.MODEL_CAPABILITIES["anthropic.claude-opus-4.1-20250805-v1:0"].aliases + assert "gemini-2.5-pro" in provider.MODEL_CAPABILITIES["gemini-2.5-pro-preview-05-06"].aliases # Test alias resolution assert provider._resolve_model_name("o3") == "o3-2025-04-16" @@ -183,12 +183,12 @@ class TestSupportedModelsAliases: ] for provider in providers: - for model_name, config in provider.SUPPORTED_MODELS.items(): + for model_name, config in provider.MODEL_CAPABILITIES.items(): # All values must be ModelCapabilities objects, not strings or dicts from providers.shared import ModelCapabilities assert isinstance(config, ModelCapabilities), ( - f"{provider.__class__.__name__}.SUPPORTED_MODELS['{model_name}'] " + f"{provider.__class__.__name__}.MODEL_CAPABILITIES['{model_name}'] " f"must be a ModelCapabilities object, not {type(config).__name__}" ) diff --git a/tests/test_xai_provider.py b/tests/test_xai_provider.py index 5bdc4a0..bb8e97b 100644 --- a/tests/test_xai_provider.py +++ b/tests/test_xai_provider.py @@ -256,18 +256,18 @@ class TestXAIProvider: assert capabilities.friendly_name == "X.AI (Grok 3)" def test_supported_models_structure(self): - """Test that SUPPORTED_MODELS has the correct structure.""" + """Test that MODEL_CAPABILITIES has the correct structure.""" provider = XAIModelProvider("test-key") # Check that all expected base models are present - assert "grok-4" in provider.SUPPORTED_MODELS - assert "grok-3" in provider.SUPPORTED_MODELS - assert "grok-3-fast" in provider.SUPPORTED_MODELS + assert "grok-4" in provider.MODEL_CAPABILITIES + assert "grok-3" in provider.MODEL_CAPABILITIES + assert "grok-3-fast" in provider.MODEL_CAPABILITIES # Check model configs have required fields from providers.shared import ModelCapabilities - grok4_config = provider.SUPPORTED_MODELS["grok-4"] + grok4_config = provider.MODEL_CAPABILITIES["grok-4"] assert isinstance(grok4_config, ModelCapabilities) assert hasattr(grok4_config, "context_window") assert hasattr(grok4_config, "supports_extended_thinking") @@ -280,18 +280,18 @@ class TestXAIProvider: assert "grok-4" in grok4_config.aliases assert "grok4" in grok4_config.aliases - grok3_config = provider.SUPPORTED_MODELS["grok-3"] + grok3_config = provider.MODEL_CAPABILITIES["grok-3"] assert grok3_config.context_window == 131_072 assert grok3_config.supports_extended_thinking is False # Check aliases are correctly structured assert "grok3" in grok3_config.aliases # grok3 resolves to grok-3 # Check grok-4 aliases - grok4_config = provider.SUPPORTED_MODELS["grok-4"] + grok4_config = provider.MODEL_CAPABILITIES["grok-4"] assert "grok" in grok4_config.aliases # grok resolves to grok-4 assert "grok4" in grok4_config.aliases - grok3fast_config = provider.SUPPORTED_MODELS["grok-3-fast"] + grok3fast_config = provider.MODEL_CAPABILITIES["grok-3-fast"] assert "grok3fast" in grok3fast_config.aliases assert "grokfast" in grok3fast_config.aliases