From d285fadf4cc1ab0189758a99fb24578fb6cc2c97 Mon Sep 17 00:00:00 2001 From: Fahad Date: Thu, 2 Oct 2025 13:49:23 +0400 Subject: [PATCH] fix: custom provider must only accept a model if it's declared explicitly. Upon model rejection (in auto mode) the list of available models is returned up-front to help with selection. --- providers/custom.py | 90 ++++--------------------------- providers/openrouter.py | 40 ++++++++------ tests/test_custom_provider.py | 29 +++++----- tests/test_openrouter_provider.py | 25 +++++---- tools/shared/base_tool.py | 65 +++++++++++++--------- utils/model_context.py | 13 ++++- 6 files changed, 116 insertions(+), 146 deletions(-) diff --git a/providers/custom.py b/providers/custom.py index 63e6f8e..4f7eb50 100644 --- a/providers/custom.py +++ b/providers/custom.py @@ -6,7 +6,7 @@ from typing import Optional from .openai_compatible import OpenAICompatibleProvider from .openrouter_registry import OpenRouterModelRegistry -from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint +from .shared import ModelCapabilities, ModelResponse, ProviderType class CustomProvider(OpenAICompatibleProvider): @@ -91,51 +91,22 @@ class CustomProvider(OpenAICompatibleProvider): canonical_name: str, requested_name: Optional[str] = None, ) -> Optional[ModelCapabilities]: - """Return custom capabilities from the registry or generic defaults.""" + """Return capabilities for models explicitly marked as custom.""" builtin = super()._lookup_capabilities(canonical_name, requested_name) if builtin is not None: return builtin - capabilities = self._registry.get_capabilities(canonical_name) - if capabilities: - config = self._registry.resolve(canonical_name) - if config and getattr(config, "is_custom", False): - capabilities.provider = ProviderType.CUSTOM - return capabilities - # Non-custom models should fall through so OpenRouter handles them - return None + registry_entry = self._registry.resolve(canonical_name) + if registry_entry and getattr(registry_entry, "is_custom", False): + registry_entry.provider = ProviderType.CUSTOM + return registry_entry logging.debug( - f"Using generic capabilities for '{canonical_name}' via Custom API. " - "Consider adding to custom_models.json for specific capabilities." + "Custom provider cannot resolve model '%s'; ensure it is declared with 'is_custom': true in custom_models.json", + canonical_name, ) - - supports_temperature, temperature_constraint, temperature_reason = TemperatureConstraint.resolve_settings( - canonical_name - ) - - logging.warning( - f"Model '{canonical_name}' not found in custom_models.json. Using generic capabilities with inferred settings. " - f"Temperature support: {supports_temperature} ({temperature_reason}). " - "For better accuracy, add this model to your custom_models.json configuration." - ) - - generic = ModelCapabilities( - provider=ProviderType.CUSTOM, - model_name=canonical_name, - friendly_name=f"{self.FRIENDLY_NAME} ({canonical_name})", - context_window=32_768, - max_output_tokens=32_768, - supports_extended_thinking=False, - supports_system_prompts=True, - supports_streaming=True, - supports_function_calling=False, - supports_temperature=supports_temperature, - temperature_constraint=temperature_constraint, - ) - generic._is_generic = True - return generic + return None def get_provider_type(self) -> ProviderType: """Identify this provider for restriction and logging logic.""" @@ -146,49 +117,6 @@ class CustomProvider(OpenAICompatibleProvider): # Validation # ------------------------------------------------------------------ - def validate_model_name(self, model_name: str) -> bool: - """Validate if the model name is allowed. - - For custom endpoints, only accept models that are explicitly intended for - local/custom usage. This provider should NOT handle OpenRouter or cloud models. - - Args: - model_name: Model name to validate - - Returns: - True if model is intended for custom/local endpoint - """ - if super().validate_model_name(model_name): - return True - - config = self._registry.resolve(model_name) - if config and not getattr(config, "is_custom", False): - return False - - clean_model_name = model_name - if ":" in model_name: - clean_model_name = model_name.split(":", 1)[0] - logging.debug(f"Stripped version tag from '{model_name}' -> '{clean_model_name}'") - - if super().validate_model_name(clean_model_name): - return True - - config = self._registry.resolve(clean_model_name) - if config and not getattr(config, "is_custom", False): - return False - - lowered = clean_model_name.lower() - if any(indicator in lowered for indicator in ["local", "ollama", "vllm", "lmstudio"]): - logging.debug(f"Model '{clean_model_name}' validated via local indicators") - return True - - if "/" not in clean_model_name: - logging.debug(f"Model '{clean_model_name}' validated as potential local model (no vendor prefix)") - return True - - logging.debug(f"Model '{model_name}' rejected by custom provider (appears to be cloud model)") - return False - # ------------------------------------------------------------------ # Request execution # ------------------------------------------------------------------ diff --git a/providers/openrouter.py b/providers/openrouter.py index b4b9d6a..ddb7745 100644 --- a/providers/openrouter.py +++ b/providers/openrouter.py @@ -76,25 +76,31 @@ class OpenRouterProvider(OpenAICompatibleProvider): if capabilities: return capabilities - logging.debug( - f"Using generic capabilities for '{canonical_name}' via OpenRouter. " - "Consider adding to custom_models.json for specific capabilities." - ) + base_identifier = canonical_name.split(":", 1)[0] + if "/" in base_identifier: + logging.debug( + "Using generic OpenRouter capabilities for %s (provider/model format detected)", canonical_name + ) + generic = ModelCapabilities( + provider=ProviderType.OPENROUTER, + model_name=canonical_name, + friendly_name=self.FRIENDLY_NAME, + context_window=32_768, + max_output_tokens=32_768, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, + temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0), + ) + generic._is_generic = True + return generic - generic = ModelCapabilities( - provider=ProviderType.OPENROUTER, - model_name=canonical_name, - friendly_name=self.FRIENDLY_NAME, - context_window=32_768, - max_output_tokens=32_768, - supports_extended_thinking=False, - supports_system_prompts=True, - supports_streaming=True, - supports_function_calling=False, - temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0), + logging.debug( + "Rejecting unknown OpenRouter model '%s' (no provider prefix); requires explicit configuration", + canonical_name, ) - generic._is_generic = True - return generic + return None # ------------------------------------------------------------------ # Provider identity diff --git a/tests/test_custom_provider.py b/tests/test_custom_provider.py index 4f7ca30..2733e2c 100644 --- a/tests/test_custom_provider.py +++ b/tests/test_custom_provider.py @@ -36,12 +36,16 @@ class TestCustomProvider: CustomProvider(api_key="test-key") def test_validate_model_names_always_true(self): - """Test CustomProvider accepts any model name.""" + """Test CustomProvider validates model names correctly.""" provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1") + # Known model should validate assert provider.validate_model_name("llama3.2") - assert provider.validate_model_name("unknown-model") - assert provider.validate_model_name("anything") + + # For custom provider, unknown models return False when not in registry + # This is expected behavior - custom models need to be declared in custom_models.json + assert not provider.validate_model_name("unknown-model") + assert not provider.validate_model_name("anything") def test_get_capabilities_from_registry(self): """Test get_capabilities returns registry capabilities when available.""" @@ -71,17 +75,12 @@ class TestCustomProvider: os.environ["OPENROUTER_ALLOWED_MODELS"] = original_env def test_get_capabilities_generic_fallback(self): - """Test get_capabilities returns generic capabilities for unknown models.""" + """Test get_capabilities raises error for unknown models not in registry.""" provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1") - capabilities = provider.get_capabilities("unknown-model-xyz") - - assert capabilities.provider == ProviderType.CUSTOM - assert capabilities.model_name == "unknown-model-xyz" - assert capabilities.context_window == 32_768 # Conservative default - assert not capabilities.supports_extended_thinking - assert capabilities.supports_system_prompts - assert capabilities.supports_streaming + # Unknown models should raise ValueError when not in registry + with pytest.raises(ValueError, match="Unsupported model 'unknown-model-xyz' for provider custom"): + provider.get_capabilities("unknown-model-xyz") def test_model_alias_resolution(self): """Test model alias resolution works correctly.""" @@ -100,8 +99,12 @@ class TestCustomProvider: """Custom provider generic capabilities default to no thinking mode.""" provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1") + # llama3.2 is a known model that should work assert not provider.get_capabilities("llama3.2").supports_extended_thinking - assert not provider.get_capabilities("any-model").supports_extended_thinking + + # Unknown models should raise error + with pytest.raises(ValueError, match="Unsupported model 'any-model' for provider custom"): + provider.get_capabilities("any-model") @patch("providers.custom.OpenAICompatibleProvider.generate_content") def test_generate_content_with_alias_resolution(self, mock_generate): diff --git a/tests/test_openrouter_provider.py b/tests/test_openrouter_provider.py index 0731646..1df2b2a 100644 --- a/tests/test_openrouter_provider.py +++ b/tests/test_openrouter_provider.py @@ -42,12 +42,15 @@ class TestOpenRouterProvider: """Test model validation.""" provider = OpenRouterProvider(api_key="test-key") - # Should accept any model - OpenRouter handles validation - assert provider.validate_model_name("gpt-4") is True - assert provider.validate_model_name("claude-4-opus") is True - assert provider.validate_model_name("any-model-name") is True - assert provider.validate_model_name("GPT-4") is True - assert provider.validate_model_name("unknown-model") is True + # OpenRouter accepts models with provider prefixes or known models + assert provider.validate_model_name("openai/gpt-4") is True + assert provider.validate_model_name("anthropic/claude-3-opus") is True + assert provider.validate_model_name("google/any-model-name") is True + assert provider.validate_model_name("groq/llama-3.1-8b") is True + + # Unknown models without provider prefix are rejected + assert provider.validate_model_name("gpt-4") is False + assert provider.validate_model_name("unknown-model") is False def test_get_capabilities(self): """Test capability generation.""" @@ -59,10 +62,14 @@ class TestOpenRouterProvider: assert caps.model_name == "openai/o3" # Resolved name assert caps.friendly_name == "OpenRouter (openai/o3)" - # Test with a model not in registry - should get generic capabilities - caps = provider.get_capabilities("unknown-model") + # Test with a model not in registry - should raise error + with pytest.raises(ValueError, match="Unsupported model 'unknown-model' for provider openrouter"): + provider.get_capabilities("unknown-model") + + # Test with model that has provider prefix - should get generic capabilities + caps = provider.get_capabilities("provider/unknown-model") assert caps.provider == ProviderType.OPENROUTER - assert caps.model_name == "unknown-model" + assert caps.model_name == "provider/unknown-model" assert caps.context_window == 32_768 # Safe default assert hasattr(caps, "_is_generic") and caps._is_generic is True diff --git a/tools/shared/base_tool.py b/tools/shared/base_tool.py index eb3995b..ac1f5fd 100644 --- a/tools/shared/base_tool.py +++ b/tools/shared/base_tool.py @@ -288,6 +288,42 @@ class BaseTool(ABC): return unique_models + def _format_available_models_list(self) -> str: + """Return a human-friendly list of available models or guidance when none found.""" + + available_models = self._get_available_models() + if not available_models: + return "No models detected. Configure provider credentials or set DEFAULT_MODEL to a valid option." + return ", ".join(available_models) + + def _build_model_unavailable_message(self, model_name: str) -> str: + """Compose a consistent error message for unavailable model scenarios.""" + + tool_category = self.get_model_category() + suggested_model = ModelProviderRegistry.get_preferred_fallback_model(tool_category) + available_models_text = self._format_available_models_list() + + return ( + f"Model '{model_name}' is not available with current API keys. " + f"Available models: {available_models_text}. " + f"Suggested model for {self.get_name()}: '{suggested_model}' " + f"(category: {tool_category.value}). Select the strongest reasoning model that fits the task." + ) + + def _build_auto_mode_required_message(self) -> str: + """Compose the auto-mode prompt when an explicit model selection is required.""" + + tool_category = self.get_model_category() + suggested_model = ModelProviderRegistry.get_preferred_fallback_model(tool_category) + available_models_text = self._format_available_models_list() + + return ( + "Model parameter is required in auto mode. " + f"Available models: {available_models_text}. " + f"Suggested model for {self.get_name()}: '{suggested_model}' " + f"(category: {tool_category.value}). Select the strongest reasoning model that fits the task." + ) + def get_model_field_schema(self) -> dict[str, Any]: """ Generate the model field schema based on auto mode configuration. @@ -467,8 +503,7 @@ class BaseTool(ABC): provider = ModelProviderRegistry.get_provider_for_model(model_name) if not provider: logger.error(f"No provider found for model '{model_name}' in {self.name} tool") - available_models = ModelProviderRegistry.get_available_models() - raise ValueError(f"Model '{model_name}' is not available. Available models: {available_models}") + raise ValueError(self._build_model_unavailable_message(model_name)) return provider except Exception as e: @@ -1127,29 +1162,11 @@ When recommending searches, be specific about what information you need and why # For tests: Check if we should require model selection (auto mode) if self._should_require_model_selection(model_name): - # Get suggested model based on tool category - from providers.registry import ModelProviderRegistry - - tool_category = self.get_model_category() - suggested_model = ModelProviderRegistry.get_preferred_fallback_model(tool_category) - # Build error message based on why selection is required if model_name.lower() == "auto": - error_message = ( - f"Model parameter is required in auto mode. " - f"Suggested model for {self.get_name()}: '{suggested_model}' " - f"(category: {tool_category.value})" - ) + error_message = self._build_auto_mode_required_message() else: - # Model was specified but not available - available_models = self._get_available_models() - - error_message = ( - f"Model '{model_name}' is not available with current API keys. " - f"Available models: {', '.join(available_models)}. " - f"Suggested model for {self.get_name()}: '{suggested_model}' " - f"(category: {tool_category.value})" - ) + error_message = self._build_model_unavailable_message(model_name) raise ValueError(error_message) # Create model context for tests @@ -1237,7 +1254,7 @@ When recommending searches, be specific about what information you need and why # Generic error response for any unavailable model return { "status": "error", - "content": f"Model '{model_context}' is not available. {str(e)}", + "content": self._build_model_unavailable_message(str(model_context)), "content_type": "text", "metadata": { "error_type": "validation_error", @@ -1264,7 +1281,7 @@ When recommending searches, be specific about what information you need and why model_name = getattr(model_context, "model_name", "unknown") return { "status": "error", - "content": f"Model '{model_name}' is not available. {str(e)}", + "content": self._build_model_unavailable_message(model_name), "content_type": "text", "metadata": { "error_type": "validation_error", diff --git a/utils/model_context.py b/utils/model_context.py index e0f5bd5..c4015cc 100644 --- a/utils/model_context.py +++ b/utils/model_context.py @@ -73,8 +73,17 @@ class ModelContext: if self._provider is None: self._provider = ModelProviderRegistry.get_provider_for_model(self.model_name) if not self._provider: - available_models = ModelProviderRegistry.get_available_models() - raise ValueError(f"Model '{self.model_name}' is not available. Available models: {available_models}") + available_models = ModelProviderRegistry.get_available_model_names() + if available_models: + available_text = ", ".join(available_models) + else: + available_text = ( + "No models detected. Configure provider credentials or set DEFAULT_MODEL to a valid option." + ) + + raise ValueError( + f"Model '{self.model_name}' is not available with current API keys. Available models: {available_text}." + ) return self._provider @property