From e94c028a3fef548835e22de30e4a07e80721d227 Mon Sep 17 00:00:00 2001 From: Fahad Date: Mon, 23 Jun 2025 15:23:55 +0400 Subject: [PATCH 1/8] Fixed: https://github.com/BeehiveInnovations/zen-mcp-server/issues/123 Fixed restriction checks for OpenRouter --- config.py | 2 +- providers/openrouter.py | 34 +++++---- providers/registry.py | 1 - server.py | 16 +++++ tests/test_custom_provider.py | 32 ++++++--- tests/test_parse_model_option.py | 79 +++++++++++++++++++++ tests/test_provider_routing_bugs.py | 34 +++++++-- tests/test_workflow_metadata.py | 104 +++++++++++++++++++--------- utils/model_restrictions.py | 4 ++ 9 files changed, 246 insertions(+), 60 deletions(-) create mode 100644 tests/test_parse_model_option.py diff --git a/config.py b/config.py index c824e29..d2005c1 100644 --- a/config.py +++ b/config.py @@ -14,7 +14,7 @@ import os # These values are used in server responses and for tracking releases # IMPORTANT: This is the single source of truth for version and author info # Semantic versioning: MAJOR.MINOR.PATCH -__version__ = "5.6.1" +__version__ = "5.6.2" # Last update date in ISO format __updated__ = "2025-06-23" # Primary maintainer diff --git a/providers/openrouter.py b/providers/openrouter.py index e464f4a..1e22b45 100644 --- a/providers/openrouter.py +++ b/providers/openrouter.py @@ -50,14 +50,6 @@ class OpenRouterProvider(OpenAICompatibleProvider): aliases = self._registry.list_aliases() logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases") - def _parse_allowed_models(self) -> None: - """Override to disable environment-based allow-list. - - OpenRouter model access is controlled via the OpenRouter dashboard, - not through environment variables. - """ - return None - def _resolve_model_name(self, model_name: str) -> str: """Resolve model aliases to OpenRouter model names. @@ -130,16 +122,34 @@ class OpenRouterProvider(OpenAICompatibleProvider): As the catch-all provider, OpenRouter accepts any model name that wasn't handled by higher-priority providers. OpenRouter will validate based on - the API key's permissions. + the API key's permissions and local restrictions. Args: model_name: Model name to validate Returns: - Always True - OpenRouter is the catch-all provider + True if model is allowed, False if restricted """ - # Accept any model name - OpenRouter is the fallback provider - # Higher priority providers (native APIs, custom endpoints) get first chance + # Check model restrictions if configured + from utils.model_restrictions import get_restriction_service + + restriction_service = get_restriction_service() + if restriction_service: + # Check if model name itself is allowed + if restriction_service.is_allowed(self.get_provider_type(), model_name): + return True + + # Also check aliases - model_name might be an alias + model_config = self._registry.resolve(model_name) + if model_config and model_config.aliases: + for alias in model_config.aliases: + if restriction_service.is_allowed(self.get_provider_type(), alias): + return True + + # If restrictions are configured and model/alias not in allowed list, reject + return False + + # No restrictions configured - accept any model name as the fallback provider return True def generate_content( diff --git a/providers/registry.py b/providers/registry.py index baa9222..da7a9b5 100644 --- a/providers/registry.py +++ b/providers/registry.py @@ -129,7 +129,6 @@ class ModelProviderRegistry: logging.debug(f"Available providers in registry: {list(instance._providers.keys())}") for provider_type in PROVIDER_PRIORITY_ORDER: - logging.debug(f"Checking provider_type: {provider_type}") if provider_type in instance._providers: logging.debug(f"Found {provider_type} in registry") # Get or create provider instance diff --git a/server.py b/server.py index 19904fb..9247aa6 100644 --- a/server.py +++ b/server.py @@ -673,6 +673,11 @@ def parse_model_option(model_string: str) -> tuple[str, Optional[str]]: """ Parse model:option format into model name and option. + Handles different formats: + - OpenRouter models: preserve :free, :beta, :preview suffixes as part of model name + - Ollama/Custom models: split on : to extract tags like :latest + - Consensus stance: extract options like :for, :against + Args: model_string: String that may contain "model:option" format @@ -680,6 +685,17 @@ def parse_model_option(model_string: str) -> tuple[str, Optional[str]]: tuple: (model_name, option) where option may be None """ if ":" in model_string and not model_string.startswith("http"): # Avoid parsing URLs + # Check if this looks like an OpenRouter model (contains /) + if "/" in model_string and model_string.count(":") == 1: + # Could be openai/gpt-4:something - check what comes after colon + parts = model_string.split(":", 1) + suffix = parts[1].strip().lower() + + # Known OpenRouter suffixes to preserve + if suffix in ["free", "beta", "preview"]: + return model_string.strip(), None + + # For other patterns (Ollama tags, consensus stances), split normally parts = model_string.split(":", 1) model_name = parts[0].strip() model_option = parts[1].strip() if len(parts) > 1 else None diff --git a/tests/test_custom_provider.py b/tests/test_custom_provider.py index 8708d39..125417d 100644 --- a/tests/test_custom_provider.py +++ b/tests/test_custom_provider.py @@ -45,18 +45,32 @@ class TestCustomProvider: def test_get_capabilities_from_registry(self): """Test get_capabilities returns registry capabilities when available.""" - provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1") + # Save original environment + original_env = os.environ.get("OPENROUTER_ALLOWED_MODELS") - # Test with a model that should be in the registry (OpenRouter model) and is allowed by restrictions - capabilities = provider.get_capabilities("o3") # o3 is in OPENROUTER_ALLOWED_MODELS + try: + # Clear any restrictions + os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) - assert capabilities.provider == ProviderType.OPENROUTER # o3 is an OpenRouter model (is_custom=false) - assert capabilities.context_window > 0 + provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1") - # Test with a custom model (is_custom=true) - capabilities = provider.get_capabilities("local-llama") - assert capabilities.provider == ProviderType.CUSTOM # local-llama has is_custom=true - assert capabilities.context_window > 0 + # Test with a model that should be in the registry (OpenRouter model) + capabilities = provider.get_capabilities("o3") # o3 is an OpenRouter model + + assert capabilities.provider == ProviderType.OPENROUTER # o3 is an OpenRouter model (is_custom=false) + assert capabilities.context_window > 0 + + # Test with a custom model (is_custom=true) + capabilities = provider.get_capabilities("local-llama") + assert capabilities.provider == ProviderType.CUSTOM # local-llama has is_custom=true + assert capabilities.context_window > 0 + + finally: + # Restore original environment + if original_env is None: + os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) + else: + os.environ["OPENROUTER_ALLOWED_MODELS"] = original_env def test_get_capabilities_generic_fallback(self): """Test get_capabilities returns generic capabilities for unknown models.""" diff --git a/tests/test_parse_model_option.py b/tests/test_parse_model_option.py new file mode 100644 index 0000000..5b01c88 --- /dev/null +++ b/tests/test_parse_model_option.py @@ -0,0 +1,79 @@ +"""Tests for parse_model_option function.""" + +from server import parse_model_option + + +class TestParseModelOption: + """Test cases for model option parsing.""" + + def test_openrouter_free_suffix_preserved(self): + """Test that OpenRouter :free suffix is preserved as part of model name.""" + model, option = parse_model_option("openai/gpt-3.5-turbo:free") + assert model == "openai/gpt-3.5-turbo:free" + assert option is None + + def test_openrouter_beta_suffix_preserved(self): + """Test that OpenRouter :beta suffix is preserved as part of model name.""" + model, option = parse_model_option("anthropic/claude-3-opus:beta") + assert model == "anthropic/claude-3-opus:beta" + assert option is None + + def test_openrouter_preview_suffix_preserved(self): + """Test that OpenRouter :preview suffix is preserved as part of model name.""" + model, option = parse_model_option("google/gemini-pro:preview") + assert model == "google/gemini-pro:preview" + assert option is None + + def test_ollama_tag_parsed_as_option(self): + """Test that Ollama tags are parsed as options.""" + model, option = parse_model_option("llama3.2:latest") + assert model == "llama3.2" + assert option == "latest" + + def test_consensus_stance_parsed_as_option(self): + """Test that consensus stances are parsed as options.""" + model, option = parse_model_option("o3:for") + assert model == "o3" + assert option == "for" + + model, option = parse_model_option("gemini-2.5-pro:against") + assert model == "gemini-2.5-pro" + assert option == "against" + + def test_openrouter_unknown_suffix_parsed_as_option(self): + """Test that unknown suffixes on OpenRouter models are parsed as options.""" + model, option = parse_model_option("openai/gpt-4:custom-tag") + assert model == "openai/gpt-4" + assert option == "custom-tag" + + def test_plain_model_name(self): + """Test plain model names without colons.""" + model, option = parse_model_option("gpt-4") + assert model == "gpt-4" + assert option is None + + def test_url_not_parsed(self): + """Test that URLs are not parsed for options.""" + model, option = parse_model_option("http://localhost:8080") + assert model == "http://localhost:8080" + assert option is None + + def test_whitespace_handling(self): + """Test that whitespace is properly stripped.""" + model, option = parse_model_option(" openai/gpt-3.5-turbo:free ") + assert model == "openai/gpt-3.5-turbo:free" + assert option is None + + model, option = parse_model_option(" llama3.2 : latest ") + assert model == "llama3.2" + assert option == "latest" + + def test_case_insensitive_suffix_matching(self): + """Test that OpenRouter suffix matching is case-insensitive.""" + model, option = parse_model_option("openai/gpt-3.5-turbo:FREE") + assert model == "openai/gpt-3.5-turbo:FREE" # Original case preserved + assert option is None + + model, option = parse_model_option("openai/gpt-3.5-turbo:Free") + assert model == "openai/gpt-3.5-turbo:Free" # Original case preserved + assert option is None diff --git a/tests/test_provider_routing_bugs.py b/tests/test_provider_routing_bugs.py index 9ed125b..f05e181 100644 --- a/tests/test_provider_routing_bugs.py +++ b/tests/test_provider_routing_bugs.py @@ -58,7 +58,13 @@ class TestProviderRoutingBugs: """ # Save original environment original_env = {} - for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]: + for key in [ + "GEMINI_API_KEY", + "OPENAI_API_KEY", + "XAI_API_KEY", + "OPENROUTER_API_KEY", + "OPENROUTER_ALLOWED_MODELS", + ]: original_env[key] = os.environ.get(key) try: @@ -66,6 +72,7 @@ class TestProviderRoutingBugs: os.environ.pop("GEMINI_API_KEY", None) # No Google API key os.environ.pop("OPENAI_API_KEY", None) os.environ.pop("XAI_API_KEY", None) + os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key" # Register only OpenRouter provider (like in server.py:configure_providers) @@ -113,12 +120,24 @@ class TestProviderRoutingBugs: """ # Save original environment original_env = {} - for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]: + for key in [ + "GEMINI_API_KEY", + "OPENAI_API_KEY", + "XAI_API_KEY", + "OPENROUTER_API_KEY", + "OPENROUTER_ALLOWED_MODELS", + ]: original_env[key] = os.environ.get(key) try: # Set up scenario: NO API keys at all - for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]: + for key in [ + "GEMINI_API_KEY", + "OPENAI_API_KEY", + "XAI_API_KEY", + "OPENROUTER_API_KEY", + "OPENROUTER_ALLOWED_MODELS", + ]: os.environ.pop(key, None) # Create tool to test fallback logic @@ -151,7 +170,13 @@ class TestProviderRoutingBugs: """ # Save original environment original_env = {} - for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]: + for key in [ + "GEMINI_API_KEY", + "OPENAI_API_KEY", + "XAI_API_KEY", + "OPENROUTER_API_KEY", + "OPENROUTER_ALLOWED_MODELS", + ]: original_env[key] = os.environ.get(key) try: @@ -160,6 +185,7 @@ class TestProviderRoutingBugs: os.environ["OPENAI_API_KEY"] = "test-openai-key" os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key" os.environ.pop("XAI_API_KEY", None) + os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions # Register providers in priority order (like server.py) from providers.gemini import GeminiModelProvider diff --git a/tests/test_workflow_metadata.py b/tests/test_workflow_metadata.py index 7f1e139..d0a9693 100644 --- a/tests/test_workflow_metadata.py +++ b/tests/test_workflow_metadata.py @@ -48,7 +48,13 @@ class TestWorkflowMetadata: """ # Save original environment original_env = {} - for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]: + for key in [ + "GEMINI_API_KEY", + "OPENAI_API_KEY", + "XAI_API_KEY", + "OPENROUTER_API_KEY", + "OPENROUTER_ALLOWED_MODELS", + ]: original_env[key] = os.environ.get(key) try: @@ -56,6 +62,7 @@ class TestWorkflowMetadata: os.environ.pop("GEMINI_API_KEY", None) os.environ.pop("OPENAI_API_KEY", None) os.environ.pop("XAI_API_KEY", None) + os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key" # Register OpenRouter provider @@ -124,7 +131,13 @@ class TestWorkflowMetadata: """ # Save original environment original_env = {} - for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]: + for key in [ + "GEMINI_API_KEY", + "OPENAI_API_KEY", + "XAI_API_KEY", + "OPENROUTER_API_KEY", + "OPENROUTER_ALLOWED_MODELS", + ]: original_env[key] = os.environ.get(key) try: @@ -132,6 +145,7 @@ class TestWorkflowMetadata: os.environ.pop("GEMINI_API_KEY", None) os.environ.pop("OPENAI_API_KEY", None) os.environ.pop("XAI_API_KEY", None) + os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key" # Register OpenRouter provider @@ -182,43 +196,60 @@ class TestWorkflowMetadata: """ Test that workflow tools handle metadata gracefully when model context is missing. """ - # Create debug tool - debug_tool = DebugIssueTool() + # Save original environment + original_env = {} + for key in ["OPENROUTER_ALLOWED_MODELS"]: + original_env[key] = os.environ.get(key) - # Create arguments without model context (fallback scenario) - arguments = { - "step": "Test step without model context", - "step_number": 1, - "total_steps": 1, - "next_step_required": False, - "findings": "Test findings", - "model": "flash", - "confidence": "low", - # No _model_context or _resolved_model_name - } + try: + # Clear any restrictions + os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) - # Execute the workflow tool - import asyncio + # Create debug tool + debug_tool = DebugIssueTool() - result = asyncio.run(debug_tool.execute_workflow(arguments)) + # Create arguments without model context (fallback scenario) + arguments = { + "step": "Test step without model context", + "step_number": 1, + "total_steps": 1, + "next_step_required": False, + "findings": "Test findings", + "model": "flash", + "confidence": "low", + # No _model_context or _resolved_model_name + } - # Parse the JSON response - assert len(result) == 1 - response_text = result[0].text - response_data = json.loads(response_text) + # Execute the workflow tool + import asyncio - # Verify metadata is still present with fallback values - assert "metadata" in response_data, "Workflow response should include metadata even in fallback" - metadata = response_data["metadata"] + result = asyncio.run(debug_tool.execute_workflow(arguments)) - # Verify fallback metadata - assert "tool_name" in metadata, "Fallback metadata should include tool_name" - assert "model_used" in metadata, "Fallback metadata should include model_used" - assert "provider_used" in metadata, "Fallback metadata should include provider_used" + # Parse the JSON response + assert len(result) == 1 + response_text = result[0].text + response_data = json.loads(response_text) - assert metadata["tool_name"] == "debug", "tool_name should be 'debug'" - assert metadata["model_used"] == "flash", "model_used should be from request" - assert metadata["provider_used"] == "unknown", "provider_used should be 'unknown' in fallback" + # Verify metadata is still present with fallback values + assert "metadata" in response_data, "Workflow response should include metadata even in fallback" + metadata = response_data["metadata"] + + # Verify fallback metadata + assert "tool_name" in metadata, "Fallback metadata should include tool_name" + assert "model_used" in metadata, "Fallback metadata should include model_used" + assert "provider_used" in metadata, "Fallback metadata should include provider_used" + + assert metadata["tool_name"] == "debug", "tool_name should be 'debug'" + assert metadata["model_used"] == "flash", "model_used should be from request" + assert metadata["provider_used"] == "unknown", "provider_used should be 'unknown' in fallback" + + finally: + # Restore original environment + for key, value in original_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value @pytest.mark.no_mock_provider def test_workflow_metadata_preserves_existing_response_fields(self): @@ -227,7 +258,13 @@ class TestWorkflowMetadata: """ # Save original environment original_env = {} - for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]: + for key in [ + "GEMINI_API_KEY", + "OPENAI_API_KEY", + "XAI_API_KEY", + "OPENROUTER_API_KEY", + "OPENROUTER_ALLOWED_MODELS", + ]: original_env[key] = os.environ.get(key) try: @@ -235,6 +272,7 @@ class TestWorkflowMetadata: os.environ.pop("GEMINI_API_KEY", None) os.environ.pop("OPENAI_API_KEY", None) os.environ.pop("XAI_API_KEY", None) + os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key" # Register OpenRouter provider diff --git a/utils/model_restrictions.py b/utils/model_restrictions.py index 834c0a2..b10544a 100644 --- a/utils/model_restrictions.py +++ b/utils/model_restrictions.py @@ -128,6 +128,10 @@ class ModelRestrictionService: allowed_set = self.restrictions[provider_type] + if len(allowed_set) == 0: + # Empty set - allowed + return True + # Check both the resolved name and original name (if different) names_to_check = {model_name.lower()} if original_name and original_name.lower() != model_name.lower(): From 498ea88293fea1b5f3f463ad9878aa1e6b516bc3 Mon Sep 17 00:00:00 2001 From: Fahad Date: Mon, 23 Jun 2025 16:58:59 +0400 Subject: [PATCH 2/8] Use ModelCapabilities consistently instead of dictionaries Moved aliases as part of SUPPORTED_MODELS instead of shorthand, more in line with how custom_models are declared Further refactoring to cleanup some code --- config.py | 2 +- providers/base.py | 134 +++++++++-- providers/custom.py | 74 +++--- providers/dial.py | 301 +++++++++++++------------ providers/gemini.py | 213 ++++++++--------- providers/openai_provider.py | 252 +++++++++------------ providers/openrouter.py | 36 +++ providers/xai.py | 134 ++++------- tests/test_auto_mode.py | 2 +- tests/test_auto_mode_comprehensive.py | 13 +- tests/test_dial_provider.py | 2 +- tests/test_openai_provider.py | 6 +- tests/test_supported_models_aliases.py | 206 +++++++++++++++++ tests/test_xai_provider.py | 37 +-- tools/listmodels.py | 25 +- tools/shared/base_tool.py | 18 +- 16 files changed, 850 insertions(+), 605 deletions(-) create mode 100644 tests/test_supported_models_aliases.py diff --git a/config.py b/config.py index d2005c1..5e8667a 100644 --- a/config.py +++ b/config.py @@ -14,7 +14,7 @@ import os # These values are used in server responses and for tracking releases # IMPORTANT: This is the single source of truth for version and author info # Semantic versioning: MAJOR.MINOR.PATCH -__version__ = "5.6.2" +__version__ = "5.7.0" # Last update date in ISO format __updated__ = "2025-06-23" # Primary maintainer diff --git a/providers/base.py b/providers/base.py index c8b1ec7..06f60fe 100644 --- a/providers/base.py +++ b/providers/base.py @@ -140,6 +140,19 @@ class ModelCapabilities: max_image_size_mb: float = 0.0 # Maximum total size for all images in MB supports_temperature: bool = True # Whether model accepts temperature parameter in API calls + # Additional fields for comprehensive model information + description: str = "" # Human-readable description of the model + aliases: list[str] = field(default_factory=list) # Alternative names/shortcuts for the model + + # JSON mode support (for providers that support structured output) + supports_json_mode: bool = False + + # Thinking mode support (for models with thinking capabilities) + max_thinking_tokens: int = 0 # Maximum thinking tokens for extended reasoning models + + # Custom model flag (for models that only work with custom endpoints) + is_custom: bool = False # Whether this model requires custom API endpoints + # Temperature constraint object - preferred way to define temperature limits temperature_constraint: TemperatureConstraint = field( default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.7) @@ -251,7 +264,7 @@ class ModelProvider(ABC): capabilities = self.get_capabilities(model_name) # Check if model supports temperature at all - if hasattr(capabilities, "supports_temperature") and not capabilities.supports_temperature: + if not capabilities.supports_temperature: return None # Get temperature range @@ -290,19 +303,109 @@ class ModelProvider(ABC): """Check if the model supports extended thinking mode.""" pass - @abstractmethod + def get_model_configurations(self) -> dict[str, ModelCapabilities]: + """Get model configurations for this provider. + + This is a hook method that subclasses can override to provide + their model configurations from different sources. + + Returns: + Dictionary mapping model names to their ModelCapabilities objects + """ + # Return SUPPORTED_MODELS if it exists (must contain ModelCapabilities objects) + if hasattr(self, "SUPPORTED_MODELS"): + return {k: v for k, v in self.SUPPORTED_MODELS.items() if isinstance(v, ModelCapabilities)} + return {} + + def get_all_model_aliases(self) -> dict[str, list[str]]: + """Get all model aliases for this provider. + + This is a hook method that subclasses can override to provide + aliases from different sources. + + Returns: + Dictionary mapping model names to their list of aliases + """ + # Default implementation extracts from ModelCapabilities objects + aliases = {} + for model_name, capabilities in self.get_model_configurations().items(): + if capabilities.aliases: + aliases[model_name] = capabilities.aliases + return aliases + + def _resolve_model_name(self, model_name: str) -> str: + """Resolve model shorthand to full name. + + This implementation uses the hook methods to support different + model configuration sources. + + Args: + model_name: Model name that may be an alias + + Returns: + Resolved model name + """ + # Get model configurations from the hook method + model_configs = self.get_model_configurations() + + # First check if it's already a base model name (case-sensitive exact match) + if model_name in model_configs: + return model_name + + # Check case-insensitively for both base models and aliases + model_name_lower = model_name.lower() + + # Check base model names case-insensitively + for base_model in model_configs: + if base_model.lower() == model_name_lower: + return base_model + + # Check aliases from the hook method + all_aliases = self.get_all_model_aliases() + for base_model, aliases in all_aliases.items(): + if any(alias.lower() == model_name_lower for alias in aliases): + return base_model + + # If not found, return as-is + return model_name + def list_models(self, respect_restrictions: bool = True) -> list[str]: """Return a list of model names supported by this provider. + This implementation uses the get_model_configurations() hook + to support different model configuration sources. + Args: respect_restrictions: Whether to apply provider-specific restriction logic. Returns: List of model names available from this provider """ - pass + from utils.model_restrictions import get_restriction_service + + restriction_service = get_restriction_service() if respect_restrictions else None + models = [] + + # Get model configurations from the hook method + model_configs = self.get_model_configurations() + + for model_name in model_configs: + # Check restrictions if enabled + if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name): + continue + + # Add the base model + models.append(model_name) + + # Get aliases from the hook method + all_aliases = self.get_all_model_aliases() + for model_name, aliases in all_aliases.items(): + # Only add aliases for models that passed restriction check + if model_name in models: + models.extend(aliases) + + return models - @abstractmethod def list_all_known_models(self) -> list[str]: """Return all model names known by this provider, including alias targets. @@ -312,21 +415,22 @@ class ModelProvider(ABC): Returns: List of all model names and alias targets known by this provider """ - pass + all_models = set() - def _resolve_model_name(self, model_name: str) -> str: - """Resolve model shorthand to full name. + # Get model configurations from the hook method + model_configs = self.get_model_configurations() - Base implementation returns the model name unchanged. - Subclasses should override to provide alias resolution. + # Add all base model names + for model_name in model_configs: + all_models.add(model_name.lower()) - Args: - model_name: Model name that may be an alias + # Get aliases from the hook method and add them + all_aliases = self.get_all_model_aliases() + for _model_name, aliases in all_aliases.items(): + for alias in aliases: + all_models.add(alias.lower()) - Returns: - Resolved model name - """ - return model_name + return list(all_models) def close(self): """Clean up any resources held by the provider. diff --git a/providers/custom.py b/providers/custom.py index bad1062..52d9b94 100644 --- a/providers/custom.py +++ b/providers/custom.py @@ -268,65 +268,55 @@ class CustomProvider(OpenAICompatibleProvider): def supports_thinking_mode(self, model_name: str) -> bool: """Check if the model supports extended thinking mode. - Most custom/local models don't support extended thinking. - Args: model_name: Model to check Returns: - False (custom models generally don't support thinking mode) + True if model supports thinking mode, False otherwise """ + # Check if model is in registry + config = self._registry.resolve(model_name) if self._registry else None + if config and config.is_custom: + # Trust the config from custom_models.json + return config.supports_extended_thinking + + # Default to False for unknown models return False - def list_models(self, respect_restrictions: bool = True) -> list[str]: - """Return a list of model names supported by this provider. + def get_model_configurations(self) -> dict[str, ModelCapabilities]: + """Get model configurations from the registry. - Args: - respect_restrictions: Whether to apply provider-specific restriction logic. + For CustomProvider, we convert registry configurations to ModelCapabilities objects. Returns: - List of model names available from this provider + Dictionary mapping model names to their ModelCapabilities objects """ - from utils.model_restrictions import get_restriction_service + from .base import ProviderType - restriction_service = get_restriction_service() if respect_restrictions else None - models = [] + configs = {} if self._registry: - # Get all models from the registry - all_models = self._registry.list_models() - aliases = self._registry.list_aliases() - - # Add models that are validated by the custom provider - for model_name in all_models + aliases: - # Use the provider's validation logic to determine if this model - # is appropriate for the custom endpoint + # Get all models from registry + for model_name in self._registry.list_models(): + # Only include custom models that this provider validates if self.validate_model_name(model_name): - # Check restrictions if enabled - if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name): - continue + config = self._registry.resolve(model_name) + if config and config.is_custom: + # Convert OpenRouterModelConfig to ModelCapabilities + capabilities = config.to_capabilities() + # Override provider type to CUSTOM for local models + capabilities.provider = ProviderType.CUSTOM + capabilities.friendly_name = f"{self.FRIENDLY_NAME} ({config.model_name})" + configs[model_name] = capabilities - models.append(model_name) + return configs - return models - - def list_all_known_models(self) -> list[str]: - """Return all model names known by this provider, including alias targets. + def get_all_model_aliases(self) -> dict[str, list[str]]: + """Get all model aliases from the registry. Returns: - List of all model names and alias targets known by this provider + Dictionary mapping model names to their list of aliases """ - all_models = set() - - if self._registry: - # Get all models and aliases from the registry - all_models.update(model.lower() for model in self._registry.list_models()) - all_models.update(alias.lower() for alias in self._registry.list_aliases()) - - # For each alias, also add its target - for alias in self._registry.list_aliases(): - config = self._registry.resolve(alias) - if config: - all_models.add(config.model_name.lower()) - - return list(all_models) + # Since aliases are now included in the configurations, + # we can use the base class implementation + return super().get_all_model_aliases() diff --git a/providers/dial.py b/providers/dial.py index 617858c..f019415 100644 --- a/providers/dial.py +++ b/providers/dial.py @@ -10,7 +10,7 @@ from .base import ( ModelCapabilities, ModelResponse, ProviderType, - RangeTemperatureConstraint, + create_temperature_constraint, ) from .openai_compatible import OpenAICompatibleProvider @@ -30,63 +30,161 @@ class DIALModelProvider(OpenAICompatibleProvider): MAX_RETRIES = 4 RETRY_DELAYS = [1, 3, 5, 8] # seconds - # Supported DIAL models (these can be customized based on your DIAL deployment) + # Model configurations using ModelCapabilities objects SUPPORTED_MODELS = { - "o3-2025-04-16": { - "context_window": 200_000, - "supports_extended_thinking": False, - "supports_vision": True, - }, - "o4-mini-2025-04-16": { - "context_window": 200_000, - "supports_extended_thinking": False, - "supports_vision": True, - }, - "anthropic.claude-sonnet-4-20250514-v1:0": { - "context_window": 200_000, - "supports_extended_thinking": False, - "supports_vision": True, - }, - "anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": { - "context_window": 200_000, - "supports_extended_thinking": True, # Thinking mode variant - "supports_vision": True, - }, - "anthropic.claude-opus-4-20250514-v1:0": { - "context_window": 200_000, - "supports_extended_thinking": False, - "supports_vision": True, - }, - "anthropic.claude-opus-4-20250514-v1:0-with-thinking": { - "context_window": 200_000, - "supports_extended_thinking": True, # Thinking mode variant - "supports_vision": True, - }, - "gemini-2.5-pro-preview-03-25-google-search": { - "context_window": 1_000_000, - "supports_extended_thinking": False, # DIAL doesn't expose thinking mode - "supports_vision": True, - }, - "gemini-2.5-pro-preview-05-06": { - "context_window": 1_000_000, - "supports_extended_thinking": False, - "supports_vision": True, - }, - "gemini-2.5-flash-preview-05-20": { - "context_window": 1_000_000, - "supports_extended_thinking": False, - "supports_vision": True, - }, - # Shorthands - "o3": "o3-2025-04-16", - "o4-mini": "o4-mini-2025-04-16", - "sonnet-4": "anthropic.claude-sonnet-4-20250514-v1:0", - "sonnet-4-thinking": "anthropic.claude-sonnet-4-20250514-v1:0-with-thinking", - "opus-4": "anthropic.claude-opus-4-20250514-v1:0", - "opus-4-thinking": "anthropic.claude-opus-4-20250514-v1:0-with-thinking", - "gemini-2.5-pro": "gemini-2.5-pro-preview-05-06", - "gemini-2.5-pro-search": "gemini-2.5-pro-preview-03-25-google-search", - "gemini-2.5-flash": "gemini-2.5-flash-preview-05-20", + "o3-2025-04-16": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="o3-2025-04-16", + friendly_name="DIAL (O3)", + context_window=200_000, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # DIAL may not expose function calling + supports_json_mode=True, + supports_images=True, + max_image_size_mb=20.0, + supports_temperature=False, # O3 models don't accept temperature + temperature_constraint=create_temperature_constraint("fixed"), + description="OpenAI O3 via DIAL - Strong reasoning model", + aliases=["o3"], + ), + "o4-mini-2025-04-16": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="o4-mini-2025-04-16", + friendly_name="DIAL (O4-mini)", + context_window=200_000, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # DIAL may not expose function calling + supports_json_mode=True, + supports_images=True, + max_image_size_mb=20.0, + supports_temperature=False, # O4 models don't accept temperature + temperature_constraint=create_temperature_constraint("fixed"), + description="OpenAI O4-mini via DIAL - Fast reasoning model", + aliases=["o4-mini"], + ), + "anthropic.claude-sonnet-4-20250514-v1:0": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="anthropic.claude-sonnet-4-20250514-v1:0", + friendly_name="DIAL (Sonnet 4)", + context_window=200_000, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # Claude doesn't have function calling + supports_json_mode=False, # Claude doesn't have JSON mode + supports_images=True, + max_image_size_mb=5.0, + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + description="Claude Sonnet 4 via DIAL - Balanced performance", + aliases=["sonnet-4"], + ), + "anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="anthropic.claude-sonnet-4-20250514-v1:0-with-thinking", + friendly_name="DIAL (Sonnet 4 Thinking)", + context_window=200_000, + supports_extended_thinking=True, # Thinking mode variant + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # Claude doesn't have function calling + supports_json_mode=False, # Claude doesn't have JSON mode + supports_images=True, + max_image_size_mb=5.0, + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + description="Claude Sonnet 4 with thinking mode via DIAL", + aliases=["sonnet-4-thinking"], + ), + "anthropic.claude-opus-4-20250514-v1:0": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="anthropic.claude-opus-4-20250514-v1:0", + friendly_name="DIAL (Opus 4)", + context_window=200_000, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # Claude doesn't have function calling + supports_json_mode=False, # Claude doesn't have JSON mode + supports_images=True, + max_image_size_mb=5.0, + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + description="Claude Opus 4 via DIAL - Most capable Claude model", + aliases=["opus-4"], + ), + "anthropic.claude-opus-4-20250514-v1:0-with-thinking": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="anthropic.claude-opus-4-20250514-v1:0-with-thinking", + friendly_name="DIAL (Opus 4 Thinking)", + context_window=200_000, + supports_extended_thinking=True, # Thinking mode variant + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # Claude doesn't have function calling + supports_json_mode=False, # Claude doesn't have JSON mode + supports_images=True, + max_image_size_mb=5.0, + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + description="Claude Opus 4 with thinking mode via DIAL", + aliases=["opus-4-thinking"], + ), + "gemini-2.5-pro-preview-03-25-google-search": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="gemini-2.5-pro-preview-03-25-google-search", + friendly_name="DIAL (Gemini 2.5 Pro Search)", + context_window=1_000_000, + supports_extended_thinking=False, # DIAL doesn't expose thinking mode + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # DIAL may not expose function calling + supports_json_mode=True, + supports_images=True, + max_image_size_mb=20.0, + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + description="Gemini 2.5 Pro with Google Search via DIAL", + aliases=["gemini-2.5-pro-search"], + ), + "gemini-2.5-pro-preview-05-06": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="gemini-2.5-pro-preview-05-06", + friendly_name="DIAL (Gemini 2.5 Pro)", + context_window=1_000_000, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # DIAL may not expose function calling + supports_json_mode=True, + supports_images=True, + max_image_size_mb=20.0, + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + description="Gemini 2.5 Pro via DIAL - Deep reasoning", + aliases=["gemini-2.5-pro"], + ), + "gemini-2.5-flash-preview-05-20": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="gemini-2.5-flash-preview-05-20", + friendly_name="DIAL (Gemini Flash 2.5)", + context_window=1_000_000, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # DIAL may not expose function calling + supports_json_mode=True, + supports_images=True, + max_image_size_mb=20.0, + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + description="Gemini 2.5 Flash via DIAL - Ultra-fast", + aliases=["gemini-2.5-flash"], + ), } def __init__(self, api_key: str, **kwargs): @@ -181,20 +279,8 @@ class DIALModelProvider(OpenAICompatibleProvider): if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name): raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.") - config = self.SUPPORTED_MODELS[resolved_name] - - return ModelCapabilities( - provider=ProviderType.DIAL, - model_name=resolved_name, - friendly_name=self.FRIENDLY_NAME, - context_window=config["context_window"], - supports_extended_thinking=config["supports_extended_thinking"], - supports_system_prompts=True, - supports_streaming=True, - supports_function_calling=True, - supports_images=config.get("supports_vision", False), - temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7), - ) + # Return the ModelCapabilities object directly from SUPPORTED_MODELS + return self.SUPPORTED_MODELS[resolved_name] def get_provider_type(self) -> ProviderType: """Get the provider type.""" @@ -211,7 +297,7 @@ class DIALModelProvider(OpenAICompatibleProvider): """ resolved_name = self._resolve_model_name(model_name) - if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict): + if resolved_name not in self.SUPPORTED_MODELS: return False # Check against base class allowed_models if configured @@ -231,20 +317,6 @@ class DIALModelProvider(OpenAICompatibleProvider): return True - def _resolve_model_name(self, model_name: str) -> str: - """Resolve model shorthand to full name. - - Args: - model_name: Model name or shorthand - - Returns: - Full model name - """ - shorthand_value = self.SUPPORTED_MODELS.get(model_name) - if isinstance(shorthand_value, str): - return shorthand_value - return model_name - def _get_deployment_client(self, deployment: str): """Get or create a cached client for a specific deployment. @@ -357,7 +429,7 @@ class DIALModelProvider(OpenAICompatibleProvider): # Check model capabilities try: capabilities = self.get_capabilities(model_name) - supports_temperature = getattr(capabilities, "supports_temperature", True) + supports_temperature = capabilities.supports_temperature except Exception as e: logger.debug(f"Failed to check temperature support for {model_name}: {e}") supports_temperature = True @@ -441,63 +513,12 @@ class DIALModelProvider(OpenAICompatibleProvider): """ resolved_name = self._resolve_model_name(model_name) - if resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict): - return self.SUPPORTED_MODELS[resolved_name].get("supports_vision", False) + if resolved_name in self.SUPPORTED_MODELS: + return self.SUPPORTED_MODELS[resolved_name].supports_images # Fall back to parent implementation for unknown models return super()._supports_vision(model_name) - def list_models(self, respect_restrictions: bool = True) -> list[str]: - """Return a list of model names supported by this provider. - - Args: - respect_restrictions: Whether to apply provider-specific restriction logic. - - Returns: - List of model names available from this provider - """ - # Get all model keys (both full names and aliases) - all_models = list(self.SUPPORTED_MODELS.keys()) - - if not respect_restrictions: - return all_models - - # Apply restrictions if configured - from utils.model_restrictions import get_restriction_service - - restriction_service = get_restriction_service() - - # Filter based on restrictions - allowed_models = [] - for model in all_models: - resolved_name = self._resolve_model_name(model) - if restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model): - allowed_models.append(model) - - return allowed_models - - def list_all_known_models(self) -> list[str]: - """Return all model names known by this provider, including alias targets. - - This is used for validation purposes to ensure restriction policies - can validate against both aliases and their target model names. - - Returns: - List of all model names and alias targets known by this provider - """ - # Collect all unique model names (both aliases and targets) - all_models = set() - - for key, value in self.SUPPORTED_MODELS.items(): - # Add the key (could be alias or full name) - all_models.add(key) - - # If it's an alias (string value), add the target too - if isinstance(value, str): - all_models.add(value) - - return sorted(all_models) - def close(self): """Clean up HTTP clients when provider is closed.""" logger.info("Closing DIAL provider HTTP clients...") diff --git a/providers/gemini.py b/providers/gemini.py index 074232f..1118699 100644 --- a/providers/gemini.py +++ b/providers/gemini.py @@ -9,7 +9,7 @@ from typing import Optional from google import genai from google.genai import types -from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, RangeTemperatureConstraint +from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, create_temperature_constraint logger = logging.getLogger(__name__) @@ -17,47 +17,79 @@ logger = logging.getLogger(__name__) class GeminiModelProvider(ModelProvider): """Google Gemini model provider implementation.""" - # Model configurations + # Model configurations using ModelCapabilities objects SUPPORTED_MODELS = { - "gemini-2.0-flash": { - "context_window": 1_048_576, # 1M tokens - "supports_extended_thinking": True, # Experimental thinking mode - "max_thinking_tokens": 24576, # Same as 2.5 flash for consistency - "supports_images": True, # Vision capability - "max_image_size_mb": 20.0, # Conservative 20MB limit for reliability - "description": "Gemini 2.0 Flash (1M context) - Latest fast model with experimental thinking, supports audio/video input", - }, - "gemini-2.0-flash-lite": { - "context_window": 1_048_576, # 1M tokens - "supports_extended_thinking": False, # Not supported per user request - "max_thinking_tokens": 0, # No thinking support - "supports_images": False, # Does not support images - "max_image_size_mb": 0.0, # No image support - "description": "Gemini 2.0 Flash Lite (1M context) - Lightweight fast model, text-only", - }, - "gemini-2.5-flash": { - "context_window": 1_048_576, # 1M tokens - "supports_extended_thinking": True, - "max_thinking_tokens": 24576, # Flash 2.5 thinking budget limit - "supports_images": True, # Vision capability - "max_image_size_mb": 20.0, # Conservative 20MB limit for reliability - "description": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations", - }, - "gemini-2.5-pro": { - "context_window": 1_048_576, # 1M tokens - "supports_extended_thinking": True, - "max_thinking_tokens": 32768, # Pro 2.5 thinking budget limit - "supports_images": True, # Vision capability - "max_image_size_mb": 32.0, # Higher limit for Pro model - "description": "Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis", - }, - # Shorthands - "flash": "gemini-2.5-flash", - "flash-2.0": "gemini-2.0-flash", - "flash2": "gemini-2.0-flash", - "flashlite": "gemini-2.0-flash-lite", - "flash-lite": "gemini-2.0-flash-lite", - "pro": "gemini-2.5-pro", + "gemini-2.0-flash": ModelCapabilities( + provider=ProviderType.GOOGLE, + model_name="gemini-2.0-flash", + friendly_name="Gemini (Flash 2.0)", + context_window=1_048_576, # 1M tokens + supports_extended_thinking=True, # Experimental thinking mode + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # Vision capability + max_image_size_mb=20.0, # Conservative 20MB limit for reliability + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + max_thinking_tokens=24576, # Same as 2.5 flash for consistency + description="Gemini 2.0 Flash (1M context) - Latest fast model with experimental thinking, supports audio/video input", + aliases=["flash-2.0", "flash2"], + ), + "gemini-2.0-flash-lite": ModelCapabilities( + provider=ProviderType.GOOGLE, + model_name="gemini-2.0-flash-lite", + friendly_name="Gemin (Flash Lite 2.0)", + context_window=1_048_576, # 1M tokens + supports_extended_thinking=False, # Not supported per user request + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=False, # Does not support images + max_image_size_mb=0.0, # No image support + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + description="Gemini 2.0 Flash Lite (1M context) - Lightweight fast model, text-only", + aliases=["flashlite", "flash-lite"], + ), + "gemini-2.5-flash": ModelCapabilities( + provider=ProviderType.GOOGLE, + model_name="gemini-2.5-flash", + friendly_name="Gemini (Flash 2.5)", + context_window=1_048_576, # 1M tokens + supports_extended_thinking=True, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # Vision capability + max_image_size_mb=20.0, # Conservative 20MB limit for reliability + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + max_thinking_tokens=24576, # Flash 2.5 thinking budget limit + description="Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations", + aliases=["flash", "flash2.5"], + ), + "gemini-2.5-pro": ModelCapabilities( + provider=ProviderType.GOOGLE, + model_name="gemini-2.5-pro", + friendly_name="Gemini (Pro 2.5)", + context_window=1_048_576, # 1M tokens + supports_extended_thinking=True, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # Vision capability + max_image_size_mb=32.0, # Higher limit for Pro model + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + max_thinking_tokens=32768, # Max thinking tokens for Pro model + description="Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis", + aliases=["pro", "gemini pro", "gemini-pro"], + ), } # Thinking mode configurations - percentages of model's max_thinking_tokens @@ -70,6 +102,14 @@ class GeminiModelProvider(ModelProvider): "max": 1.0, # 100% of max - full thinking budget } + # Model-specific thinking token limits + MAX_THINKING_TOKENS = { + "gemini-2.0-flash": 24576, # Same as 2.5 flash for consistency + "gemini-2.0-flash-lite": 0, # No thinking support + "gemini-2.5-flash": 24576, # Flash 2.5 thinking budget limit + "gemini-2.5-pro": 32768, # Pro 2.5 thinking budget limit + } + def __init__(self, api_key: str, **kwargs): """Initialize Gemini provider with API key.""" super().__init__(api_key, **kwargs) @@ -100,25 +140,8 @@ class GeminiModelProvider(ModelProvider): if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name): raise ValueError(f"Gemini model '{resolved_name}' is not allowed by restriction policy.") - config = self.SUPPORTED_MODELS[resolved_name] - - # Gemini models support 0.0-2.0 temperature range - temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7) - - return ModelCapabilities( - provider=ProviderType.GOOGLE, - model_name=resolved_name, - friendly_name="Gemini", - context_window=config["context_window"], - supports_extended_thinking=config["supports_extended_thinking"], - supports_system_prompts=True, - supports_streaming=True, - supports_function_calling=True, - supports_images=config.get("supports_images", False), - max_image_size_mb=config.get("max_image_size_mb", 0.0), - supports_temperature=True, # Gemini models accept temperature parameter - temperature_constraint=temp_constraint, - ) + # Return the ModelCapabilities object directly from SUPPORTED_MODELS + return self.SUPPORTED_MODELS[resolved_name] def generate_content( self, @@ -179,8 +202,8 @@ class GeminiModelProvider(ModelProvider): if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS: # Get model's max thinking tokens and calculate actual budget model_config = self.SUPPORTED_MODELS.get(resolved_name) - if model_config and "max_thinking_tokens" in model_config: - max_thinking_tokens = model_config["max_thinking_tokens"] + if model_config and model_config.max_thinking_tokens > 0: + max_thinking_tokens = model_config.max_thinking_tokens actual_thinking_budget = int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode]) generation_config.thinking_config = types.ThinkingConfig(thinking_budget=actual_thinking_budget) @@ -258,7 +281,7 @@ class GeminiModelProvider(ModelProvider): resolved_name = self._resolve_model_name(model_name) # First check if model is supported - if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict): + if resolved_name not in self.SUPPORTED_MODELS: return False # Then check if model is allowed by restrictions @@ -281,78 +304,20 @@ class GeminiModelProvider(ModelProvider): def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int: """Get actual thinking token budget for a model and thinking mode.""" resolved_name = self._resolve_model_name(model_name) - model_config = self.SUPPORTED_MODELS.get(resolved_name, {}) + model_config = self.SUPPORTED_MODELS.get(resolved_name) - if not model_config.get("supports_extended_thinking", False): + if not model_config or not model_config.supports_extended_thinking: return 0 if thinking_mode not in self.THINKING_BUDGETS: return 0 - max_thinking_tokens = model_config.get("max_thinking_tokens", 0) + max_thinking_tokens = model_config.max_thinking_tokens if max_thinking_tokens == 0: return 0 return int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode]) - def list_models(self, respect_restrictions: bool = True) -> list[str]: - """Return a list of model names supported by this provider. - - Args: - respect_restrictions: Whether to apply provider-specific restriction logic. - - Returns: - List of model names available from this provider - """ - from utils.model_restrictions import get_restriction_service - - restriction_service = get_restriction_service() if respect_restrictions else None - models = [] - - for model_name, config in self.SUPPORTED_MODELS.items(): - # Handle both base models (dict configs) and aliases (string values) - if isinstance(config, str): - # This is an alias - check if the target model would be allowed - target_model = config - if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model): - continue - # Allow the alias - models.append(model_name) - else: - # This is a base model with config dict - # Check restrictions if enabled - if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name): - continue - models.append(model_name) - - return models - - def list_all_known_models(self) -> list[str]: - """Return all model names known by this provider, including alias targets. - - Returns: - List of all model names and alias targets known by this provider - """ - all_models = set() - - for model_name, config in self.SUPPORTED_MODELS.items(): - # Add the model name itself - all_models.add(model_name.lower()) - - # If it's an alias (string value), add the target model too - if isinstance(config, str): - all_models.add(config.lower()) - - return list(all_models) - - def _resolve_model_name(self, model_name: str) -> str: - """Resolve model shorthand to full name.""" - # Check if it's a shorthand - shorthand_value = self.SUPPORTED_MODELS.get(model_name.lower()) - if isinstance(shorthand_value, str): - return shorthand_value - return model_name - def _extract_usage(self, response) -> dict[str, int]: """Extract token usage from Gemini response.""" usage = {} diff --git a/providers/openai_provider.py b/providers/openai_provider.py index 3553673..e065ee1 100644 --- a/providers/openai_provider.py +++ b/providers/openai_provider.py @@ -17,71 +17,110 @@ logger = logging.getLogger(__name__) class OpenAIModelProvider(OpenAICompatibleProvider): """Official OpenAI API provider (api.openai.com).""" - # Model configurations + # Model configurations using ModelCapabilities objects SUPPORTED_MODELS = { - "o3": { - "context_window": 200_000, # 200K tokens - "supports_extended_thinking": False, - "supports_images": True, # O3 models support vision - "max_image_size_mb": 20.0, # 20MB per OpenAI docs - "supports_temperature": False, # O3 models don't accept temperature parameter - "temperature_constraint": "fixed", # Fixed at 1.0 - "description": "Strong reasoning (200K context) - Logical problems, code generation, systematic analysis", - }, - "o3-mini": { - "context_window": 200_000, # 200K tokens - "supports_extended_thinking": False, - "supports_images": True, # O3 models support vision - "max_image_size_mb": 20.0, # 20MB per OpenAI docs - "supports_temperature": False, # O3 models don't accept temperature parameter - "temperature_constraint": "fixed", # Fixed at 1.0 - "description": "Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity", - }, - "o3-pro-2025-06-10": { - "context_window": 200_000, # 200K tokens - "supports_extended_thinking": False, - "supports_images": True, # O3 models support vision - "max_image_size_mb": 20.0, # 20MB per OpenAI docs - "supports_temperature": False, # O3 models don't accept temperature parameter - "temperature_constraint": "fixed", # Fixed at 1.0 - "description": "Professional-grade reasoning (200K context) - EXTREMELY EXPENSIVE: Only for the most complex problems requiring universe-scale complexity analysis OR when the user explicitly asks for this model. Use sparingly for critical architectural decisions or exceptionally complex debugging that other models cannot handle.", - }, - # Aliases - "o3-pro": "o3-pro-2025-06-10", - "o4-mini": { - "context_window": 200_000, # 200K tokens - "supports_extended_thinking": False, - "supports_images": True, # O4 models support vision - "max_image_size_mb": 20.0, # 20MB per OpenAI docs - "supports_temperature": False, # O4 models don't accept temperature parameter - "temperature_constraint": "fixed", # Fixed at 1.0 - "description": "Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning", - }, - "o4-mini-high": { - "context_window": 200_000, # 200K tokens - "supports_extended_thinking": False, - "supports_images": True, # O4 models support vision - "max_image_size_mb": 20.0, # 20MB per OpenAI docs - "supports_temperature": False, # O4 models don't accept temperature parameter - "temperature_constraint": "fixed", # Fixed at 1.0 - "description": "Enhanced O4 mini (200K context) - Higher reasoning effort for complex tasks", - }, - "gpt-4.1-2025-04-14": { - "context_window": 1_000_000, # 1M tokens - "supports_extended_thinking": False, - "supports_images": True, # GPT-4.1 supports vision - "max_image_size_mb": 20.0, # 20MB per OpenAI docs - "supports_temperature": True, # Regular models accept temperature parameter - "temperature_constraint": "range", # 0.0-2.0 range - "description": "GPT-4.1 (1M context) - Advanced reasoning model with large context window", - }, - # Shorthands - "mini": "o4-mini", # Default 'mini' to latest mini model - "o3mini": "o3-mini", - "o4mini": "o4-mini", - "o4minihigh": "o4-mini-high", - "o4minihi": "o4-mini-high", - "gpt4.1": "gpt-4.1-2025-04-14", + "o3": ModelCapabilities( + provider=ProviderType.OPENAI, + model_name="o3", + friendly_name="OpenAI (O3)", + context_window=200_000, # 200K tokens + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # O3 models support vision + max_image_size_mb=20.0, # 20MB per OpenAI docs + supports_temperature=False, # O3 models don't accept temperature parameter + temperature_constraint=create_temperature_constraint("fixed"), + description="Strong reasoning (200K context) - Logical problems, code generation, systematic analysis", + aliases=[], + ), + "o3-mini": ModelCapabilities( + provider=ProviderType.OPENAI, + model_name="o3-mini", + friendly_name="OpenAI (O3-mini)", + context_window=200_000, # 200K tokens + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # O3 models support vision + max_image_size_mb=20.0, # 20MB per OpenAI docs + supports_temperature=False, # O3 models don't accept temperature parameter + temperature_constraint=create_temperature_constraint("fixed"), + description="Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity", + aliases=["o3mini", "o3-mini"], + ), + "o3-pro-2025-06-10": ModelCapabilities( + provider=ProviderType.OPENAI, + model_name="o3-pro-2025-06-10", + friendly_name="OpenAI (O3-Pro)", + context_window=200_000, # 200K tokens + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # O3 models support vision + max_image_size_mb=20.0, # 20MB per OpenAI docs + supports_temperature=False, # O3 models don't accept temperature parameter + temperature_constraint=create_temperature_constraint("fixed"), + description="Professional-grade reasoning (200K context) - EXTREMELY EXPENSIVE: Only for the most complex problems requiring universe-scale complexity analysis OR when the user explicitly asks for this model. Use sparingly for critical architectural decisions or exceptionally complex debugging that other models cannot handle.", + aliases=["o3-pro"], + ), + "o4-mini": ModelCapabilities( + provider=ProviderType.OPENAI, + model_name="o4-mini", + friendly_name="OpenAI (O4-mini)", + context_window=200_000, # 200K tokens + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # O4 models support vision + max_image_size_mb=20.0, # 20MB per OpenAI docs + supports_temperature=False, # O4 models don't accept temperature parameter + temperature_constraint=create_temperature_constraint("fixed"), + description="Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning", + aliases=["mini", "o4mini"], + ), + "o4-mini-high": ModelCapabilities( + provider=ProviderType.OPENAI, + model_name="o4-mini-high", + friendly_name="OpenAI (O4-mini-high)", + context_window=200_000, # 200K tokens + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # O4 models support vision + max_image_size_mb=20.0, # 20MB per OpenAI docs + supports_temperature=False, # O4 models don't accept temperature parameter + temperature_constraint=create_temperature_constraint("fixed"), + description="Enhanced O4 mini (200K context) - Higher reasoning effort for complex tasks", + aliases=["o4minihigh", "o4minihi", "mini-high"], + ), + "gpt-4.1-2025-04-14": ModelCapabilities( + provider=ProviderType.OPENAI, + model_name="gpt-4.1-2025-04-14", + friendly_name="OpenAI (GPT 4.1)", + context_window=1_000_000, # 1M tokens + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # GPT-4.1 supports vision + max_image_size_mb=20.0, # 20MB per OpenAI docs + supports_temperature=True, # Regular models accept temperature parameter + temperature_constraint=create_temperature_constraint("range"), + description="GPT-4.1 (1M context) - Advanced reasoning model with large context window", + aliases=["gpt4.1"], + ), } def __init__(self, api_key: str, **kwargs): @@ -95,7 +134,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider): # Resolve shorthand resolved_name = self._resolve_model_name(model_name) - if resolved_name not in self.SUPPORTED_MODELS or isinstance(self.SUPPORTED_MODELS[resolved_name], str): + if resolved_name not in self.SUPPORTED_MODELS: raise ValueError(f"Unsupported OpenAI model: {model_name}") # Check if model is allowed by restrictions @@ -105,27 +144,8 @@ class OpenAIModelProvider(OpenAICompatibleProvider): if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name): raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.") - config = self.SUPPORTED_MODELS[resolved_name] - - # Get temperature constraints and support from configuration - supports_temperature = config.get("supports_temperature", True) # Default to True for backward compatibility - temp_constraint_type = config.get("temperature_constraint", "range") # Default to range - temp_constraint = create_temperature_constraint(temp_constraint_type) - - return ModelCapabilities( - provider=ProviderType.OPENAI, - model_name=model_name, - friendly_name="OpenAI", - context_window=config["context_window"], - supports_extended_thinking=config["supports_extended_thinking"], - supports_system_prompts=True, - supports_streaming=True, - supports_function_calling=True, - supports_images=config.get("supports_images", False), - max_image_size_mb=config.get("max_image_size_mb", 0.0), - supports_temperature=supports_temperature, - temperature_constraint=temp_constraint, - ) + # Return the ModelCapabilities object directly from SUPPORTED_MODELS + return self.SUPPORTED_MODELS[resolved_name] def get_provider_type(self) -> ProviderType: """Get the provider type.""" @@ -136,7 +156,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider): resolved_name = self._resolve_model_name(model_name) # First check if model is supported - if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict): + if resolved_name not in self.SUPPORTED_MODELS: return False # Then check if model is allowed by restrictions @@ -177,61 +197,3 @@ class OpenAIModelProvider(OpenAICompatibleProvider): # Currently no OpenAI models support extended thinking # This may change with future O3 models return False - - def list_models(self, respect_restrictions: bool = True) -> list[str]: - """Return a list of model names supported by this provider. - - Args: - respect_restrictions: Whether to apply provider-specific restriction logic. - - Returns: - List of model names available from this provider - """ - from utils.model_restrictions import get_restriction_service - - restriction_service = get_restriction_service() if respect_restrictions else None - models = [] - - for model_name, config in self.SUPPORTED_MODELS.items(): - # Handle both base models (dict configs) and aliases (string values) - if isinstance(config, str): - # This is an alias - check if the target model would be allowed - target_model = config - if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model): - continue - # Allow the alias - models.append(model_name) - else: - # This is a base model with config dict - # Check restrictions if enabled - if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name): - continue - models.append(model_name) - - return models - - def list_all_known_models(self) -> list[str]: - """Return all model names known by this provider, including alias targets. - - Returns: - List of all model names and alias targets known by this provider - """ - all_models = set() - - for model_name, config in self.SUPPORTED_MODELS.items(): - # Add the model name itself - all_models.add(model_name.lower()) - - # If it's an alias (string value), add the target model too - if isinstance(config, str): - all_models.add(config.lower()) - - return list(all_models) - - def _resolve_model_name(self, model_name: str) -> str: - """Resolve model shorthand to full name.""" - # Check if it's a shorthand - shorthand_value = self.SUPPORTED_MODELS.get(model_name) - if isinstance(shorthand_value, str): - return shorthand_value - return model_name diff --git a/providers/openrouter.py b/providers/openrouter.py index 1e22b45..5d29514 100644 --- a/providers/openrouter.py +++ b/providers/openrouter.py @@ -270,3 +270,39 @@ class OpenRouterProvider(OpenAICompatibleProvider): all_models.add(config.model_name.lower()) return list(all_models) + + def get_model_configurations(self) -> dict[str, ModelCapabilities]: + """Get model configurations from the registry. + + For OpenRouter, we convert registry configurations to ModelCapabilities objects. + + Returns: + Dictionary mapping model names to their ModelCapabilities objects + """ + configs = {} + + if self._registry: + # Get all models from registry + for model_name in self._registry.list_models(): + # Only include models that this provider validates + if self.validate_model_name(model_name): + config = self._registry.resolve(model_name) + if config and not config.is_custom: # Only OpenRouter models, not custom ones + # Convert OpenRouterModelConfig to ModelCapabilities + capabilities = config.to_capabilities() + # Override provider type to OPENROUTER + capabilities.provider = ProviderType.OPENROUTER + capabilities.friendly_name = f"{self.FRIENDLY_NAME} ({config.model_name})" + configs[model_name] = capabilities + + return configs + + def get_all_model_aliases(self) -> dict[str, list[str]]: + """Get all model aliases from the registry. + + Returns: + Dictionary mapping model names to their list of aliases + """ + # Since aliases are now included in the configurations, + # we can use the base class implementation + return super().get_all_model_aliases() diff --git a/providers/xai.py b/providers/xai.py index 71d5c8a..2b6fd04 100644 --- a/providers/xai.py +++ b/providers/xai.py @@ -7,7 +7,7 @@ from .base import ( ModelCapabilities, ModelResponse, ProviderType, - RangeTemperatureConstraint, + create_temperature_constraint, ) from .openai_compatible import OpenAICompatibleProvider @@ -19,23 +19,42 @@ class XAIModelProvider(OpenAICompatibleProvider): FRIENDLY_NAME = "X.AI" - # Model configurations + # Model configurations using ModelCapabilities objects SUPPORTED_MODELS = { - "grok-3": { - "context_window": 131_072, # 131K tokens - "supports_extended_thinking": False, - "description": "GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis", - }, - "grok-3-fast": { - "context_window": 131_072, # 131K tokens - "supports_extended_thinking": False, - "description": "GROK-3 Fast (131K context) - Higher performance variant, faster processing but more expensive", - }, - # Shorthands for convenience - "grok": "grok-3", # Default to grok-3 - "grok3": "grok-3", - "grok3fast": "grok-3-fast", - "grokfast": "grok-3-fast", + "grok-3": ModelCapabilities( + provider=ProviderType.XAI, + model_name="grok-3", + friendly_name="X.AI (Grok 3)", + context_window=131_072, # 131K tokens + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=False, # Assuming GROK doesn't have JSON mode yet + supports_images=False, # Assuming GROK is text-only for now + max_image_size_mb=0.0, + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + description="GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis", + aliases=["grok", "grok3"], + ), + "grok-3-fast": ModelCapabilities( + provider=ProviderType.XAI, + model_name="grok-3-fast", + friendly_name="X.AI (Grok 3 Fast)", + context_window=131_072, # 131K tokens + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=False, # Assuming GROK doesn't have JSON mode yet + supports_images=False, # Assuming GROK is text-only for now + max_image_size_mb=0.0, + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + description="GROK-3 Fast (131K context) - Higher performance variant, faster processing but more expensive", + aliases=["grok3fast", "grokfast", "grok3-fast"], + ), } def __init__(self, api_key: str, **kwargs): @@ -49,7 +68,7 @@ class XAIModelProvider(OpenAICompatibleProvider): # Resolve shorthand resolved_name = self._resolve_model_name(model_name) - if resolved_name not in self.SUPPORTED_MODELS or isinstance(self.SUPPORTED_MODELS[resolved_name], str): + if resolved_name not in self.SUPPORTED_MODELS: raise ValueError(f"Unsupported X.AI model: {model_name}") # Check if model is allowed by restrictions @@ -59,23 +78,8 @@ class XAIModelProvider(OpenAICompatibleProvider): if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name): raise ValueError(f"X.AI model '{model_name}' is not allowed by restriction policy.") - config = self.SUPPORTED_MODELS[resolved_name] - - # Define temperature constraints for GROK models - # GROK supports the standard OpenAI temperature range - temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7) - - return ModelCapabilities( - provider=ProviderType.XAI, - model_name=resolved_name, - friendly_name=self.FRIENDLY_NAME, - context_window=config["context_window"], - supports_extended_thinking=config["supports_extended_thinking"], - supports_system_prompts=True, - supports_streaming=True, - supports_function_calling=True, - temperature_constraint=temp_constraint, - ) + # Return the ModelCapabilities object directly from SUPPORTED_MODELS + return self.SUPPORTED_MODELS[resolved_name] def get_provider_type(self) -> ProviderType: """Get the provider type.""" @@ -86,7 +90,7 @@ class XAIModelProvider(OpenAICompatibleProvider): resolved_name = self._resolve_model_name(model_name) # First check if model is supported - if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict): + if resolved_name not in self.SUPPORTED_MODELS: return False # Then check if model is allowed by restrictions @@ -127,61 +131,3 @@ class XAIModelProvider(OpenAICompatibleProvider): # Currently GROK models do not support extended thinking # This may change with future GROK model releases return False - - def list_models(self, respect_restrictions: bool = True) -> list[str]: - """Return a list of model names supported by this provider. - - Args: - respect_restrictions: Whether to apply provider-specific restriction logic. - - Returns: - List of model names available from this provider - """ - from utils.model_restrictions import get_restriction_service - - restriction_service = get_restriction_service() if respect_restrictions else None - models = [] - - for model_name, config in self.SUPPORTED_MODELS.items(): - # Handle both base models (dict configs) and aliases (string values) - if isinstance(config, str): - # This is an alias - check if the target model would be allowed - target_model = config - if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model): - continue - # Allow the alias - models.append(model_name) - else: - # This is a base model with config dict - # Check restrictions if enabled - if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name): - continue - models.append(model_name) - - return models - - def list_all_known_models(self) -> list[str]: - """Return all model names known by this provider, including alias targets. - - Returns: - List of all model names and alias targets known by this provider - """ - all_models = set() - - for model_name, config in self.SUPPORTED_MODELS.items(): - # Add the model name itself - all_models.add(model_name.lower()) - - # If it's an alias (string value), add the target model too - if isinstance(config, str): - all_models.add(config.lower()) - - return list(all_models) - - def _resolve_model_name(self, model_name: str) -> str: - """Resolve model shorthand to full name.""" - # Check if it's a shorthand - shorthand_value = self.SUPPORTED_MODELS.get(model_name) - if isinstance(shorthand_value, str): - return shorthand_value - return model_name diff --git a/tests/test_auto_mode.py b/tests/test_auto_mode.py index 1aa4376..74d8ae3 100644 --- a/tests/test_auto_mode.py +++ b/tests/test_auto_mode.py @@ -59,7 +59,7 @@ class TestAutoMode: continue # Check that model has description - description = config.get("description", "") + description = config.description if hasattr(config, "description") else "" if description: models_with_descriptions[model_name] = description diff --git a/tests/test_auto_mode_comprehensive.py b/tests/test_auto_mode_comprehensive.py index 8539fdf..4d699b0 100644 --- a/tests/test_auto_mode_comprehensive.py +++ b/tests/test_auto_mode_comprehensive.py @@ -319,7 +319,18 @@ class TestAutoModeComprehensive: m for m in available_models if not m.startswith("gemini") - and m not in ["flash", "pro", "flash-2.0", "flash2", "flashlite", "flash-lite"] + and m + not in [ + "flash", + "pro", + "flash-2.0", + "flash2", + "flashlite", + "flash-lite", + "flash2.5", + "gemini pro", + "gemini-pro", + ] ] assert ( len(non_gemini_models) == 0 diff --git a/tests/test_dial_provider.py b/tests/test_dial_provider.py index 4a22cb6..62af59c 100644 --- a/tests/test_dial_provider.py +++ b/tests/test_dial_provider.py @@ -84,7 +84,7 @@ class TestDIALProvider: # Test O3 capabilities capabilities = provider.get_capabilities("o3") assert capabilities.model_name == "o3-2025-04-16" - assert capabilities.friendly_name == "DIAL" + assert capabilities.friendly_name == "DIAL (O3)" assert capabilities.context_window == 200_000 assert capabilities.provider == ProviderType.DIAL assert capabilities.supports_images is True diff --git a/tests/test_openai_provider.py b/tests/test_openai_provider.py index e9e3ae8..baab182 100644 --- a/tests/test_openai_provider.py +++ b/tests/test_openai_provider.py @@ -85,7 +85,7 @@ class TestOpenAIProvider: capabilities = provider.get_capabilities("o3") assert capabilities.model_name == "o3" # Should NOT be resolved in capabilities - assert capabilities.friendly_name == "OpenAI" + assert capabilities.friendly_name == "OpenAI (O3)" assert capabilities.context_window == 200_000 assert capabilities.provider == ProviderType.OPENAI assert not capabilities.supports_extended_thinking @@ -101,8 +101,8 @@ class TestOpenAIProvider: provider = OpenAIModelProvider("test-key") capabilities = provider.get_capabilities("mini") - assert capabilities.model_name == "mini" # Capabilities should show original request - assert capabilities.friendly_name == "OpenAI" + assert capabilities.model_name == "o4-mini" # Capabilities should show resolved model name + assert capabilities.friendly_name == "OpenAI (O4-mini)" assert capabilities.context_window == 200_000 assert capabilities.provider == ProviderType.OPENAI diff --git a/tests/test_supported_models_aliases.py b/tests/test_supported_models_aliases.py new file mode 100644 index 0000000..6ed899f --- /dev/null +++ b/tests/test_supported_models_aliases.py @@ -0,0 +1,206 @@ +"""Test the SUPPORTED_MODELS aliases structure across all providers.""" + +from providers.dial import DIALModelProvider +from providers.gemini import GeminiModelProvider +from providers.openai_provider import OpenAIModelProvider +from providers.xai import XAIModelProvider + + +class TestSupportedModelsAliases: + """Test that all providers have correctly structured SUPPORTED_MODELS with aliases.""" + + def test_gemini_provider_aliases(self): + """Test Gemini provider's alias structure.""" + provider = GeminiModelProvider("test-key") + + # Check that all models have ModelCapabilities with aliases + for model_name, config in provider.SUPPORTED_MODELS.items(): + assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute" + assert isinstance(config.aliases, list), f"{model_name} aliases must be a list" + + # Test specific aliases + assert "flash" in provider.SUPPORTED_MODELS["gemini-2.5-flash"].aliases + assert "pro" in provider.SUPPORTED_MODELS["gemini-2.5-pro"].aliases + assert "flash-2.0" in provider.SUPPORTED_MODELS["gemini-2.0-flash"].aliases + assert "flash2" in provider.SUPPORTED_MODELS["gemini-2.0-flash"].aliases + assert "flashlite" in provider.SUPPORTED_MODELS["gemini-2.0-flash-lite"].aliases + assert "flash-lite" in provider.SUPPORTED_MODELS["gemini-2.0-flash-lite"].aliases + + # Test alias resolution + assert provider._resolve_model_name("flash") == "gemini-2.5-flash" + assert provider._resolve_model_name("pro") == "gemini-2.5-pro" + assert provider._resolve_model_name("flash-2.0") == "gemini-2.0-flash" + assert provider._resolve_model_name("flash2") == "gemini-2.0-flash" + assert provider._resolve_model_name("flashlite") == "gemini-2.0-flash-lite" + + # Test case insensitive resolution + assert provider._resolve_model_name("Flash") == "gemini-2.5-flash" + assert provider._resolve_model_name("PRO") == "gemini-2.5-pro" + + def test_openai_provider_aliases(self): + """Test OpenAI provider's alias structure.""" + provider = OpenAIModelProvider("test-key") + + # Check that all models have ModelCapabilities with aliases + for model_name, config in provider.SUPPORTED_MODELS.items(): + assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute" + assert isinstance(config.aliases, list), f"{model_name} aliases must be a list" + + # Test specific aliases + assert "mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases + assert "o4mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases + assert "o3mini" in provider.SUPPORTED_MODELS["o3-mini"].aliases + assert "o3-pro" in provider.SUPPORTED_MODELS["o3-pro-2025-06-10"].aliases + assert "o4minihigh" in provider.SUPPORTED_MODELS["o4-mini-high"].aliases + assert "o4minihi" in provider.SUPPORTED_MODELS["o4-mini-high"].aliases + assert "gpt4.1" in provider.SUPPORTED_MODELS["gpt-4.1-2025-04-14"].aliases + + # Test alias resolution + assert provider._resolve_model_name("mini") == "o4-mini" + assert provider._resolve_model_name("o3mini") == "o3-mini" + assert provider._resolve_model_name("o3-pro") == "o3-pro-2025-06-10" + assert provider._resolve_model_name("o4minihigh") == "o4-mini-high" + assert provider._resolve_model_name("gpt4.1") == "gpt-4.1-2025-04-14" + + # Test case insensitive resolution + assert provider._resolve_model_name("Mini") == "o4-mini" + assert provider._resolve_model_name("O3MINI") == "o3-mini" + + def test_xai_provider_aliases(self): + """Test XAI provider's alias structure.""" + provider = XAIModelProvider("test-key") + + # Check that all models have ModelCapabilities with aliases + for model_name, config in provider.SUPPORTED_MODELS.items(): + assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute" + assert isinstance(config.aliases, list), f"{model_name} aliases must be a list" + + # Test specific aliases + assert "grok" in provider.SUPPORTED_MODELS["grok-3"].aliases + assert "grok3" in provider.SUPPORTED_MODELS["grok-3"].aliases + assert "grok3fast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases + assert "grokfast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases + + # Test alias resolution + assert provider._resolve_model_name("grok") == "grok-3" + assert provider._resolve_model_name("grok3") == "grok-3" + assert provider._resolve_model_name("grok3fast") == "grok-3-fast" + assert provider._resolve_model_name("grokfast") == "grok-3-fast" + + # Test case insensitive resolution + assert provider._resolve_model_name("Grok") == "grok-3" + assert provider._resolve_model_name("GROKFAST") == "grok-3-fast" + + def test_dial_provider_aliases(self): + """Test DIAL provider's alias structure.""" + provider = DIALModelProvider("test-key") + + # Check that all models have ModelCapabilities with aliases + for model_name, config in provider.SUPPORTED_MODELS.items(): + assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute" + assert isinstance(config.aliases, list), f"{model_name} aliases must be a list" + + # Test specific aliases + assert "o3" in provider.SUPPORTED_MODELS["o3-2025-04-16"].aliases + assert "o4-mini" in provider.SUPPORTED_MODELS["o4-mini-2025-04-16"].aliases + assert "sonnet-4" in provider.SUPPORTED_MODELS["anthropic.claude-sonnet-4-20250514-v1:0"].aliases + assert "opus-4" in provider.SUPPORTED_MODELS["anthropic.claude-opus-4-20250514-v1:0"].aliases + assert "gemini-2.5-pro" in provider.SUPPORTED_MODELS["gemini-2.5-pro-preview-05-06"].aliases + + # Test alias resolution + assert provider._resolve_model_name("o3") == "o3-2025-04-16" + assert provider._resolve_model_name("o4-mini") == "o4-mini-2025-04-16" + assert provider._resolve_model_name("sonnet-4") == "anthropic.claude-sonnet-4-20250514-v1:0" + assert provider._resolve_model_name("opus-4") == "anthropic.claude-opus-4-20250514-v1:0" + + # Test case insensitive resolution + assert provider._resolve_model_name("O3") == "o3-2025-04-16" + assert provider._resolve_model_name("SONNET-4") == "anthropic.claude-sonnet-4-20250514-v1:0" + + def test_list_models_includes_aliases(self): + """Test that list_models returns both base models and aliases.""" + # Test Gemini + gemini_provider = GeminiModelProvider("test-key") + gemini_models = gemini_provider.list_models(respect_restrictions=False) + assert "gemini-2.5-flash" in gemini_models + assert "flash" in gemini_models + assert "gemini-2.5-pro" in gemini_models + assert "pro" in gemini_models + + # Test OpenAI + openai_provider = OpenAIModelProvider("test-key") + openai_models = openai_provider.list_models(respect_restrictions=False) + assert "o4-mini" in openai_models + assert "mini" in openai_models + assert "o3-mini" in openai_models + assert "o3mini" in openai_models + + # Test XAI + xai_provider = XAIModelProvider("test-key") + xai_models = xai_provider.list_models(respect_restrictions=False) + assert "grok-3" in xai_models + assert "grok" in xai_models + assert "grok-3-fast" in xai_models + assert "grokfast" in xai_models + + # Test DIAL + dial_provider = DIALModelProvider("test-key") + dial_models = dial_provider.list_models(respect_restrictions=False) + assert "o3-2025-04-16" in dial_models + assert "o3" in dial_models + + def test_list_all_known_models_includes_aliases(self): + """Test that list_all_known_models returns all models and aliases in lowercase.""" + # Test Gemini + gemini_provider = GeminiModelProvider("test-key") + gemini_all = gemini_provider.list_all_known_models() + assert "gemini-2.5-flash" in gemini_all + assert "flash" in gemini_all + assert "gemini-2.5-pro" in gemini_all + assert "pro" in gemini_all + # All should be lowercase + assert all(model == model.lower() for model in gemini_all) + + # Test OpenAI + openai_provider = OpenAIModelProvider("test-key") + openai_all = openai_provider.list_all_known_models() + assert "o4-mini" in openai_all + assert "mini" in openai_all + assert "o3-mini" in openai_all + assert "o3mini" in openai_all + # All should be lowercase + assert all(model == model.lower() for model in openai_all) + + def test_no_string_shorthand_in_supported_models(self): + """Test that no provider has string-based shorthands anymore.""" + providers = [ + GeminiModelProvider("test-key"), + OpenAIModelProvider("test-key"), + XAIModelProvider("test-key"), + DIALModelProvider("test-key"), + ] + + for provider in providers: + for model_name, config in provider.SUPPORTED_MODELS.items(): + # All values must be ModelCapabilities objects, not strings or dicts + from providers.base import ModelCapabilities + + assert isinstance(config, ModelCapabilities), ( + f"{provider.__class__.__name__}.SUPPORTED_MODELS['{model_name}'] " + f"must be a ModelCapabilities object, not {type(config).__name__}" + ) + + def test_resolve_returns_original_if_not_found(self): + """Test that _resolve_model_name returns original name if alias not found.""" + providers = [ + GeminiModelProvider("test-key"), + OpenAIModelProvider("test-key"), + XAIModelProvider("test-key"), + DIALModelProvider("test-key"), + ] + + for provider in providers: + # Test with unknown model name + assert provider._resolve_model_name("unknown-model") == "unknown-model" + assert provider._resolve_model_name("gpt-4") == "gpt-4" + assert provider._resolve_model_name("claude-3") == "claude-3" diff --git a/tests/test_xai_provider.py b/tests/test_xai_provider.py index e002636..978d9c1 100644 --- a/tests/test_xai_provider.py +++ b/tests/test_xai_provider.py @@ -77,7 +77,7 @@ class TestXAIProvider: capabilities = provider.get_capabilities("grok-3") assert capabilities.model_name == "grok-3" - assert capabilities.friendly_name == "X.AI" + assert capabilities.friendly_name == "X.AI (Grok 3)" assert capabilities.context_window == 131_072 assert capabilities.provider == ProviderType.XAI assert not capabilities.supports_extended_thinking @@ -96,7 +96,7 @@ class TestXAIProvider: capabilities = provider.get_capabilities("grok-3-fast") assert capabilities.model_name == "grok-3-fast" - assert capabilities.friendly_name == "X.AI" + assert capabilities.friendly_name == "X.AI (Grok 3 Fast)" assert capabilities.context_window == 131_072 assert capabilities.provider == ProviderType.XAI assert not capabilities.supports_extended_thinking @@ -212,31 +212,34 @@ class TestXAIProvider: assert provider.FRIENDLY_NAME == "X.AI" capabilities = provider.get_capabilities("grok-3") - assert capabilities.friendly_name == "X.AI" + assert capabilities.friendly_name == "X.AI (Grok 3)" def test_supported_models_structure(self): """Test that SUPPORTED_MODELS has the correct structure.""" provider = XAIModelProvider("test-key") - # Check that all expected models are present + # Check that all expected base models are present assert "grok-3" in provider.SUPPORTED_MODELS assert "grok-3-fast" in provider.SUPPORTED_MODELS - assert "grok" in provider.SUPPORTED_MODELS - assert "grok3" in provider.SUPPORTED_MODELS - assert "grokfast" in provider.SUPPORTED_MODELS - assert "grok3fast" in provider.SUPPORTED_MODELS # Check model configs have required fields - grok3_config = provider.SUPPORTED_MODELS["grok-3"] - assert isinstance(grok3_config, dict) - assert "context_window" in grok3_config - assert "supports_extended_thinking" in grok3_config - assert grok3_config["context_window"] == 131_072 - assert grok3_config["supports_extended_thinking"] is False + from providers.base import ModelCapabilities - # Check shortcuts point to full names - assert provider.SUPPORTED_MODELS["grok"] == "grok-3" - assert provider.SUPPORTED_MODELS["grokfast"] == "grok-3-fast" + grok3_config = provider.SUPPORTED_MODELS["grok-3"] + assert isinstance(grok3_config, ModelCapabilities) + assert hasattr(grok3_config, "context_window") + assert hasattr(grok3_config, "supports_extended_thinking") + assert hasattr(grok3_config, "aliases") + assert grok3_config.context_window == 131_072 + assert grok3_config.supports_extended_thinking is False + + # Check aliases are correctly structured + assert "grok" in grok3_config.aliases + assert "grok3" in grok3_config.aliases + + grok3fast_config = provider.SUPPORTED_MODELS["grok-3-fast"] + assert "grok3fast" in grok3fast_config.aliases + assert "grokfast" in grok3fast_config.aliases @patch("providers.openai_compatible.OpenAI") def test_generate_content_resolves_alias_before_api_call(self, mock_openai_class): diff --git a/tools/listmodels.py b/tools/listmodels.py index 265fbcc..0813ee7 100644 --- a/tools/listmodels.py +++ b/tools/listmodels.py @@ -99,15 +99,11 @@ class ListModelsTool(BaseTool): output_lines.append("**Status**: Configured and available") output_lines.append("\n**Models**:") - # Get models from the provider's SUPPORTED_MODELS - for model_name, config in provider.SUPPORTED_MODELS.items(): - # Skip alias entries (string values) - if isinstance(config, str): - continue - - # Get description and context from the model config - description = config.get("description", "No description available") - context_window = config.get("context_window", 0) + # Get models from the provider's model configurations + for model_name, capabilities in provider.get_model_configurations().items(): + # Get description and context from the ModelCapabilities object + description = capabilities.description or "No description available" + context_window = capabilities.context_window # Format context window if context_window >= 1_000_000: @@ -133,13 +129,14 @@ class ListModelsTool(BaseTool): # Show aliases for this provider aliases = [] - for alias_name, target in provider.SUPPORTED_MODELS.items(): - if isinstance(target, str): # This is an alias - aliases.append(f"- `{alias_name}` → `{target}`") + for model_name, capabilities in provider.get_model_configurations().items(): + if capabilities.aliases: + for alias in capabilities.aliases: + aliases.append(f"- `{alias}` → `{model_name}`") if aliases: output_lines.append("\n**Aliases**:") - output_lines.extend(aliases) + output_lines.extend(sorted(aliases)) # Sort for consistent output else: output_lines.append(f"**Status**: Not configured (set {info['env_key']})") @@ -237,7 +234,7 @@ class ListModelsTool(BaseTool): for alias in registry.list_aliases(): config = registry.resolve(alias) - if config and hasattr(config, "is_custom") and config.is_custom: + if config and config.is_custom: custom_models.append((alias, config)) if custom_models: diff --git a/tools/shared/base_tool.py b/tools/shared/base_tool.py index a98baf8..10b223f 100644 --- a/tools/shared/base_tool.py +++ b/tools/shared/base_tool.py @@ -256,8 +256,8 @@ class BaseTool(ABC): # Find all custom models (is_custom=true) for alias in registry.list_aliases(): config = registry.resolve(alias) - # Use hasattr for defensive programming - is_custom is optional with default False - if config and hasattr(config, "is_custom") and config.is_custom: + # Check if this is a custom model that requires custom endpoints + if config and config.is_custom: if alias not in all_models: all_models.append(alias) except Exception as e: @@ -311,12 +311,16 @@ class BaseTool(ABC): ProviderType.GOOGLE: "Gemini models", ProviderType.OPENAI: "OpenAI models", ProviderType.XAI: "X.AI GROK models", + ProviderType.DIAL: "DIAL models", ProviderType.CUSTOM: "Custom models", ProviderType.OPENROUTER: "OpenRouter models", } # Check available providers and add their model descriptions - for provider_type in [ProviderType.GOOGLE, ProviderType.OPENAI, ProviderType.XAI]: + + # Start with native providers + for provider_type in [ProviderType.GOOGLE, ProviderType.OPENAI, ProviderType.XAI, ProviderType.DIAL]: + # Only if this is registered / available provider = ModelProviderRegistry.get_provider(provider_type) if provider: provider_section_added = False @@ -324,13 +328,13 @@ class BaseTool(ABC): try: # Get model config to extract description model_config = provider.SUPPORTED_MODELS.get(model_name) - if isinstance(model_config, dict) and "description" in model_config: + if model_config and model_config.description: if not provider_section_added: model_desc_parts.append( f"\n{provider_names[provider_type]} - Available when {provider_type.value.upper()}_API_KEY is configured:" ) provider_section_added = True - model_desc_parts.append(f"- '{model_name}': {model_config['description']}") + model_desc_parts.append(f"- '{model_name}': {model_config.description}") except Exception: # Skip models without descriptions continue @@ -346,8 +350,8 @@ class BaseTool(ABC): # Find all custom models (is_custom=true) for alias in registry.list_aliases(): config = registry.resolve(alias) - # Use hasattr for defensive programming - is_custom is optional with default False - if config and hasattr(config, "is_custom") and config.is_custom: + # Check if this is a custom model that requires custom endpoints + if config and config.is_custom: # Format context window context_tokens = config.context_window if context_tokens >= 1_000_000: From 14eaf930ed7922b8472c811ab67128717468be01 Mon Sep 17 00:00:00 2001 From: Fahad Date: Mon, 23 Jun 2025 17:39:47 +0400 Subject: [PATCH 3/8] Cleanup, use ModelCapabilities only --- providers/custom.py | 9 +--- providers/openrouter.py | 8 +-- providers/openrouter_registry.py | 85 ++++++++++--------------------- tests/test_openrouter_provider.py | 2 +- tests/test_openrouter_registry.py | 27 +++++----- 5 files changed, 47 insertions(+), 84 deletions(-) diff --git a/providers/custom.py b/providers/custom.py index 52d9b94..021bba5 100644 --- a/providers/custom.py +++ b/providers/custom.py @@ -291,7 +291,6 @@ class CustomProvider(OpenAICompatibleProvider): Returns: Dictionary mapping model names to their ModelCapabilities objects """ - from .base import ProviderType configs = {} @@ -302,12 +301,8 @@ class CustomProvider(OpenAICompatibleProvider): if self.validate_model_name(model_name): config = self._registry.resolve(model_name) if config and config.is_custom: - # Convert OpenRouterModelConfig to ModelCapabilities - capabilities = config.to_capabilities() - # Override provider type to CUSTOM for local models - capabilities.provider = ProviderType.CUSTOM - capabilities.friendly_name = f"{self.FRIENDLY_NAME} ({config.model_name})" - configs[model_name] = capabilities + # Use ModelCapabilities directly from registry + configs[model_name] = config return configs diff --git a/providers/openrouter.py b/providers/openrouter.py index 5d29514..3d90238 100644 --- a/providers/openrouter.py +++ b/providers/openrouter.py @@ -288,12 +288,8 @@ class OpenRouterProvider(OpenAICompatibleProvider): if self.validate_model_name(model_name): config = self._registry.resolve(model_name) if config and not config.is_custom: # Only OpenRouter models, not custom ones - # Convert OpenRouterModelConfig to ModelCapabilities - capabilities = config.to_capabilities() - # Override provider type to OPENROUTER - capabilities.provider = ProviderType.OPENROUTER - capabilities.friendly_name = f"{self.FRIENDLY_NAME} ({config.model_name})" - configs[model_name] = capabilities + # Use ModelCapabilities directly from registry + configs[model_name] = config return configs diff --git a/providers/openrouter_registry.py b/providers/openrouter_registry.py index 47258c8..97b8f60 100644 --- a/providers/openrouter_registry.py +++ b/providers/openrouter_registry.py @@ -2,7 +2,6 @@ import logging import os -from dataclasses import dataclass, field from pathlib import Path from typing import Optional @@ -11,58 +10,10 @@ from utils.file_utils import read_json_file from .base import ( ModelCapabilities, ProviderType, - TemperatureConstraint, create_temperature_constraint, ) -@dataclass -class OpenRouterModelConfig: - """Configuration for an OpenRouter model.""" - - model_name: str - aliases: list[str] = field(default_factory=list) - context_window: int = 32768 # Total context window size in tokens - supports_extended_thinking: bool = False - supports_system_prompts: bool = True - supports_streaming: bool = True - supports_function_calling: bool = False - supports_json_mode: bool = False - supports_images: bool = False # Whether model can process images - max_image_size_mb: float = 0.0 # Maximum total size for all images in MB - supports_temperature: bool = True # Whether model accepts temperature parameter in API calls - temperature_constraint: Optional[str] = ( - None # Type of temperature constraint: "fixed", "range", "discrete", or None for default range - ) - is_custom: bool = False # True for models that should only be used with custom endpoints - description: str = "" - - def _create_temperature_constraint(self) -> TemperatureConstraint: - """Create temperature constraint object from configuration. - - Returns: - TemperatureConstraint object based on configuration - """ - return create_temperature_constraint(self.temperature_constraint or "range") - - def to_capabilities(self) -> ModelCapabilities: - """Convert to ModelCapabilities object.""" - return ModelCapabilities( - provider=ProviderType.OPENROUTER, - model_name=self.model_name, - friendly_name="OpenRouter", - context_window=self.context_window, - supports_extended_thinking=self.supports_extended_thinking, - supports_system_prompts=self.supports_system_prompts, - supports_streaming=self.supports_streaming, - supports_function_calling=self.supports_function_calling, - supports_images=self.supports_images, - max_image_size_mb=self.max_image_size_mb, - supports_temperature=self.supports_temperature, - temperature_constraint=self._create_temperature_constraint(), - ) - - class OpenRouterModelRegistry: """Registry for managing OpenRouter model configurations and aliases.""" @@ -73,7 +24,7 @@ class OpenRouterModelRegistry: config_path: Path to config file. If None, uses default locations. """ self.alias_map: dict[str, str] = {} # alias -> model_name - self.model_map: dict[str, OpenRouterModelConfig] = {} # model_name -> config + self.model_map: dict[str, ModelCapabilities] = {} # model_name -> config # Determine config path if config_path: @@ -139,7 +90,7 @@ class OpenRouterModelRegistry: self.alias_map = {} self.model_map = {} - def _read_config(self) -> list[OpenRouterModelConfig]: + def _read_config(self) -> list[ModelCapabilities]: """Read configuration from file. Returns: @@ -158,7 +109,27 @@ class OpenRouterModelRegistry: # Parse models configs = [] for model_data in data.get("models", []): - config = OpenRouterModelConfig(**model_data) + # Create ModelCapabilities directly from JSON data + # Handle temperature_constraint conversion + temp_constraint_str = model_data.get("temperature_constraint") + temp_constraint = create_temperature_constraint(temp_constraint_str or "range") + + # Set provider-specific defaults based on is_custom flag + is_custom = model_data.get("is_custom", False) + if is_custom: + model_data.setdefault("provider", ProviderType.CUSTOM) + model_data.setdefault("friendly_name", f"Custom ({model_data.get('model_name', 'Unknown')})") + else: + model_data.setdefault("provider", ProviderType.OPENROUTER) + model_data.setdefault("friendly_name", f"OpenRouter ({model_data.get('model_name', 'Unknown')})") + model_data["temperature_constraint"] = temp_constraint + + # Remove the string version of temperature_constraint before creating ModelCapabilities + if "temperature_constraint" in model_data and isinstance(model_data["temperature_constraint"], str): + del model_data["temperature_constraint"] + model_data["temperature_constraint"] = temp_constraint + + config = ModelCapabilities(**model_data) configs.append(config) return configs @@ -168,7 +139,7 @@ class OpenRouterModelRegistry: except Exception as e: raise ValueError(f"Error reading config from {self.config_path}: {e}") - def _build_maps(self, configs: list[OpenRouterModelConfig]) -> None: + def _build_maps(self, configs: list[ModelCapabilities]) -> None: """Build alias and model maps from configurations. Args: @@ -211,7 +182,7 @@ class OpenRouterModelRegistry: self.alias_map = alias_map self.model_map = model_map - def resolve(self, name_or_alias: str) -> Optional[OpenRouterModelConfig]: + def resolve(self, name_or_alias: str) -> Optional[ModelCapabilities]: """Resolve a model name or alias to configuration. Args: @@ -237,10 +208,8 @@ class OpenRouterModelRegistry: Returns: ModelCapabilities if found, None otherwise """ - config = self.resolve(name_or_alias) - if config: - return config.to_capabilities() - return None + # Registry now returns ModelCapabilities directly + return self.resolve(name_or_alias) def list_models(self) -> list[str]: """List all available model names.""" diff --git a/tests/test_openrouter_provider.py b/tests/test_openrouter_provider.py index da10678..6d427ba 100644 --- a/tests/test_openrouter_provider.py +++ b/tests/test_openrouter_provider.py @@ -57,7 +57,7 @@ class TestOpenRouterProvider: caps = provider.get_capabilities("o3") assert caps.provider == ProviderType.OPENROUTER assert caps.model_name == "openai/o3" # Resolved name - assert caps.friendly_name == "OpenRouter" + assert caps.friendly_name == "OpenRouter (openai/o3)" # Test with a model not in registry - should get generic capabilities caps = provider.get_capabilities("unknown-model") diff --git a/tests/test_openrouter_registry.py b/tests/test_openrouter_registry.py index 4b8bbbf..f6ea000 100644 --- a/tests/test_openrouter_registry.py +++ b/tests/test_openrouter_registry.py @@ -6,8 +6,8 @@ import tempfile import pytest -from providers.base import ProviderType -from providers.openrouter_registry import OpenRouterModelConfig, OpenRouterModelRegistry +from providers.base import ModelCapabilities, ProviderType +from providers.openrouter_registry import OpenRouterModelRegistry class TestOpenRouterModelRegistry: @@ -110,18 +110,18 @@ class TestOpenRouterModelRegistry: assert registry.resolve("non-existent") is None def test_model_capabilities_conversion(self): - """Test conversion to ModelCapabilities.""" + """Test that registry returns ModelCapabilities directly.""" registry = OpenRouterModelRegistry() config = registry.resolve("opus") assert config is not None - caps = config.to_capabilities() - assert caps.provider == ProviderType.OPENROUTER - assert caps.model_name == "anthropic/claude-opus-4" - assert caps.friendly_name == "OpenRouter" - assert caps.context_window == 200000 - assert not caps.supports_extended_thinking + # Registry now returns ModelCapabilities objects directly + assert config.provider == ProviderType.OPENROUTER + assert config.model_name == "anthropic/claude-opus-4" + assert config.friendly_name == "OpenRouter (anthropic/claude-opus-4)" + assert config.context_window == 200000 + assert not config.supports_extended_thinking def test_duplicate_alias_detection(self): """Test that duplicate aliases are detected.""" @@ -199,8 +199,12 @@ class TestOpenRouterModelRegistry: def test_model_with_all_capabilities(self): """Test model with all capability flags.""" - config = OpenRouterModelConfig( + from providers.base import create_temperature_constraint + + caps = ModelCapabilities( + provider=ProviderType.OPENROUTER, model_name="test/full-featured", + friendly_name="OpenRouter (test/full-featured)", aliases=["full"], context_window=128000, supports_extended_thinking=True, @@ -209,9 +213,8 @@ class TestOpenRouterModelRegistry: supports_function_calling=True, supports_json_mode=True, description="Fully featured test model", + temperature_constraint=create_temperature_constraint("range"), ) - - caps = config.to_capabilities() assert caps.context_window == 128000 assert caps.supports_extended_thinking assert caps.supports_system_prompts From 9167e6d8456a64d7c663b113fed4a1c8a85763ef Mon Sep 17 00:00:00 2001 From: Fahad Date: Mon, 23 Jun 2025 17:53:03 +0400 Subject: [PATCH 4/8] Quick test mode for simulation tests --- CLAUDE.md | 36 +++++++++++++++++++++++++++------ communication_simulator_test.py | 33 ++++++++++++++++++++++++++++-- 2 files changed, 61 insertions(+), 8 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 8fd708f..db9a335 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -128,7 +128,26 @@ python communication_simulator_test.py python communication_simulator_test.py --verbose ``` -#### Run Individual Simulator Tests (Recommended) +#### Quick Test Mode (Recommended for Time-Limited Testing) +```bash +# Run quick test mode - 6 essential tests that provide maximum functionality coverage +python communication_simulator_test.py --quick + +# Run quick test mode with verbose output +python communication_simulator_test.py --quick --verbose +``` + +**Quick mode runs these 6 essential tests:** +- `cross_tool_continuation` - Cross-tool conversation memory testing (chat, thinkdeep, codereview, analyze, debug) +- `conversation_chain_validation` - Core conversation threading and memory validation +- `consensus_workflow_accurate` - Consensus tool with flash model and stance testing +- `codereview_validation` - CodeReview tool with flash model and multi-step workflows +- `planner_validation` - Planner tool with flash model and complex planning workflows +- `token_allocation_validation` - Token allocation and conversation history buildup testing + +**Why these 6 tests:** They cover all major tools (chat, planner, consensus, codereview + analyze, debug, thinkdeep), extensively test conversation memory functionality, use flash/flashlite models, and provide comprehensive app functionality coverage in minimal time. + +#### Run Individual Simulator Tests (For Detailed Testing) ```bash # List all available tests python communication_simulator_test.py --list-tests @@ -223,15 +242,17 @@ python -m pytest tests/ -v #### After Making Changes 1. Run quality checks again: `./code_quality_checks.sh` 2. Run integration tests locally: `./run_integration_tests.sh` -3. Run relevant simulator tests: `python communication_simulator_test.py --individual ` -4. Check logs for any issues: `tail -n 100 logs/mcp_server.log` -5. Restart Claude session to use updated code +3. Run quick test mode for fast validation: `python communication_simulator_test.py --quick` +4. Run relevant specific simulator tests if needed: `python communication_simulator_test.py --individual ` +5. Check logs for any issues: `tail -n 100 logs/mcp_server.log` +6. Restart Claude session to use updated code #### Before Committing/PR 1. Final quality check: `./code_quality_checks.sh` 2. Run integration tests: `./run_integration_tests.sh` -3. Run full simulator test suite: `./run_integration_tests.sh --with-simulator` -4. Verify all tests pass 100% +3. Run quick test mode: `python communication_simulator_test.py --quick` +4. Run full simulator test suite (optional): `./run_integration_tests.sh --with-simulator` +5. Verify all tests pass 100% ### Common Troubleshooting @@ -250,6 +271,9 @@ which python #### Test Failures ```bash +# First try quick test mode to see if it's a general issue +python communication_simulator_test.py --quick --verbose + # Run individual failing test with verbose output python communication_simulator_test.py --individual --verbose diff --git a/communication_simulator_test.py b/communication_simulator_test.py index 9c5cb89..93a1695 100644 --- a/communication_simulator_test.py +++ b/communication_simulator_test.py @@ -38,6 +38,15 @@ Available tests: debug_validation - Debug tool validation with actual bugs conversation_chain_validation - Conversation chain continuity validation +Quick Test Mode (for time-limited testing): + Use --quick to run the essential 6 tests that provide maximum coverage: + - cross_tool_continuation + - conversation_chain_validation + - consensus_workflow_accurate + - codereview_validation + - planner_validation + - token_allocation_validation + Examples: # Run all tests python communication_simulator_test.py @@ -48,6 +57,9 @@ Examples: # Run a single test individually (with full standalone setup) python communication_simulator_test.py --individual content_validation + # Run quick test mode (essential 6 tests for time-limited testing) + python communication_simulator_test.py --quick + # Force setup standalone server environment before running tests python communication_simulator_test.py --setup @@ -68,12 +80,13 @@ class CommunicationSimulator: """Simulates real-world Claude CLI communication with MCP Gemini server""" def __init__( - self, verbose: bool = False, keep_logs: bool = False, selected_tests: list[str] = None, setup: bool = False + self, verbose: bool = False, keep_logs: bool = False, selected_tests: list[str] = None, setup: bool = False, quick_mode: bool = False ): self.verbose = verbose self.keep_logs = keep_logs self.selected_tests = selected_tests or [] self.setup = setup + self.quick_mode = quick_mode self.temp_dir = None self.server_process = None self.python_path = self._get_python_path() @@ -83,6 +96,21 @@ class CommunicationSimulator: self.test_registry = TEST_REGISTRY + # Define quick mode tests (essential tests for time-limited testing) + self.quick_mode_tests = [ + "cross_tool_continuation", + "conversation_chain_validation", + "consensus_workflow_accurate", + "codereview_validation", + "planner_validation", + "token_allocation_validation" + ] + + # If quick mode is enabled, override selected_tests + if self.quick_mode: + self.selected_tests = self.quick_mode_tests + self.logger.info(f"Quick mode enabled - running {len(self.quick_mode_tests)} essential tests") + # Available test methods mapping self.available_tests = { name: self._create_test_runner(test_class) for name, test_class in self.test_registry.items() @@ -415,6 +443,7 @@ def parse_arguments(): parser.add_argument("--tests", "-t", nargs="+", help="Specific tests to run (space-separated)") parser.add_argument("--list-tests", action="store_true", help="List available tests and exit") parser.add_argument("--individual", "-i", help="Run a single test individually") + parser.add_argument("--quick", "-q", action="store_true", help="Run quick test mode (6 essential tests for time-limited testing)") parser.add_argument( "--setup", action="store_true", help="Force setup standalone server environment using run-server.sh" ) @@ -492,7 +521,7 @@ def main(): # Initialize simulator consistently for all use cases simulator = CommunicationSimulator( - verbose=args.verbose, keep_logs=args.keep_logs, selected_tests=args.tests, setup=args.setup + verbose=args.verbose, keep_logs=args.keep_logs, selected_tests=args.tests, setup=args.setup, quick_mode=args.quick ) # Determine execution mode and run From 8c1814d4ebba09f43683b8ad0b0c10b567fb5f58 Mon Sep 17 00:00:00 2001 From: Fahad Date: Mon, 23 Jun 2025 18:05:31 +0400 Subject: [PATCH 5/8] Quick test mode for simulation tests --- CLAUDE.md | 4 +++- communication_simulator_test.py | 35 +++++++++++++++++---------------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index db9a335..89db9d9 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -145,7 +145,9 @@ python communication_simulator_test.py --quick --verbose - `planner_validation` - Planner tool with flash model and complex planning workflows - `token_allocation_validation` - Token allocation and conversation history buildup testing -**Why these 6 tests:** They cover all major tools (chat, planner, consensus, codereview + analyze, debug, thinkdeep), extensively test conversation memory functionality, use flash/flashlite models, and provide comprehensive app functionality coverage in minimal time. +**Why these 6 tests:** They cover the core functionality including conversation memory (`utils/conversation_memory.py`), chat tool functionality, file processing and deduplication, model selection (flash/flashlite/o3), and cross-tool conversation workflows. These tests validate the most critical parts of the system in minimal time. + +**Note:** Some workflow tools (analyze, codereview, planner, consensus, etc.) require specific workflow parameters and may need individual testing rather than quick mode testing. #### Run Individual Simulator Tests (For Detailed Testing) ```bash diff --git a/communication_simulator_test.py b/communication_simulator_test.py index 93a1695..6bd1a07 100644 --- a/communication_simulator_test.py +++ b/communication_simulator_test.py @@ -40,12 +40,12 @@ Available tests: Quick Test Mode (for time-limited testing): Use --quick to run the essential 6 tests that provide maximum coverage: - - cross_tool_continuation - - conversation_chain_validation - - consensus_workflow_accurate - - codereview_validation - - planner_validation - - token_allocation_validation + - cross_tool_continuation (cross-tool conversation memory) + - basic_conversation (basic chat functionality) + - content_validation (content validation and deduplication) + - model_thinking_config (flash/flashlite model testing) + - o3_model_selection (o3 model selection testing) + - per_tool_deduplication (file deduplication for individual tools) Examples: # Run all tests @@ -91,19 +91,25 @@ class CommunicationSimulator: self.server_process = None self.python_path = self._get_python_path() + # Configure logging first + log_level = logging.DEBUG if verbose else logging.INFO + logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(message)s") + self.logger = logging.getLogger(__name__) + # Import test registry from simulator_tests import TEST_REGISTRY self.test_registry = TEST_REGISTRY # Define quick mode tests (essential tests for time-limited testing) + # Focus on tests that work with current tool configurations self.quick_mode_tests = [ - "cross_tool_continuation", - "conversation_chain_validation", - "consensus_workflow_accurate", - "codereview_validation", - "planner_validation", - "token_allocation_validation" + "cross_tool_continuation", # Cross-tool conversation memory + "basic_conversation", # Basic chat functionality + "content_validation", # Content validation and deduplication + "model_thinking_config", # Flash/flashlite model testing + "o3_model_selection", # O3 model selection testing + "per_tool_deduplication" # File deduplication for individual tools ] # If quick mode is enabled, override selected_tests @@ -119,11 +125,6 @@ class CommunicationSimulator: # Test result tracking self.test_results = dict.fromkeys(self.test_registry.keys(), False) - # Configure logging - log_level = logging.DEBUG if verbose else logging.INFO - logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(message)s") - self.logger = logging.getLogger(__name__) - def _get_python_path(self) -> str: """Get the Python path for the virtual environment""" current_dir = os.getcwd() From ce6c1fd7ea4fd336179d3b438d416e62237c5d83 Mon Sep 17 00:00:00 2001 From: Fahad Date: Mon, 23 Jun 2025 18:33:47 +0400 Subject: [PATCH 6/8] Quick test mode for simulation tests Fixed o4-mini name, OpenAI removed o4-mini-high Add max_output_tokens property to ModelCapabilities --- README.md | 2 +- communication_simulator_test.py | 29 ++++++++++++++++-------- conf/custom_models.json | 30 ++++++++++++++----------- docs/advanced-usage.md | 15 ++++++------- docs/configuration.md | 4 +--- docs/tools/analyze.md | 2 +- docs/tools/chat.md | 2 +- docs/tools/codereview.md | 2 +- docs/tools/debug.md | 2 +- docs/tools/precommit.md | 2 +- docs/tools/refactor.md | 2 +- docs/tools/secaudit.md | 2 +- docs/tools/testgen.md | 2 +- docs/tools/thinkdeep.md | 2 +- providers/base.py | 1 + providers/custom.py | 7 +++--- providers/dial.py | 9 ++++++++ providers/gemini.py | 4 ++++ providers/openai_compatible.py | 1 - providers/openai_provider.py | 24 +++++--------------- providers/openrouter.py | 1 + providers/registry.py | 2 -- providers/xai.py | 2 ++ tests/mock_helpers.py | 1 + tests/test_alias_target_restrictions.py | 19 ++++++++-------- tests/test_auto_mode.py | 2 +- tests/test_buggy_behavior_prevention.py | 16 +++++++------ tests/test_model_enumeration.py | 2 +- tests/test_model_restrictions.py | 4 ++-- tests/test_o3_temperature_fix_simple.py | 2 +- tests/test_openai_provider.py | 18 +++++++-------- tests/test_openrouter_provider.py | 2 +- tests/test_openrouter_registry.py | 21 ++++++++++++++--- tests/test_providers.py | 6 ++--- tests/test_supported_models_aliases.py | 5 ++--- 35 files changed, 137 insertions(+), 110 deletions(-) diff --git a/README.md b/README.md index 40552da..504a5a2 100644 --- a/README.md +++ b/README.md @@ -409,7 +409,7 @@ for most debugging workflows, as Claude is usually able to confidently find the When in doubt, you can always follow up with a new prompt and ask Claude to share its findings with another model: ```text -Use continuation with thinkdeep, share details with o4-mini-high to find out what the best fix is for this +Use continuation with thinkdeep, share details with o4-mini to find out what the best fix is for this ``` **[📖 Read More](docs/tools/debug.md)** - Step-by-step investigation methodology with workflow enforcement diff --git a/communication_simulator_test.py b/communication_simulator_test.py index 6bd1a07..e471b33 100644 --- a/communication_simulator_test.py +++ b/communication_simulator_test.py @@ -80,7 +80,12 @@ class CommunicationSimulator: """Simulates real-world Claude CLI communication with MCP Gemini server""" def __init__( - self, verbose: bool = False, keep_logs: bool = False, selected_tests: list[str] = None, setup: bool = False, quick_mode: bool = False + self, + verbose: bool = False, + keep_logs: bool = False, + selected_tests: list[str] = None, + setup: bool = False, + quick_mode: bool = False, ): self.verbose = verbose self.keep_logs = keep_logs @@ -104,12 +109,12 @@ class CommunicationSimulator: # Define quick mode tests (essential tests for time-limited testing) # Focus on tests that work with current tool configurations self.quick_mode_tests = [ - "cross_tool_continuation", # Cross-tool conversation memory - "basic_conversation", # Basic chat functionality - "content_validation", # Content validation and deduplication - "model_thinking_config", # Flash/flashlite model testing - "o3_model_selection", # O3 model selection testing - "per_tool_deduplication" # File deduplication for individual tools + "cross_tool_continuation", # Cross-tool conversation memory + "basic_conversation", # Basic chat functionality + "content_validation", # Content validation and deduplication + "model_thinking_config", # Flash/flashlite model testing + "o3_model_selection", # O3 model selection testing + "per_tool_deduplication", # File deduplication for individual tools ] # If quick mode is enabled, override selected_tests @@ -444,7 +449,9 @@ def parse_arguments(): parser.add_argument("--tests", "-t", nargs="+", help="Specific tests to run (space-separated)") parser.add_argument("--list-tests", action="store_true", help="List available tests and exit") parser.add_argument("--individual", "-i", help="Run a single test individually") - parser.add_argument("--quick", "-q", action="store_true", help="Run quick test mode (6 essential tests for time-limited testing)") + parser.add_argument( + "--quick", "-q", action="store_true", help="Run quick test mode (6 essential tests for time-limited testing)" + ) parser.add_argument( "--setup", action="store_true", help="Force setup standalone server environment using run-server.sh" ) @@ -522,7 +529,11 @@ def main(): # Initialize simulator consistently for all use cases simulator = CommunicationSimulator( - verbose=args.verbose, keep_logs=args.keep_logs, selected_tests=args.tests, setup=args.setup, quick_mode=args.quick + verbose=args.verbose, + keep_logs=args.keep_logs, + selected_tests=args.tests, + setup=args.setup, + quick_mode=args.quick, ) # Determine execution mode and run diff --git a/conf/custom_models.json b/conf/custom_models.json index 2a3bcf3..f794d00 100644 --- a/conf/custom_models.json +++ b/conf/custom_models.json @@ -22,6 +22,7 @@ "model_name": "The model identifier - OpenRouter format (e.g., 'anthropic/claude-opus-4') or custom model name (e.g., 'llama3.2')", "aliases": "Array of short names users can type instead of the full model name", "context_window": "Total number of tokens the model can process (input + output combined)", + "max_output_tokens": "Maximum number of tokens the model can generate in a single response", "supports_extended_thinking": "Whether the model supports extended reasoning tokens (currently none do via OpenRouter or custom APIs)", "supports_json_mode": "Whether the model can guarantee valid JSON output", "supports_function_calling": "Whether the model supports function/tool calling", @@ -36,6 +37,7 @@ "model_name": "my-local-model", "aliases": ["shortname", "nickname", "abbrev"], "context_window": 128000, + "max_output_tokens": 32768, "supports_extended_thinking": false, "supports_json_mode": true, "supports_function_calling": true, @@ -52,6 +54,7 @@ "model_name": "anthropic/claude-opus-4", "aliases": ["opus", "claude-opus", "claude4-opus", "claude-4-opus"], "context_window": 200000, + "max_output_tokens": 64000, "supports_extended_thinking": false, "supports_json_mode": false, "supports_function_calling": false, @@ -63,6 +66,7 @@ "model_name": "anthropic/claude-sonnet-4", "aliases": ["sonnet", "claude-sonnet", "claude4-sonnet", "claude-4-sonnet", "claude"], "context_window": 200000, + "max_output_tokens": 64000, "supports_extended_thinking": false, "supports_json_mode": false, "supports_function_calling": false, @@ -74,6 +78,7 @@ "model_name": "anthropic/claude-3.5-haiku", "aliases": ["haiku", "claude-haiku", "claude3-haiku", "claude-3-haiku"], "context_window": 200000, + "max_output_tokens": 64000, "supports_extended_thinking": false, "supports_json_mode": false, "supports_function_calling": false, @@ -85,6 +90,7 @@ "model_name": "google/gemini-2.5-pro", "aliases": ["pro","gemini-pro", "gemini", "pro-openrouter"], "context_window": 1048576, + "max_output_tokens": 65536, "supports_extended_thinking": false, "supports_json_mode": true, "supports_function_calling": false, @@ -96,6 +102,7 @@ "model_name": "google/gemini-2.5-flash", "aliases": ["flash","gemini-flash", "flash-openrouter", "flash-2.5"], "context_window": 1048576, + "max_output_tokens": 65536, "supports_extended_thinking": false, "supports_json_mode": true, "supports_function_calling": false, @@ -107,6 +114,7 @@ "model_name": "mistralai/mistral-large-2411", "aliases": ["mistral-large", "mistral"], "context_window": 128000, + "max_output_tokens": 32000, "supports_extended_thinking": false, "supports_json_mode": true, "supports_function_calling": true, @@ -118,6 +126,7 @@ "model_name": "meta-llama/llama-3-70b", "aliases": ["llama", "llama3", "llama3-70b", "llama-70b", "llama3-openrouter"], "context_window": 8192, + "max_output_tokens": 8192, "supports_extended_thinking": false, "supports_json_mode": false, "supports_function_calling": false, @@ -129,6 +138,7 @@ "model_name": "deepseek/deepseek-r1-0528", "aliases": ["deepseek-r1", "deepseek", "r1", "deepseek-thinking"], "context_window": 65536, + "max_output_tokens": 32768, "supports_extended_thinking": true, "supports_json_mode": true, "supports_function_calling": false, @@ -140,6 +150,7 @@ "model_name": "perplexity/llama-3-sonar-large-32k-online", "aliases": ["perplexity", "sonar", "perplexity-online"], "context_window": 32768, + "max_output_tokens": 32768, "supports_extended_thinking": false, "supports_json_mode": false, "supports_function_calling": false, @@ -151,6 +162,7 @@ "model_name": "openai/o3", "aliases": ["o3"], "context_window": 200000, + "max_output_tokens": 100000, "supports_extended_thinking": false, "supports_json_mode": true, "supports_function_calling": true, @@ -164,6 +176,7 @@ "model_name": "openai/o3-mini", "aliases": ["o3-mini", "o3mini"], "context_window": 200000, + "max_output_tokens": 100000, "supports_extended_thinking": false, "supports_json_mode": true, "supports_function_calling": true, @@ -177,6 +190,7 @@ "model_name": "openai/o3-mini-high", "aliases": ["o3-mini-high", "o3mini-high"], "context_window": 200000, + "max_output_tokens": 100000, "supports_extended_thinking": false, "supports_json_mode": true, "supports_function_calling": true, @@ -190,6 +204,7 @@ "model_name": "openai/o3-pro", "aliases": ["o3-pro", "o3pro"], "context_window": 200000, + "max_output_tokens": 100000, "supports_extended_thinking": false, "supports_json_mode": true, "supports_function_calling": true, @@ -203,6 +218,7 @@ "model_name": "openai/o4-mini", "aliases": ["o4-mini", "o4mini"], "context_window": 200000, + "max_output_tokens": 100000, "supports_extended_thinking": false, "supports_json_mode": true, "supports_function_calling": true, @@ -212,23 +228,11 @@ "temperature_constraint": "fixed", "description": "OpenAI's o4-mini model - optimized for shorter contexts with rapid reasoning and vision" }, - { - "model_name": "openai/o4-mini-high", - "aliases": ["o4-mini-high", "o4mini-high", "o4minihigh", "o4minihi"], - "context_window": 200000, - "supports_extended_thinking": false, - "supports_json_mode": true, - "supports_function_calling": true, - "supports_images": true, - "max_image_size_mb": 20.0, - "supports_temperature": false, - "temperature_constraint": "fixed", - "description": "OpenAI's o4-mini with high reasoning effort - enhanced for complex tasks with vision" - }, { "model_name": "llama3.2", "aliases": ["local-llama", "local", "llama3.2", "ollama-llama"], "context_window": 128000, + "max_output_tokens": 64000, "supports_extended_thinking": false, "supports_json_mode": false, "supports_function_calling": false, diff --git a/docs/advanced-usage.md b/docs/advanced-usage.md index 65dc7f3..9383354 100644 --- a/docs/advanced-usage.md +++ b/docs/advanced-usage.md @@ -38,7 +38,6 @@ Regardless of your default configuration, you can specify models per request: | **`o3`** | OpenAI | 200K tokens | Strong logical reasoning | Debugging logic errors, systematic analysis | | **`o3-mini`** | OpenAI | 200K tokens | Balanced speed/quality | Moderate complexity tasks | | **`o4-mini`** | OpenAI | 200K tokens | Latest reasoning model | Optimized for shorter contexts | -| **`o4-mini-high`** | OpenAI | 200K tokens | Enhanced reasoning | Complex tasks requiring deeper analysis | | **`gpt4.1`** | OpenAI | 1M tokens | Latest GPT-4 with extended context | Large codebase analysis, comprehensive reviews | | **`llama`** (Llama 3.2) | Custom/Local | 128K tokens | Local inference, privacy | On-device analysis, cost-free processing | | **Any model** | OpenRouter | Varies | Access to GPT-4, Claude, Llama, etc. | User-specified or based on task requirements | @@ -69,7 +68,7 @@ OPENAI_ALLOWED_MODELS=o4-mini # High-performance: Quality over cost GOOGLE_ALLOWED_MODELS=pro -OPENAI_ALLOWED_MODELS=o3,o4-mini-high +OPENAI_ALLOWED_MODELS=o3,o4-mini ``` **Important Notes:** @@ -144,7 +143,7 @@ All tools that work with files support **both individual files and entire direct **`analyze`** - Analyze files or directories - `files`: List of file paths or directories (required) - `question`: What to analyze (required) -- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) +- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default) - `analysis_type`: architecture|performance|security|quality|general - `output_format`: summary|detailed|actionable - `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only) @@ -159,7 +158,7 @@ All tools that work with files support **both individual files and entire direct **`codereview`** - Review code files or directories - `files`: List of file paths or directories (required) -- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) +- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default) - `review_type`: full|security|performance|quick - `focus_on`: Specific aspects to focus on - `standards`: Coding standards to enforce @@ -175,7 +174,7 @@ All tools that work with files support **both individual files and entire direct **`debug`** - Debug with file context - `error_description`: Description of the issue (required) -- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) +- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default) - `error_context`: Stack trace or logs - `files`: Files or directories related to the issue - `runtime_info`: Environment details @@ -191,7 +190,7 @@ All tools that work with files support **both individual files and entire direct **`thinkdeep`** - Extended analysis with file context - `current_analysis`: Your current thinking (required) -- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) +- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default) - `problem_context`: Additional context - `focus_areas`: Specific aspects to focus on - `files`: Files or directories for context @@ -207,7 +206,7 @@ All tools that work with files support **both individual files and entire direct **`testgen`** - Comprehensive test generation with edge case coverage - `files`: Code files or directories to generate tests for (required) - `prompt`: Description of what to test, testing objectives, and scope (required) -- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) +- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default) - `test_examples`: Optional existing test files as style/pattern reference - `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only) @@ -222,7 +221,7 @@ All tools that work with files support **both individual files and entire direct - `files`: Code files or directories to analyze for refactoring opportunities (required) - `prompt`: Description of refactoring goals, context, and specific areas of focus (required) - `refactor_type`: codesmells|decompose|modernize|organization (required) -- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) +- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default) - `focus_areas`: Specific areas to focus on (e.g., 'performance', 'readability', 'maintainability', 'security') - `style_guide_examples`: Optional existing code files to use as style/pattern reference - `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only) diff --git a/docs/configuration.md b/docs/configuration.md index 8107cc4..473b6de 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -63,7 +63,7 @@ CUSTOM_MODEL_NAME=llama3.2 # Default model **Default Model Selection:** ```env -# Options: 'auto', 'pro', 'flash', 'o3', 'o3-mini', 'o4-mini', 'o4-mini-high', etc. +# Options: 'auto', 'pro', 'flash', 'o3', 'o3-mini', 'o4-mini', etc. DEFAULT_MODEL=auto # Claude picks best model for each task (recommended) ``` @@ -74,7 +74,6 @@ DEFAULT_MODEL=auto # Claude picks best model for each task (recommended) - **`o3`**: Strong logical reasoning (200K context) - **`o3-mini`**: Balanced speed/quality (200K context) - **`o4-mini`**: Latest reasoning model, optimized for shorter contexts -- **`o4-mini-high`**: Enhanced O4 with higher reasoning effort - **`grok`**: GROK-3 advanced reasoning (131K context) - **Custom models**: via OpenRouter or local APIs @@ -120,7 +119,6 @@ OPENROUTER_ALLOWED_MODELS=opus,sonnet,mistral - `o3` (200K context, high reasoning) - `o3-mini` (200K context, balanced) - `o4-mini` (200K context, latest balanced) -- `o4-mini-high` (200K context, enhanced reasoning) - `mini` (shorthand for o4-mini) **Gemini Models:** diff --git a/docs/tools/analyze.md b/docs/tools/analyze.md index 379b20d..618a0be 100644 --- a/docs/tools/analyze.md +++ b/docs/tools/analyze.md @@ -65,7 +65,7 @@ This workflow ensures methodical analysis before expert insights, resulting in d **Initial Configuration (used in step 1):** - `prompt`: What to analyze or look for (required) -- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) +- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default) - `analysis_type`: architecture|performance|security|quality|general (default: general) - `output_format`: summary|detailed|actionable (default: detailed) - `temperature`: Temperature for analysis (0-1, default 0.2) diff --git a/docs/tools/chat.md b/docs/tools/chat.md index 1c2b507..b7557eb 100644 --- a/docs/tools/chat.md +++ b/docs/tools/chat.md @@ -33,7 +33,7 @@ and then debate with the other models to give me a final verdict ## Tool Parameters - `prompt`: Your question or discussion topic (required) -- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) +- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default) - `files`: Optional files for context (absolute paths) - `images`: Optional images for visual context (absolute paths) - `temperature`: Response creativity (0-1, default 0.5) diff --git a/docs/tools/codereview.md b/docs/tools/codereview.md index 9ba650c..9037cc2 100644 --- a/docs/tools/codereview.md +++ b/docs/tools/codereview.md @@ -80,7 +80,7 @@ The above prompt will simultaneously run two separate `codereview` tools with tw **Initial Review Configuration (used in step 1):** - `prompt`: User's summary of what the code does, expected behavior, constraints, and review objectives (required) -- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) +- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default) - `review_type`: full|security|performance|quick (default: full) - `focus_on`: Specific aspects to focus on (e.g., "security vulnerabilities", "performance bottlenecks") - `standards`: Coding standards to enforce (e.g., "PEP8", "ESLint", "Google Style Guide") diff --git a/docs/tools/debug.md b/docs/tools/debug.md index 7efc454..6e7f20d 100644 --- a/docs/tools/debug.md +++ b/docs/tools/debug.md @@ -73,7 +73,7 @@ This structured approach ensures Claude performs methodical groundwork before ex - `images`: Visual debugging materials (error screenshots, logs, etc.) **Model Selection:** -- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high (default: server default) +- `model`: auto|pro|flash|o3|o3-mini|o4-mini (default: server default) - `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only) - `use_websearch`: Enable web search for documentation and solutions (default: true) - `use_assistant_model`: Whether to use expert analysis phase (default: true, set to false to use Claude only) diff --git a/docs/tools/precommit.md b/docs/tools/precommit.md index a218bd4..d70c1ab 100644 --- a/docs/tools/precommit.md +++ b/docs/tools/precommit.md @@ -135,7 +135,7 @@ Use zen and perform a thorough precommit ensuring there aren't any new regressio **Initial Configuration (used in step 1):** - `path`: Starting directory to search for repos (default: current directory, absolute path required) - `prompt`: The original user request description for the changes (required for context) -- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) +- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default) - `compare_to`: Compare against a branch/tag instead of local changes (optional) - `severity_filter`: critical|high|medium|low|all (default: all) - `include_staged`: Include staged changes in the review (default: true) diff --git a/docs/tools/refactor.md b/docs/tools/refactor.md index 8314a4e..6407a4a 100644 --- a/docs/tools/refactor.md +++ b/docs/tools/refactor.md @@ -103,7 +103,7 @@ This results in Claude first performing its own expert analysis, encouraging it **Initial Configuration (used in step 1):** - `prompt`: Description of refactoring goals, context, and specific areas of focus (required) - `refactor_type`: codesmells|decompose|modernize|organization (default: codesmells) -- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) +- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default) - `focus_areas`: Specific areas to focus on (e.g., 'performance', 'readability', 'maintainability', 'security') - `style_guide_examples`: Optional existing code files to use as style/pattern reference (absolute paths) - `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only) diff --git a/docs/tools/secaudit.md b/docs/tools/secaudit.md index 36c4b8f..280452f 100644 --- a/docs/tools/secaudit.md +++ b/docs/tools/secaudit.md @@ -86,7 +86,7 @@ security remediation plan using planner - `images`: Architecture diagrams, security documentation, or visual references **Initial Security Configuration (used in step 1):** -- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) +- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default) - `security_scope`: Application context, technology stack, and security boundary definition (required) - `threat_level`: low|medium|high|critical (default: medium) - determines assessment depth and urgency - `compliance_requirements`: List of compliance frameworks to assess against (e.g., ["PCI DSS", "SOC2"]) diff --git a/docs/tools/testgen.md b/docs/tools/testgen.md index e19d042..0d74a98 100644 --- a/docs/tools/testgen.md +++ b/docs/tools/testgen.md @@ -70,7 +70,7 @@ Test generation excels with extended reasoning models like Gemini Pro or O3, whi **Initial Configuration (used in step 1):** - `prompt`: Description of what to test, testing objectives, and specific scope/focus areas (required) -- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) +- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default) - `test_examples`: Optional existing test files or directories to use as style/pattern reference (absolute paths) - `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only) - `use_assistant_model`: Whether to use expert test generation phase (default: true, set to false to use Claude only) diff --git a/docs/tools/thinkdeep.md b/docs/tools/thinkdeep.md index 5180a8b..26d5322 100644 --- a/docs/tools/thinkdeep.md +++ b/docs/tools/thinkdeep.md @@ -30,7 +30,7 @@ with the best architecture for my project ## Tool Parameters - `prompt`: Your current thinking/analysis to extend and validate (required) -- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) +- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default) - `problem_context`: Additional context about the problem or goal - `focus_areas`: Specific aspects to focus on (architecture, performance, security, etc.) - `files`: Optional file paths or directories for additional context (absolute paths) diff --git a/providers/base.py b/providers/base.py index 06f60fe..aff8705 100644 --- a/providers/base.py +++ b/providers/base.py @@ -132,6 +132,7 @@ class ModelCapabilities: model_name: str friendly_name: str # Human-friendly name like "Gemini" or "OpenAI" context_window: int # Total context window size in tokens + max_output_tokens: int # Maximum output tokens per request supports_extended_thinking: bool = False supports_system_prompts: bool = True supports_streaming: bool = True diff --git a/providers/custom.py b/providers/custom.py index 021bba5..d32d494 100644 --- a/providers/custom.py +++ b/providers/custom.py @@ -158,6 +158,7 @@ class CustomProvider(OpenAICompatibleProvider): model_name=resolved_name, friendly_name=f"{self.FRIENDLY_NAME} ({resolved_name})", context_window=32_768, # Conservative default + max_output_tokens=32_768, # Conservative default max output supports_extended_thinking=False, # Most custom models don't support this supports_system_prompts=True, supports_streaming=True, @@ -187,7 +188,7 @@ class CustomProvider(OpenAICompatibleProvider): Returns: True if model is intended for custom/local endpoint """ - logging.debug(f"Custom provider validating model: '{model_name}'") + # logging.debug(f"Custom provider validating model: '{model_name}'") # Try to resolve through registry first config = self._registry.resolve(model_name) @@ -195,12 +196,12 @@ class CustomProvider(OpenAICompatibleProvider): model_id = config.model_name # Use explicit is_custom flag for clean validation if config.is_custom: - logging.debug(f"Model '{model_name}' -> '{model_id}' validated via registry (custom model)") + logging.debug(f"... [Custom] Model '{model_name}' -> '{model_id}' validated via registry") return True else: # This is a cloud/OpenRouter model - CustomProvider should NOT handle these # Let OpenRouter provider handle them instead - logging.debug(f"Model '{model_name}' -> '{model_id}' rejected (cloud model, defer to OpenRouter)") + # logging.debug(f"... [Custom] Model '{model_name}' -> '{model_id}' not custom (defer to OpenRouter)") return False # Handle version tags for unknown models (e.g., "my-model:latest") diff --git a/providers/dial.py b/providers/dial.py index f019415..e0c4a29 100644 --- a/providers/dial.py +++ b/providers/dial.py @@ -37,6 +37,7 @@ class DIALModelProvider(OpenAICompatibleProvider): model_name="o3-2025-04-16", friendly_name="DIAL (O3)", context_window=200_000, + max_output_tokens=100_000, supports_extended_thinking=False, supports_system_prompts=True, supports_streaming=True, @@ -54,6 +55,7 @@ class DIALModelProvider(OpenAICompatibleProvider): model_name="o4-mini-2025-04-16", friendly_name="DIAL (O4-mini)", context_window=200_000, + max_output_tokens=100_000, supports_extended_thinking=False, supports_system_prompts=True, supports_streaming=True, @@ -71,6 +73,7 @@ class DIALModelProvider(OpenAICompatibleProvider): model_name="anthropic.claude-sonnet-4-20250514-v1:0", friendly_name="DIAL (Sonnet 4)", context_window=200_000, + max_output_tokens=64_000, supports_extended_thinking=False, supports_system_prompts=True, supports_streaming=True, @@ -88,6 +91,7 @@ class DIALModelProvider(OpenAICompatibleProvider): model_name="anthropic.claude-sonnet-4-20250514-v1:0-with-thinking", friendly_name="DIAL (Sonnet 4 Thinking)", context_window=200_000, + max_output_tokens=64_000, supports_extended_thinking=True, # Thinking mode variant supports_system_prompts=True, supports_streaming=True, @@ -105,6 +109,7 @@ class DIALModelProvider(OpenAICompatibleProvider): model_name="anthropic.claude-opus-4-20250514-v1:0", friendly_name="DIAL (Opus 4)", context_window=200_000, + max_output_tokens=64_000, supports_extended_thinking=False, supports_system_prompts=True, supports_streaming=True, @@ -122,6 +127,7 @@ class DIALModelProvider(OpenAICompatibleProvider): model_name="anthropic.claude-opus-4-20250514-v1:0-with-thinking", friendly_name="DIAL (Opus 4 Thinking)", context_window=200_000, + max_output_tokens=64_000, supports_extended_thinking=True, # Thinking mode variant supports_system_prompts=True, supports_streaming=True, @@ -139,6 +145,7 @@ class DIALModelProvider(OpenAICompatibleProvider): model_name="gemini-2.5-pro-preview-03-25-google-search", friendly_name="DIAL (Gemini 2.5 Pro Search)", context_window=1_000_000, + max_output_tokens=65_536, supports_extended_thinking=False, # DIAL doesn't expose thinking mode supports_system_prompts=True, supports_streaming=True, @@ -156,6 +163,7 @@ class DIALModelProvider(OpenAICompatibleProvider): model_name="gemini-2.5-pro-preview-05-06", friendly_name="DIAL (Gemini 2.5 Pro)", context_window=1_000_000, + max_output_tokens=65_536, supports_extended_thinking=False, supports_system_prompts=True, supports_streaming=True, @@ -173,6 +181,7 @@ class DIALModelProvider(OpenAICompatibleProvider): model_name="gemini-2.5-flash-preview-05-20", friendly_name="DIAL (Gemini Flash 2.5)", context_window=1_000_000, + max_output_tokens=65_536, supports_extended_thinking=False, supports_system_prompts=True, supports_streaming=True, diff --git a/providers/gemini.py b/providers/gemini.py index 1118699..51916b0 100644 --- a/providers/gemini.py +++ b/providers/gemini.py @@ -24,6 +24,7 @@ class GeminiModelProvider(ModelProvider): model_name="gemini-2.0-flash", friendly_name="Gemini (Flash 2.0)", context_window=1_048_576, # 1M tokens + max_output_tokens=65_536, supports_extended_thinking=True, # Experimental thinking mode supports_system_prompts=True, supports_streaming=True, @@ -42,6 +43,7 @@ class GeminiModelProvider(ModelProvider): model_name="gemini-2.0-flash-lite", friendly_name="Gemin (Flash Lite 2.0)", context_window=1_048_576, # 1M tokens + max_output_tokens=65_536, supports_extended_thinking=False, # Not supported per user request supports_system_prompts=True, supports_streaming=True, @@ -59,6 +61,7 @@ class GeminiModelProvider(ModelProvider): model_name="gemini-2.5-flash", friendly_name="Gemini (Flash 2.5)", context_window=1_048_576, # 1M tokens + max_output_tokens=65_536, supports_extended_thinking=True, supports_system_prompts=True, supports_streaming=True, @@ -77,6 +80,7 @@ class GeminiModelProvider(ModelProvider): model_name="gemini-2.5-pro", friendly_name="Gemini (Pro 2.5)", context_window=1_048_576, # 1M tokens + max_output_tokens=65_536, supports_extended_thinking=True, supports_system_prompts=True, supports_streaming=True, diff --git a/providers/openai_compatible.py b/providers/openai_compatible.py index fec4484..17ce60d 100644 --- a/providers/openai_compatible.py +++ b/providers/openai_compatible.py @@ -687,7 +687,6 @@ class OpenAICompatibleProvider(ModelProvider): "o3-mini", "o3-pro", "o4-mini", - "o4-mini-high", # Note: Claude models would be handled by a separate provider } supports = model_name.lower() in vision_models diff --git a/providers/openai_provider.py b/providers/openai_provider.py index e065ee1..d977869 100644 --- a/providers/openai_provider.py +++ b/providers/openai_provider.py @@ -24,6 +24,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider): model_name="o3", friendly_name="OpenAI (O3)", context_window=200_000, # 200K tokens + max_output_tokens=65536, # 64K max output tokens supports_extended_thinking=False, supports_system_prompts=True, supports_streaming=True, @@ -41,6 +42,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider): model_name="o3-mini", friendly_name="OpenAI (O3-mini)", context_window=200_000, # 200K tokens + max_output_tokens=65536, # 64K max output tokens supports_extended_thinking=False, supports_system_prompts=True, supports_streaming=True, @@ -58,6 +60,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider): model_name="o3-pro-2025-06-10", friendly_name="OpenAI (O3-Pro)", context_window=200_000, # 200K tokens + max_output_tokens=65536, # 64K max output tokens supports_extended_thinking=False, supports_system_prompts=True, supports_streaming=True, @@ -75,6 +78,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider): model_name="o4-mini", friendly_name="OpenAI (O4-mini)", context_window=200_000, # 200K tokens + max_output_tokens=65536, # 64K max output tokens supports_extended_thinking=False, supports_system_prompts=True, supports_streaming=True, @@ -85,30 +89,14 @@ class OpenAIModelProvider(OpenAICompatibleProvider): supports_temperature=False, # O4 models don't accept temperature parameter temperature_constraint=create_temperature_constraint("fixed"), description="Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning", - aliases=["mini", "o4mini"], - ), - "o4-mini-high": ModelCapabilities( - provider=ProviderType.OPENAI, - model_name="o4-mini-high", - friendly_name="OpenAI (O4-mini-high)", - context_window=200_000, # 200K tokens - supports_extended_thinking=False, - supports_system_prompts=True, - supports_streaming=True, - supports_function_calling=True, - supports_json_mode=True, - supports_images=True, # O4 models support vision - max_image_size_mb=20.0, # 20MB per OpenAI docs - supports_temperature=False, # O4 models don't accept temperature parameter - temperature_constraint=create_temperature_constraint("fixed"), - description="Enhanced O4 mini (200K context) - Higher reasoning effort for complex tasks", - aliases=["o4minihigh", "o4minihi", "mini-high"], + aliases=["mini", "o4mini", "o4-mini"], ), "gpt-4.1-2025-04-14": ModelCapabilities( provider=ProviderType.OPENAI, model_name="gpt-4.1-2025-04-14", friendly_name="OpenAI (GPT 4.1)", context_window=1_000_000, # 1M tokens + max_output_tokens=32_768, supports_extended_thinking=False, supports_system_prompts=True, supports_streaming=True, diff --git a/providers/openrouter.py b/providers/openrouter.py index 3d90238..18d3d5e 100644 --- a/providers/openrouter.py +++ b/providers/openrouter.py @@ -101,6 +101,7 @@ class OpenRouterProvider(OpenAICompatibleProvider): model_name=resolved_name, friendly_name=self.FRIENDLY_NAME, context_window=32_768, # Conservative default context window + max_output_tokens=32_768, supports_extended_thinking=False, supports_system_prompts=True, supports_streaming=True, diff --git a/providers/registry.py b/providers/registry.py index da7a9b5..4ab5732 100644 --- a/providers/registry.py +++ b/providers/registry.py @@ -24,8 +24,6 @@ class ModelProviderRegistry: cls._instance._providers = {} cls._instance._initialized_providers = {} logging.debug(f"REGISTRY: Created instance {cls._instance}") - else: - logging.debug(f"REGISTRY: Returning existing instance {cls._instance}") return cls._instance @classmethod diff --git a/providers/xai.py b/providers/xai.py index 2b6fd04..dcb14a1 100644 --- a/providers/xai.py +++ b/providers/xai.py @@ -26,6 +26,7 @@ class XAIModelProvider(OpenAICompatibleProvider): model_name="grok-3", friendly_name="X.AI (Grok 3)", context_window=131_072, # 131K tokens + max_output_tokens=131072, supports_extended_thinking=False, supports_system_prompts=True, supports_streaming=True, @@ -43,6 +44,7 @@ class XAIModelProvider(OpenAICompatibleProvider): model_name="grok-3-fast", friendly_name="X.AI (Grok 3 Fast)", context_window=131_072, # 131K tokens + max_output_tokens=131072, supports_extended_thinking=False, supports_system_prompts=True, supports_streaming=True, diff --git a/tests/mock_helpers.py b/tests/mock_helpers.py index eb283b6..1122af1 100644 --- a/tests/mock_helpers.py +++ b/tests/mock_helpers.py @@ -15,6 +15,7 @@ def create_mock_provider(model_name="gemini-2.5-flash", context_window=1_048_576 model_name=model_name, friendly_name="Gemini", context_window=context_window, + max_output_tokens=8192, supports_extended_thinking=False, supports_system_prompts=True, supports_streaming=True, diff --git a/tests/test_alias_target_restrictions.py b/tests/test_alias_target_restrictions.py index 7b182e6..dd36b83 100644 --- a/tests/test_alias_target_restrictions.py +++ b/tests/test_alias_target_restrictions.py @@ -211,7 +211,7 @@ class TestAliasTargetRestrictions: # Verify the polymorphic method was called mock_provider.list_all_known_models.assert_called_once() - @patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini-high"}) # Restrict to specific model + @patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini"}) # Restrict to specific model def test_complex_alias_chains_handled_correctly(self): """Test that complex alias chains are handled correctly in restrictions.""" # Clear cached restriction service @@ -221,12 +221,11 @@ class TestAliasTargetRestrictions: provider = OpenAIModelProvider(api_key="test-key") - # Only o4-mini-high should be allowed - assert provider.validate_model_name("o4-mini-high") + # Only o4-mini should be allowed + assert provider.validate_model_name("o4-mini") # Other models should be blocked - assert not provider.validate_model_name("o4-mini") - assert not provider.validate_model_name("mini") # This resolves to o4-mini + assert not provider.validate_model_name("o3") assert not provider.validate_model_name("o3-mini") def test_critical_regression_validation_sees_alias_targets(self): @@ -307,7 +306,7 @@ class TestAliasTargetRestrictions: it appear that target-based restrictions don't work. """ # Test with a made-up restriction scenario - with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini-high,o3-mini"}): + with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini,o3-mini"}): # Clear cached restriction service import utils.model_restrictions @@ -318,7 +317,7 @@ class TestAliasTargetRestrictions: # These specific target models should be recognized as valid all_known = provider.list_all_known_models() - assert "o4-mini-high" in all_known, "Target model o4-mini-high should be known" + assert "o4-mini" in all_known, "Target model o4-mini should be known" assert "o3-mini" in all_known, "Target model o3-mini should be known" # Validation should not warn about these being unrecognized @@ -329,11 +328,11 @@ class TestAliasTargetRestrictions: # Should not warn about our allowed models being unrecognized all_warnings = [str(call) for call in mock_logger.warning.call_args_list] for warning in all_warnings: - assert "o4-mini-high" not in warning or "not a recognized" not in warning + assert "o4-mini" not in warning or "not a recognized" not in warning assert "o3-mini" not in warning or "not a recognized" not in warning # The restriction should actually work - assert provider.validate_model_name("o4-mini-high") + assert provider.validate_model_name("o4-mini") assert provider.validate_model_name("o3-mini") - assert not provider.validate_model_name("o4-mini") # not in allowed list + assert not provider.validate_model_name("o3-pro") # not in allowed list assert not provider.validate_model_name("o3") # not in allowed list diff --git a/tests/test_auto_mode.py b/tests/test_auto_mode.py index 74d8ae3..f96feb3 100644 --- a/tests/test_auto_mode.py +++ b/tests/test_auto_mode.py @@ -64,7 +64,7 @@ class TestAutoMode: models_with_descriptions[model_name] = description # Check all expected models are present with meaningful descriptions - expected_models = ["flash", "pro", "o3", "o3-mini", "o3-pro", "o4-mini", "o4-mini-high"] + expected_models = ["flash", "pro", "o3", "o3-mini", "o3-pro", "o4-mini"] for model in expected_models: # Model should exist somewhere in the providers # Note: Some models might not be available if API keys aren't configured diff --git a/tests/test_buggy_behavior_prevention.py b/tests/test_buggy_behavior_prevention.py index e960f1f..e925e31 100644 --- a/tests/test_buggy_behavior_prevention.py +++ b/tests/test_buggy_behavior_prevention.py @@ -118,7 +118,7 @@ class TestBuggyBehaviorPrevention: provider = OpenAIModelProvider(api_key="test-key") # Simulate a scenario where admin wants to restrict specific targets - with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini-high"}): + with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini"}): # Clear cached restriction service import utils.model_restrictions @@ -126,19 +126,21 @@ class TestBuggyBehaviorPrevention: # These should work because they're explicitly allowed assert provider.validate_model_name("o3-mini") - assert provider.validate_model_name("o4-mini-high") + assert provider.validate_model_name("o4-mini") # These should be blocked - assert not provider.validate_model_name("o4-mini") # Not in allowed list + assert not provider.validate_model_name("o3-pro") # Not in allowed list assert not provider.validate_model_name("o3") # Not in allowed list - assert not provider.validate_model_name("mini") # Resolves to o4-mini, not allowed + + # This should be ALLOWED because it resolves to o4-mini which is in the allowed list + assert provider.validate_model_name("mini") # Resolves to o4-mini, which IS allowed # Verify our list_all_known_models includes the restricted models all_known = provider.list_all_known_models() assert "o3-mini" in all_known # Should be known (and allowed) - assert "o4-mini-high" in all_known # Should be known (and allowed) - assert "o4-mini" in all_known # Should be known (but blocked) - assert "mini" in all_known # Should be known (but blocked) + assert "o4-mini" in all_known # Should be known (and allowed) + assert "o3-pro" in all_known # Should be known (but blocked) + assert "mini" in all_known # Should be known (and allowed since it resolves to o4-mini) def test_demonstration_of_old_vs_new_interface(self): """ diff --git a/tests/test_model_enumeration.py b/tests/test_model_enumeration.py index 548f785..0a78b17 100644 --- a/tests/test_model_enumeration.py +++ b/tests/test_model_enumeration.py @@ -149,7 +149,7 @@ class TestModelEnumeration: ("o3", False), # OpenAI - not available without API key ("grok", False), # X.AI - not available without API key ("gemini-2.5-flash", False), # Full Gemini name - not available without API key - ("o4-mini-high", False), # OpenAI variant - not available without API key + ("o4-mini", False), # OpenAI variant - not available without API key ("grok-3-fast", False), # X.AI variant - not available without API key ], ) diff --git a/tests/test_model_restrictions.py b/tests/test_model_restrictions.py index bd34a81..6a93bd5 100644 --- a/tests/test_model_restrictions.py +++ b/tests/test_model_restrictions.py @@ -93,7 +93,7 @@ class TestModelRestrictionService: with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini"}): service = ModelRestrictionService() - models = ["o3", "o3-mini", "o4-mini", "o4-mini-high"] + models = ["o3", "o3-mini", "o4-mini", "o3-pro"] filtered = service.filter_models(ProviderType.OPENAI, models) assert filtered == ["o3-mini", "o4-mini"] @@ -573,7 +573,7 @@ class TestShorthandRestrictions: # Other models should not work assert not openai_provider.validate_model_name("o3") - assert not openai_provider.validate_model_name("o4-mini-high") + assert not openai_provider.validate_model_name("o3-pro") @patch.dict( os.environ, diff --git a/tests/test_o3_temperature_fix_simple.py b/tests/test_o3_temperature_fix_simple.py index da0ea60..0a27256 100644 --- a/tests/test_o3_temperature_fix_simple.py +++ b/tests/test_o3_temperature_fix_simple.py @@ -185,7 +185,7 @@ class TestO3TemperatureParameterFixSimple: provider = OpenAIModelProvider(api_key="test-key") # Test O3/O4 models that should NOT support temperature parameter - o3_o4_models = ["o3", "o3-mini", "o3-pro", "o4-mini", "o4-mini-high"] + o3_o4_models = ["o3", "o3-mini", "o3-pro", "o4-mini"] for model in o3_o4_models: capabilities = provider.get_capabilities(model) diff --git a/tests/test_openai_provider.py b/tests/test_openai_provider.py index baab182..3429be9 100644 --- a/tests/test_openai_provider.py +++ b/tests/test_openai_provider.py @@ -47,14 +47,13 @@ class TestOpenAIProvider: assert provider.validate_model_name("o3-mini") is True assert provider.validate_model_name("o3-pro") is True assert provider.validate_model_name("o4-mini") is True - assert provider.validate_model_name("o4-mini-high") is True + assert provider.validate_model_name("o4-mini") is True # Test valid aliases assert provider.validate_model_name("mini") is True assert provider.validate_model_name("o3mini") is True assert provider.validate_model_name("o4mini") is True - assert provider.validate_model_name("o4minihigh") is True - assert provider.validate_model_name("o4minihi") is True + assert provider.validate_model_name("o4mini") is True # Test invalid model assert provider.validate_model_name("invalid-model") is False @@ -69,15 +68,14 @@ class TestOpenAIProvider: assert provider._resolve_model_name("mini") == "o4-mini" assert provider._resolve_model_name("o3mini") == "o3-mini" assert provider._resolve_model_name("o4mini") == "o4-mini" - assert provider._resolve_model_name("o4minihigh") == "o4-mini-high" - assert provider._resolve_model_name("o4minihi") == "o4-mini-high" + assert provider._resolve_model_name("o4mini") == "o4-mini" # Test full name passthrough assert provider._resolve_model_name("o3") == "o3" assert provider._resolve_model_name("o3-mini") == "o3-mini" assert provider._resolve_model_name("o3-pro") == "o3-pro-2025-06-10" assert provider._resolve_model_name("o4-mini") == "o4-mini" - assert provider._resolve_model_name("o4-mini-high") == "o4-mini-high" + assert provider._resolve_model_name("o4-mini") == "o4-mini" def test_get_capabilities_o3(self): """Test getting model capabilities for O3.""" @@ -184,11 +182,11 @@ class TestOpenAIProvider: call_kwargs = mock_client.chat.completions.create.call_args[1] assert call_kwargs["model"] == "o3-mini" - # Test o4minihigh -> o4-mini-high - mock_response.model = "o4-mini-high" - provider.generate_content(prompt="Test", model_name="o4minihigh", temperature=1.0) + # Test o4mini -> o4-mini + mock_response.model = "o4-mini" + provider.generate_content(prompt="Test", model_name="o4mini", temperature=1.0) call_kwargs = mock_client.chat.completions.create.call_args[1] - assert call_kwargs["model"] == "o4-mini-high" + assert call_kwargs["model"] == "o4-mini" @patch("providers.openai_compatible.OpenAI") def test_generate_content_no_alias_passthrough(self, mock_openai_class): diff --git a/tests/test_openrouter_provider.py b/tests/test_openrouter_provider.py index 6d427ba..454f372 100644 --- a/tests/test_openrouter_provider.py +++ b/tests/test_openrouter_provider.py @@ -77,7 +77,7 @@ class TestOpenRouterProvider: assert provider._resolve_model_name("o3-mini") == "openai/o3-mini" assert provider._resolve_model_name("o3mini") == "openai/o3-mini" assert provider._resolve_model_name("o4-mini") == "openai/o4-mini" - assert provider._resolve_model_name("o4-mini-high") == "openai/o4-mini-high" + assert provider._resolve_model_name("o4-mini") == "openai/o4-mini" assert provider._resolve_model_name("claude") == "anthropic/claude-sonnet-4" assert provider._resolve_model_name("mistral") == "mistralai/mistral-large-2411" assert provider._resolve_model_name("deepseek") == "deepseek/deepseek-r1-0528" diff --git a/tests/test_openrouter_registry.py b/tests/test_openrouter_registry.py index f6ea000..6387ebe 100644 --- a/tests/test_openrouter_registry.py +++ b/tests/test_openrouter_registry.py @@ -24,7 +24,16 @@ class TestOpenRouterModelRegistry: def test_custom_config_path(self): """Test registry with custom config path.""" # Create temporary config - config_data = {"models": [{"model_name": "test/model-1", "aliases": ["test1", "t1"], "context_window": 4096}]} + config_data = { + "models": [ + { + "model_name": "test/model-1", + "aliases": ["test1", "t1"], + "context_window": 4096, + "max_output_tokens": 2048, + } + ] + } with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: json.dump(config_data, f) @@ -42,7 +51,11 @@ class TestOpenRouterModelRegistry: def test_environment_variable_override(self): """Test OPENROUTER_MODELS_PATH environment variable.""" # Create custom config - config_data = {"models": [{"model_name": "env/model", "aliases": ["envtest"], "context_window": 8192}]} + config_data = { + "models": [ + {"model_name": "env/model", "aliases": ["envtest"], "context_window": 8192, "max_output_tokens": 4096} + ] + } with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: json.dump(config_data, f) @@ -127,11 +140,12 @@ class TestOpenRouterModelRegistry: """Test that duplicate aliases are detected.""" config_data = { "models": [ - {"model_name": "test/model-1", "aliases": ["dupe"], "context_window": 4096}, + {"model_name": "test/model-1", "aliases": ["dupe"], "context_window": 4096, "max_output_tokens": 2048}, { "model_name": "test/model-2", "aliases": ["DUPE"], # Same alias, different case "context_window": 8192, + "max_output_tokens": 2048, }, ] } @@ -207,6 +221,7 @@ class TestOpenRouterModelRegistry: friendly_name="OpenRouter (test/full-featured)", aliases=["full"], context_window=128000, + max_output_tokens=8192, supports_extended_thinking=True, supports_system_prompts=True, supports_streaming=True, diff --git a/tests/test_providers.py b/tests/test_providers.py index 5401bc9..036ae9b 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -215,9 +215,7 @@ class TestOpenAIProvider: assert provider.validate_model_name("o3-mini") # Backwards compatibility assert provider.validate_model_name("o4-mini") assert provider.validate_model_name("o4mini") - assert provider.validate_model_name("o4-mini-high") - assert provider.validate_model_name("o4minihigh") - assert provider.validate_model_name("o4minihi") + assert provider.validate_model_name("o4-mini") assert not provider.validate_model_name("gpt-4o") assert not provider.validate_model_name("invalid-model") @@ -229,4 +227,4 @@ class TestOpenAIProvider: assert not provider.supports_thinking_mode("o3mini") assert not provider.supports_thinking_mode("o3-mini") assert not provider.supports_thinking_mode("o4-mini") - assert not provider.supports_thinking_mode("o4-mini-high") + assert not provider.supports_thinking_mode("o4-mini") diff --git a/tests/test_supported_models_aliases.py b/tests/test_supported_models_aliases.py index 6ed899f..1eb76b5 100644 --- a/tests/test_supported_models_aliases.py +++ b/tests/test_supported_models_aliases.py @@ -51,15 +51,14 @@ class TestSupportedModelsAliases: assert "o4mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases assert "o3mini" in provider.SUPPORTED_MODELS["o3-mini"].aliases assert "o3-pro" in provider.SUPPORTED_MODELS["o3-pro-2025-06-10"].aliases - assert "o4minihigh" in provider.SUPPORTED_MODELS["o4-mini-high"].aliases - assert "o4minihi" in provider.SUPPORTED_MODELS["o4-mini-high"].aliases + assert "o4mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases assert "gpt4.1" in provider.SUPPORTED_MODELS["gpt-4.1-2025-04-14"].aliases # Test alias resolution assert provider._resolve_model_name("mini") == "o4-mini" assert provider._resolve_model_name("o3mini") == "o3-mini" assert provider._resolve_model_name("o3-pro") == "o3-pro-2025-06-10" - assert provider._resolve_model_name("o4minihigh") == "o4-mini-high" + assert provider._resolve_model_name("o4mini") == "o4-mini" assert provider._resolve_model_name("gpt4.1") == "gpt-4.1-2025-04-14" # Test case insensitive resolution From 3b250c95dfcf981d04b336a7c4852c2515ce44e2 Mon Sep 17 00:00:00 2001 From: Fahad Date: Mon, 23 Jun 2025 18:56:47 +0400 Subject: [PATCH 7/8] Quick test mode for simulation tests Fixed o4-mini name, OpenAI removed o4-mini-high Add max_output_tokens property to ModelCapabilities Fixed tests after refactor --- simulator_tests/conversation_base_test.py | 4 ++ .../test_conversation_chain_validation.py | 17 ++++++--- .../test_token_allocation_validation.py | 38 +++++++++++-------- tests/test_conversation_memory.py | 8 ++-- tests/test_image_support_integration.py | 6 +-- tests/test_model_metadata_continuation.py | 8 ++-- 6 files changed, 49 insertions(+), 32 deletions(-) diff --git a/simulator_tests/conversation_base_test.py b/simulator_tests/conversation_base_test.py index 4502af2..f66df25 100644 --- a/simulator_tests/conversation_base_test.py +++ b/simulator_tests/conversation_base_test.py @@ -182,6 +182,10 @@ class ConversationBaseTest(BaseSimulatorTest): # Look for continuation_id in various places if isinstance(response_data, dict): + # Check top-level continuation_id (workflow tools) + if "continuation_id" in response_data: + return response_data["continuation_id"] + # Check metadata metadata = response_data.get("metadata", {}) if "thread_id" in metadata: diff --git a/simulator_tests/test_conversation_chain_validation.py b/simulator_tests/test_conversation_chain_validation.py index b033cab..03313d3 100644 --- a/simulator_tests/test_conversation_chain_validation.py +++ b/simulator_tests/test_conversation_chain_validation.py @@ -91,11 +91,14 @@ class TestClass: response_a2, continuation_id_a2 = self.call_mcp_tool( "analyze", { - "prompt": "Now analyze the code quality and suggest improvements.", - "files": [test_file_path], + "step": "Now analyze the code quality and suggest improvements.", + "step_number": 1, + "total_steps": 2, + "next_step_required": False, + "findings": "Continuing analysis from previous chat conversation to analyze code quality.", + "relevant_files": [test_file_path], "continuation_id": continuation_id_a1, "model": "flash", - "temperature": 0.7, }, ) @@ -154,10 +157,14 @@ class TestClass: response_b2, continuation_id_b2 = self.call_mcp_tool( "analyze", { - "prompt": "Analyze the previous greeting and suggest improvements.", + "step": "Analyze the previous greeting and suggest improvements.", + "step_number": 1, + "total_steps": 1, + "next_step_required": False, + "findings": "Analyzing the greeting from previous conversation and suggesting improvements.", + "relevant_files": [test_file_path], "continuation_id": continuation_id_b1, "model": "flash", - "temperature": 0.7, }, ) diff --git a/simulator_tests/test_token_allocation_validation.py b/simulator_tests/test_token_allocation_validation.py index 4a7ef8e..64c2208 100644 --- a/simulator_tests/test_token_allocation_validation.py +++ b/simulator_tests/test_token_allocation_validation.py @@ -206,11 +206,14 @@ if __name__ == "__main__": response2, continuation_id2 = self.call_mcp_tool( "analyze", { - "prompt": "Analyze the performance implications of these recursive functions.", - "files": [file1_path], + "step": "Analyze the performance implications of these recursive functions.", + "step_number": 1, + "total_steps": 1, + "next_step_required": False, + "findings": "Continuing from chat conversation to analyze performance implications of recursive functions.", + "relevant_files": [file1_path], "continuation_id": continuation_id1, # Continue the chat conversation "model": "flash", - "temperature": 0.7, }, ) @@ -221,10 +224,14 @@ if __name__ == "__main__": self.logger.info(f" ✅ Step 2 completed with continuation_id: {continuation_id2[:8]}...") continuation_ids.append(continuation_id2) - # Validate that we got a different continuation ID - if continuation_id2 == continuation_id1: - self.logger.error(" ❌ Step 2: Got same continuation ID as Step 1 - continuation not working") - return False + # Validate continuation ID behavior for workflow tools + # Workflow tools reuse the same continuation_id when continuing within a workflow session + # This is expected behavior and different from simple tools + if continuation_id2 != continuation_id1: + self.logger.info(" ✅ Step 2: Got new continuation ID (workflow behavior)") + else: + self.logger.info(" ✅ Step 2: Reused continuation ID (workflow session continuation)") + # Both behaviors are valid - what matters is that we got a continuation_id # Validate that Step 2 is building on Step 1's conversation # Check if the response references the previous conversation @@ -276,17 +283,16 @@ if __name__ == "__main__": all_have_continuation_ids = bool(continuation_id1 and continuation_id2 and continuation_id3) criteria.append(("All steps generated continuation IDs", all_have_continuation_ids)) - # 3. Each continuation ID is unique - unique_continuation_ids = len(set(continuation_ids)) == len(continuation_ids) - criteria.append(("Each response generated unique continuation ID", unique_continuation_ids)) + # 3. Continuation behavior validation (handles both simple and workflow tools) + # Simple tools create new IDs each time, workflow tools may reuse IDs within sessions + has_valid_continuation_pattern = len(continuation_ids) == 3 + criteria.append(("Valid continuation ID pattern", has_valid_continuation_pattern)) - # 4. Continuation IDs follow the expected pattern - step_ids_different = ( - len(continuation_ids) == 3 - and continuation_ids[0] != continuation_ids[1] - and continuation_ids[1] != continuation_ids[2] + # 4. Check for conversation continuity (more important than ID uniqueness) + conversation_has_continuity = len(continuation_ids) == 3 and all( + cid is not None for cid in continuation_ids ) - criteria.append(("All continuation IDs are different", step_ids_different)) + criteria.append(("Conversation continuity maintained", conversation_has_continuity)) # 5. Check responses build on each other (content validation) step1_has_function_analysis = "fibonacci" in response1.lower() or "factorial" in response1.lower() diff --git a/tests/test_conversation_memory.py b/tests/test_conversation_memory.py index 86a5f42..b6491e6 100644 --- a/tests/test_conversation_memory.py +++ b/tests/test_conversation_memory.py @@ -506,17 +506,17 @@ class TestConversationFlow: mock_client = Mock() mock_storage.return_value = mock_client - # Start conversation with files - thread_id = create_thread("analyze", {"prompt": "Analyze this codebase", "relevant_files": ["/project/src/"]}) + # Start conversation with files using a simple tool + thread_id = create_thread("chat", {"prompt": "Analyze this codebase", "files": ["/project/src/"]}) # Turn 1: Claude provides context with multiple files initial_context = ThreadContext( thread_id=thread_id, created_at="2023-01-01T00:00:00Z", last_updated_at="2023-01-01T00:00:00Z", - tool_name="analyze", + tool_name="chat", turns=[], - initial_context={"prompt": "Analyze this codebase", "relevant_files": ["/project/src/"]}, + initial_context={"prompt": "Analyze this codebase", "files": ["/project/src/"]}, ) mock_client.get.return_value = initial_context.model_dump_json() diff --git a/tests/test_image_support_integration.py b/tests/test_image_support_integration.py index daa062b..855c30e 100644 --- a/tests/test_image_support_integration.py +++ b/tests/test_image_support_integration.py @@ -483,14 +483,14 @@ class TestImageSupportIntegration: tool_name="chat", ) - # Create child thread linked to parent - child_thread_id = create_thread("debug", {"child": "context"}, parent_thread_id=parent_thread_id) + # Create child thread linked to parent using a simple tool + child_thread_id = create_thread("chat", {"prompt": "child context"}, parent_thread_id=parent_thread_id) add_turn( thread_id=child_thread_id, role="user", content="Child thread with more images", images=["child1.png", "shared.png"], # shared.png appears again (should prioritize newer) - tool_name="debug", + tool_name="chat", ) # Mock child thread context for get_thread call diff --git a/tests/test_model_metadata_continuation.py b/tests/test_model_metadata_continuation.py index 224aabf..5065804 100644 --- a/tests/test_model_metadata_continuation.py +++ b/tests/test_model_metadata_continuation.py @@ -89,7 +89,7 @@ class TestModelMetadataContinuation: @pytest.mark.asyncio async def test_multiple_turns_uses_last_assistant_model(self): """Test that with multiple turns, the last assistant turn's model is used.""" - thread_id = create_thread("analyze", {"prompt": "analyze this"}) + thread_id = create_thread("chat", {"prompt": "analyze this"}) # Add multiple turns with different models add_turn(thread_id, "assistant", "First response", model_name="gemini-2.5-flash", model_provider="google") @@ -185,11 +185,11 @@ class TestModelMetadataContinuation: async def test_thread_chain_model_preservation(self): """Test model preservation across thread chains (parent-child relationships).""" # Create parent thread - parent_id = create_thread("analyze", {"prompt": "analyze"}) + parent_id = create_thread("chat", {"prompt": "analyze"}) add_turn(parent_id, "assistant", "Analysis", model_name="gemini-2.5-pro", model_provider="google") - # Create child thread - child_id = create_thread("codereview", {"prompt": "review"}, parent_thread_id=parent_id) + # Create child thread using a simple tool instead of workflow tool + child_id = create_thread("chat", {"prompt": "review"}, parent_thread_id=parent_id) # Child thread should be able to access parent's model through chain traversal # NOTE: Current implementation only checks current thread (not parent threads) From a355b80afc6fbf9aef3e5e5706dee6c4b3cd9728 Mon Sep 17 00:00:00 2001 From: Illya Havsiyevych <44289086+illya-havsiyevych@users.noreply.github.com> Date: Mon, 23 Jun 2025 18:07:40 +0300 Subject: [PATCH 8/8] feat: Add DISABLED_TOOLS environment variable for selective tool disabling (#127) ## Description This PR adds support for selectively disabling tools via the DISABLED_TOOLS environment variable, allowing users to customize which MCP tools are available in their Zen server instance. This feature enables better control over tool availability for security, performance, or organizational requirements. ## Changes Made - [x] Added `DISABLED_TOOLS` environment variable support to selectively disable tools - [x] Implemented tool filtering logic with protection for essential tools (version, listmodels) - [x] Added comprehensive validation with warnings for unknown tools and attempts to disable essential tools - [x] Updated `.env.example` with DISABLED_TOOLS documentation and examples - [x] Added comprehensive test suite (16 tests) covering all edge cases - [x] No breaking changes - feature is opt-in with default behavior unchanged ## Configuration Add to `.env` file: ```bash # Optional: Tool Selection # Comma-separated list of tools to disable. If not set, all tools are enabled. # Essential tools (version, listmodels) cannot be disabled. # Available tools: chat, thinkdeep, planner, consensus, codereview, precommit, # debug, docgen, analyze, refactor, tracer, testgen # Examples: # DISABLED_TOOLS= # All tools enabled (default) # DISABLED_TOOLS=debug,tracer # Disable debug and tracer tools # DISABLED_TOOLS=planner,consensus # Disable planning tools --- .env.example | 10 ++ server.py | 92 ++++++++++++ tests/test_auto_mode_custom_provider_only.py | 6 +- tests/test_disabled_tools.py | 140 +++++++++++++++++++ 4 files changed, 245 insertions(+), 3 deletions(-) create mode 100644 tests/test_disabled_tools.py diff --git a/.env.example b/.env.example index 1d88d4c..b88bd70 100644 --- a/.env.example +++ b/.env.example @@ -143,3 +143,13 @@ MAX_CONVERSATION_TURNS=20 # ERROR: Shows only errors LOG_LEVEL=DEBUG +# Optional: Tool Selection +# Comma-separated list of tools to disable. If not set, all tools are enabled. +# Essential tools (version, listmodels) cannot be disabled. +# Available tools: chat, thinkdeep, planner, consensus, codereview, precommit, +# debug, docgen, analyze, refactor, tracer, testgen +# Examples: +# DISABLED_TOOLS= # All tools enabled (default) +# DISABLED_TOOLS=debug,tracer # Disable debug and tracer tools +# DISABLED_TOOLS=planner,consensus # Disable planning tools + diff --git a/server.py b/server.py index 9247aa6..ebb5ce2 100644 --- a/server.py +++ b/server.py @@ -158,6 +158,97 @@ logger = logging.getLogger(__name__) # This name is used by MCP clients to identify and connect to this specific server server: Server = Server("zen-server") + +# Constants for tool filtering +ESSENTIAL_TOOLS = {"version", "listmodels"} + + +def parse_disabled_tools_env() -> set[str]: + """ + Parse the DISABLED_TOOLS environment variable into a set of tool names. + + Returns: + Set of lowercase tool names to disable, empty set if none specified + """ + disabled_tools_env = os.getenv("DISABLED_TOOLS", "").strip() + if not disabled_tools_env: + return set() + return {t.strip().lower() for t in disabled_tools_env.split(",") if t.strip()} + + +def validate_disabled_tools(disabled_tools: set[str], all_tools: dict[str, Any]) -> None: + """ + Validate the disabled tools list and log appropriate warnings. + + Args: + disabled_tools: Set of tool names requested to be disabled + all_tools: Dictionary of all available tool instances + """ + essential_disabled = disabled_tools & ESSENTIAL_TOOLS + if essential_disabled: + logger.warning(f"Cannot disable essential tools: {sorted(essential_disabled)}") + unknown_tools = disabled_tools - set(all_tools.keys()) + if unknown_tools: + logger.warning(f"Unknown tools in DISABLED_TOOLS: {sorted(unknown_tools)}") + + +def apply_tool_filter(all_tools: dict[str, Any], disabled_tools: set[str]) -> dict[str, Any]: + """ + Apply the disabled tools filter to create the final tools dictionary. + + Args: + all_tools: Dictionary of all available tool instances + disabled_tools: Set of tool names to disable + + Returns: + Dictionary containing only enabled tools + """ + enabled_tools = {} + for tool_name, tool_instance in all_tools.items(): + if tool_name in ESSENTIAL_TOOLS or tool_name not in disabled_tools: + enabled_tools[tool_name] = tool_instance + else: + logger.debug(f"Tool '{tool_name}' disabled via DISABLED_TOOLS") + return enabled_tools + + +def log_tool_configuration(disabled_tools: set[str], enabled_tools: dict[str, Any]) -> None: + """ + Log the final tool configuration for visibility. + + Args: + disabled_tools: Set of tool names that were requested to be disabled + enabled_tools: Dictionary of tools that remain enabled + """ + if not disabled_tools: + logger.info("All tools enabled (DISABLED_TOOLS not set)") + return + actual_disabled = disabled_tools - ESSENTIAL_TOOLS + if actual_disabled: + logger.debug(f"Disabled tools: {sorted(actual_disabled)}") + logger.info(f"Active tools: {sorted(enabled_tools.keys())}") + + +def filter_disabled_tools(all_tools: dict[str, Any]) -> dict[str, Any]: + """ + Filter tools based on DISABLED_TOOLS environment variable. + + Args: + all_tools: Dictionary of all available tool instances + + Returns: + dict: Filtered dictionary containing only enabled tools + """ + disabled_tools = parse_disabled_tools_env() + if not disabled_tools: + log_tool_configuration(disabled_tools, all_tools) + return all_tools + validate_disabled_tools(disabled_tools, all_tools) + enabled_tools = apply_tool_filter(all_tools, disabled_tools) + log_tool_configuration(disabled_tools, enabled_tools) + return enabled_tools + + # Initialize the tool registry with all available AI-powered tools # Each tool provides specialized functionality for different development tasks # Tools are instantiated once and reused across requests (stateless design) @@ -178,6 +269,7 @@ TOOLS = { "listmodels": ListModelsTool(), # List all available AI models by provider "version": VersionTool(), # Display server version and system information } +TOOLS = filter_disabled_tools(TOOLS) # Rich prompt templates for all tools PROMPT_TEMPLATES = { diff --git a/tests/test_auto_mode_custom_provider_only.py b/tests/test_auto_mode_custom_provider_only.py index 5d03d4e..c97e649 100644 --- a/tests/test_auto_mode_custom_provider_only.py +++ b/tests/test_auto_mode_custom_provider_only.py @@ -70,7 +70,7 @@ class TestAutoModeCustomProviderOnly: } # Clear all other provider keys - clear_keys = ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"] + clear_keys = ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"] with patch.dict(os.environ, test_env, clear=False): # Ensure other provider keys are not set @@ -109,7 +109,7 @@ class TestAutoModeCustomProviderOnly: with patch.dict(os.environ, test_env, clear=False): # Clear other provider keys - for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]: + for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"]: if key in os.environ: del os.environ[key] @@ -177,7 +177,7 @@ class TestAutoModeCustomProviderOnly: with patch.dict(os.environ, test_env, clear=False): # Clear other provider keys - for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]: + for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"]: if key in os.environ: del os.environ[key] diff --git a/tests/test_disabled_tools.py b/tests/test_disabled_tools.py new file mode 100644 index 0000000..65a525f --- /dev/null +++ b/tests/test_disabled_tools.py @@ -0,0 +1,140 @@ +"""Tests for DISABLED_TOOLS environment variable functionality.""" + +import logging +import os +from unittest.mock import patch + +import pytest + +from server import ( + apply_tool_filter, + parse_disabled_tools_env, + validate_disabled_tools, +) + + +# Mock the tool classes since we're testing the filtering logic +class MockTool: + def __init__(self, name): + self.name = name + + +class TestDisabledTools: + """Test suite for DISABLED_TOOLS functionality.""" + + def test_parse_disabled_tools_empty(self): + """Empty string returns empty set (no tools disabled).""" + with patch.dict(os.environ, {"DISABLED_TOOLS": ""}): + assert parse_disabled_tools_env() == set() + + def test_parse_disabled_tools_not_set(self): + """Unset variable returns empty set.""" + with patch.dict(os.environ, {}, clear=True): + # Ensure DISABLED_TOOLS is not in environment + if "DISABLED_TOOLS" in os.environ: + del os.environ["DISABLED_TOOLS"] + assert parse_disabled_tools_env() == set() + + def test_parse_disabled_tools_single(self): + """Single tool name parsed correctly.""" + with patch.dict(os.environ, {"DISABLED_TOOLS": "debug"}): + assert parse_disabled_tools_env() == {"debug"} + + def test_parse_disabled_tools_multiple(self): + """Multiple tools with spaces parsed correctly.""" + with patch.dict(os.environ, {"DISABLED_TOOLS": "debug, analyze, refactor"}): + assert parse_disabled_tools_env() == {"debug", "analyze", "refactor"} + + def test_parse_disabled_tools_extra_spaces(self): + """Extra spaces and empty items handled correctly.""" + with patch.dict(os.environ, {"DISABLED_TOOLS": " debug , , analyze , "}): + assert parse_disabled_tools_env() == {"debug", "analyze"} + + def test_parse_disabled_tools_duplicates(self): + """Duplicate entries handled correctly (set removes duplicates).""" + with patch.dict(os.environ, {"DISABLED_TOOLS": "debug,analyze,debug"}): + assert parse_disabled_tools_env() == {"debug", "analyze"} + + def test_tool_filtering_logic(self): + """Test the complete filtering logic using the actual server functions.""" + # Simulate ALL_TOOLS + ALL_TOOLS = { + "chat": MockTool("chat"), + "debug": MockTool("debug"), + "analyze": MockTool("analyze"), + "version": MockTool("version"), + "listmodels": MockTool("listmodels"), + } + + # Test case 1: No tools disabled + disabled_tools = set() + enabled_tools = apply_tool_filter(ALL_TOOLS, disabled_tools) + + assert len(enabled_tools) == 5 # All tools included + assert set(enabled_tools.keys()) == set(ALL_TOOLS.keys()) + + # Test case 2: Disable some regular tools + disabled_tools = {"debug", "analyze"} + enabled_tools = apply_tool_filter(ALL_TOOLS, disabled_tools) + + assert len(enabled_tools) == 3 # chat, version, listmodels + assert "debug" not in enabled_tools + assert "analyze" not in enabled_tools + assert "chat" in enabled_tools + assert "version" in enabled_tools + assert "listmodels" in enabled_tools + + # Test case 3: Attempt to disable essential tools + disabled_tools = {"version", "chat"} + enabled_tools = apply_tool_filter(ALL_TOOLS, disabled_tools) + + assert "version" in enabled_tools # Essential tool not disabled + assert "chat" not in enabled_tools # Regular tool disabled + assert "listmodels" in enabled_tools # Essential tool included + + def test_unknown_tools_warning(self, caplog): + """Test that unknown tool names generate appropriate warnings.""" + ALL_TOOLS = { + "chat": MockTool("chat"), + "debug": MockTool("debug"), + "analyze": MockTool("analyze"), + "version": MockTool("version"), + "listmodels": MockTool("listmodels"), + } + disabled_tools = {"chat", "unknown_tool", "another_unknown"} + + with caplog.at_level(logging.WARNING): + validate_disabled_tools(disabled_tools, ALL_TOOLS) + assert "Unknown tools in DISABLED_TOOLS: ['another_unknown', 'unknown_tool']" in caplog.text + + def test_essential_tools_warning(self, caplog): + """Test warning when trying to disable essential tools.""" + ALL_TOOLS = { + "chat": MockTool("chat"), + "debug": MockTool("debug"), + "analyze": MockTool("analyze"), + "version": MockTool("version"), + "listmodels": MockTool("listmodels"), + } + disabled_tools = {"version", "chat", "debug"} + + with caplog.at_level(logging.WARNING): + validate_disabled_tools(disabled_tools, ALL_TOOLS) + assert "Cannot disable essential tools: ['version']" in caplog.text + + @pytest.mark.parametrize( + "env_value,expected", + [ + ("", set()), # Empty string + (" ", set()), # Only spaces + (",,,", set()), # Only commas + ("chat", {"chat"}), # Single tool + ("chat,debug", {"chat", "debug"}), # Multiple tools + ("chat, debug, analyze", {"chat", "debug", "analyze"}), # With spaces + ("chat,debug,chat", {"chat", "debug"}), # Duplicates + ], + ) + def test_parse_disabled_tools_parametrized(self, env_value, expected): + """Parametrized tests for various input formats.""" + with patch.dict(os.environ, {"DISABLED_TOOLS": env_value}): + assert parse_disabled_tools_env() == expected