fix: custom provider must only accept a model if it's declared explicitly. Upon model rejection (in auto mode) the list of available models is returned up-front to help with selection.

This commit is contained in:
Fahad
2025-10-02 13:49:23 +04:00
parent 82a03ce63f
commit d285fadf4c
6 changed files with 116 additions and 146 deletions

View File

@@ -6,7 +6,7 @@ from typing import Optional
from .openai_compatible import OpenAICompatibleProvider
from .openrouter_registry import OpenRouterModelRegistry
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
from .shared import ModelCapabilities, ModelResponse, ProviderType
class CustomProvider(OpenAICompatibleProvider):
@@ -91,51 +91,22 @@ class CustomProvider(OpenAICompatibleProvider):
canonical_name: str,
requested_name: Optional[str] = None,
) -> Optional[ModelCapabilities]:
"""Return custom capabilities from the registry or generic defaults."""
"""Return capabilities for models explicitly marked as custom."""
builtin = super()._lookup_capabilities(canonical_name, requested_name)
if builtin is not None:
return builtin
capabilities = self._registry.get_capabilities(canonical_name)
if capabilities:
config = self._registry.resolve(canonical_name)
if config and getattr(config, "is_custom", False):
capabilities.provider = ProviderType.CUSTOM
return capabilities
# Non-custom models should fall through so OpenRouter handles them
return None
registry_entry = self._registry.resolve(canonical_name)
if registry_entry and getattr(registry_entry, "is_custom", False):
registry_entry.provider = ProviderType.CUSTOM
return registry_entry
logging.debug(
f"Using generic capabilities for '{canonical_name}' via Custom API. "
"Consider adding to custom_models.json for specific capabilities."
"Custom provider cannot resolve model '%s'; ensure it is declared with 'is_custom': true in custom_models.json",
canonical_name,
)
supports_temperature, temperature_constraint, temperature_reason = TemperatureConstraint.resolve_settings(
canonical_name
)
logging.warning(
f"Model '{canonical_name}' not found in custom_models.json. Using generic capabilities with inferred settings. "
f"Temperature support: {supports_temperature} ({temperature_reason}). "
"For better accuracy, add this model to your custom_models.json configuration."
)
generic = ModelCapabilities(
provider=ProviderType.CUSTOM,
model_name=canonical_name,
friendly_name=f"{self.FRIENDLY_NAME} ({canonical_name})",
context_window=32_768,
max_output_tokens=32_768,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False,
supports_temperature=supports_temperature,
temperature_constraint=temperature_constraint,
)
generic._is_generic = True
return generic
return None
def get_provider_type(self) -> ProviderType:
"""Identify this provider for restriction and logging logic."""
@@ -146,49 +117,6 @@ class CustomProvider(OpenAICompatibleProvider):
# Validation
# ------------------------------------------------------------------
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is allowed.
For custom endpoints, only accept models that are explicitly intended for
local/custom usage. This provider should NOT handle OpenRouter or cloud models.
Args:
model_name: Model name to validate
Returns:
True if model is intended for custom/local endpoint
"""
if super().validate_model_name(model_name):
return True
config = self._registry.resolve(model_name)
if config and not getattr(config, "is_custom", False):
return False
clean_model_name = model_name
if ":" in model_name:
clean_model_name = model_name.split(":", 1)[0]
logging.debug(f"Stripped version tag from '{model_name}' -> '{clean_model_name}'")
if super().validate_model_name(clean_model_name):
return True
config = self._registry.resolve(clean_model_name)
if config and not getattr(config, "is_custom", False):
return False
lowered = clean_model_name.lower()
if any(indicator in lowered for indicator in ["local", "ollama", "vllm", "lmstudio"]):
logging.debug(f"Model '{clean_model_name}' validated via local indicators")
return True
if "/" not in clean_model_name:
logging.debug(f"Model '{clean_model_name}' validated as potential local model (no vendor prefix)")
return True
logging.debug(f"Model '{model_name}' rejected by custom provider (appears to be cloud model)")
return False
# ------------------------------------------------------------------
# Request execution
# ------------------------------------------------------------------

View File

@@ -76,25 +76,31 @@ class OpenRouterProvider(OpenAICompatibleProvider):
if capabilities:
return capabilities
logging.debug(
f"Using generic capabilities for '{canonical_name}' via OpenRouter. "
"Consider adding to custom_models.json for specific capabilities."
)
base_identifier = canonical_name.split(":", 1)[0]
if "/" in base_identifier:
logging.debug(
"Using generic OpenRouter capabilities for %s (provider/model format detected)", canonical_name
)
generic = ModelCapabilities(
provider=ProviderType.OPENROUTER,
model_name=canonical_name,
friendly_name=self.FRIENDLY_NAME,
context_window=32_768,
max_output_tokens=32_768,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False,
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
)
generic._is_generic = True
return generic
generic = ModelCapabilities(
provider=ProviderType.OPENROUTER,
model_name=canonical_name,
friendly_name=self.FRIENDLY_NAME,
context_window=32_768,
max_output_tokens=32_768,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False,
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
logging.debug(
"Rejecting unknown OpenRouter model '%s' (no provider prefix); requires explicit configuration",
canonical_name,
)
generic._is_generic = True
return generic
return None
# ------------------------------------------------------------------
# Provider identity

View File

@@ -36,12 +36,16 @@ class TestCustomProvider:
CustomProvider(api_key="test-key")
def test_validate_model_names_always_true(self):
"""Test CustomProvider accepts any model name."""
"""Test CustomProvider validates model names correctly."""
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
# Known model should validate
assert provider.validate_model_name("llama3.2")
assert provider.validate_model_name("unknown-model")
assert provider.validate_model_name("anything")
# For custom provider, unknown models return False when not in registry
# This is expected behavior - custom models need to be declared in custom_models.json
assert not provider.validate_model_name("unknown-model")
assert not provider.validate_model_name("anything")
def test_get_capabilities_from_registry(self):
"""Test get_capabilities returns registry capabilities when available."""
@@ -71,17 +75,12 @@ class TestCustomProvider:
os.environ["OPENROUTER_ALLOWED_MODELS"] = original_env
def test_get_capabilities_generic_fallback(self):
"""Test get_capabilities returns generic capabilities for unknown models."""
"""Test get_capabilities raises error for unknown models not in registry."""
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
capabilities = provider.get_capabilities("unknown-model-xyz")
assert capabilities.provider == ProviderType.CUSTOM
assert capabilities.model_name == "unknown-model-xyz"
assert capabilities.context_window == 32_768 # Conservative default
assert not capabilities.supports_extended_thinking
assert capabilities.supports_system_prompts
assert capabilities.supports_streaming
# Unknown models should raise ValueError when not in registry
with pytest.raises(ValueError, match="Unsupported model 'unknown-model-xyz' for provider custom"):
provider.get_capabilities("unknown-model-xyz")
def test_model_alias_resolution(self):
"""Test model alias resolution works correctly."""
@@ -100,8 +99,12 @@ class TestCustomProvider:
"""Custom provider generic capabilities default to no thinking mode."""
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
# llama3.2 is a known model that should work
assert not provider.get_capabilities("llama3.2").supports_extended_thinking
assert not provider.get_capabilities("any-model").supports_extended_thinking
# Unknown models should raise error
with pytest.raises(ValueError, match="Unsupported model 'any-model' for provider custom"):
provider.get_capabilities("any-model")
@patch("providers.custom.OpenAICompatibleProvider.generate_content")
def test_generate_content_with_alias_resolution(self, mock_generate):

View File

@@ -42,12 +42,15 @@ class TestOpenRouterProvider:
"""Test model validation."""
provider = OpenRouterProvider(api_key="test-key")
# Should accept any model - OpenRouter handles validation
assert provider.validate_model_name("gpt-4") is True
assert provider.validate_model_name("claude-4-opus") is True
assert provider.validate_model_name("any-model-name") is True
assert provider.validate_model_name("GPT-4") is True
assert provider.validate_model_name("unknown-model") is True
# OpenRouter accepts models with provider prefixes or known models
assert provider.validate_model_name("openai/gpt-4") is True
assert provider.validate_model_name("anthropic/claude-3-opus") is True
assert provider.validate_model_name("google/any-model-name") is True
assert provider.validate_model_name("groq/llama-3.1-8b") is True
# Unknown models without provider prefix are rejected
assert provider.validate_model_name("gpt-4") is False
assert provider.validate_model_name("unknown-model") is False
def test_get_capabilities(self):
"""Test capability generation."""
@@ -59,10 +62,14 @@ class TestOpenRouterProvider:
assert caps.model_name == "openai/o3" # Resolved name
assert caps.friendly_name == "OpenRouter (openai/o3)"
# Test with a model not in registry - should get generic capabilities
caps = provider.get_capabilities("unknown-model")
# Test with a model not in registry - should raise error
with pytest.raises(ValueError, match="Unsupported model 'unknown-model' for provider openrouter"):
provider.get_capabilities("unknown-model")
# Test with model that has provider prefix - should get generic capabilities
caps = provider.get_capabilities("provider/unknown-model")
assert caps.provider == ProviderType.OPENROUTER
assert caps.model_name == "unknown-model"
assert caps.model_name == "provider/unknown-model"
assert caps.context_window == 32_768 # Safe default
assert hasattr(caps, "_is_generic") and caps._is_generic is True

View File

@@ -288,6 +288,42 @@ class BaseTool(ABC):
return unique_models
def _format_available_models_list(self) -> str:
"""Return a human-friendly list of available models or guidance when none found."""
available_models = self._get_available_models()
if not available_models:
return "No models detected. Configure provider credentials or set DEFAULT_MODEL to a valid option."
return ", ".join(available_models)
def _build_model_unavailable_message(self, model_name: str) -> str:
"""Compose a consistent error message for unavailable model scenarios."""
tool_category = self.get_model_category()
suggested_model = ModelProviderRegistry.get_preferred_fallback_model(tool_category)
available_models_text = self._format_available_models_list()
return (
f"Model '{model_name}' is not available with current API keys. "
f"Available models: {available_models_text}. "
f"Suggested model for {self.get_name()}: '{suggested_model}' "
f"(category: {tool_category.value}). Select the strongest reasoning model that fits the task."
)
def _build_auto_mode_required_message(self) -> str:
"""Compose the auto-mode prompt when an explicit model selection is required."""
tool_category = self.get_model_category()
suggested_model = ModelProviderRegistry.get_preferred_fallback_model(tool_category)
available_models_text = self._format_available_models_list()
return (
"Model parameter is required in auto mode. "
f"Available models: {available_models_text}. "
f"Suggested model for {self.get_name()}: '{suggested_model}' "
f"(category: {tool_category.value}). Select the strongest reasoning model that fits the task."
)
def get_model_field_schema(self) -> dict[str, Any]:
"""
Generate the model field schema based on auto mode configuration.
@@ -467,8 +503,7 @@ class BaseTool(ABC):
provider = ModelProviderRegistry.get_provider_for_model(model_name)
if not provider:
logger.error(f"No provider found for model '{model_name}' in {self.name} tool")
available_models = ModelProviderRegistry.get_available_models()
raise ValueError(f"Model '{model_name}' is not available. Available models: {available_models}")
raise ValueError(self._build_model_unavailable_message(model_name))
return provider
except Exception as e:
@@ -1127,29 +1162,11 @@ When recommending searches, be specific about what information you need and why
# For tests: Check if we should require model selection (auto mode)
if self._should_require_model_selection(model_name):
# Get suggested model based on tool category
from providers.registry import ModelProviderRegistry
tool_category = self.get_model_category()
suggested_model = ModelProviderRegistry.get_preferred_fallback_model(tool_category)
# Build error message based on why selection is required
if model_name.lower() == "auto":
error_message = (
f"Model parameter is required in auto mode. "
f"Suggested model for {self.get_name()}: '{suggested_model}' "
f"(category: {tool_category.value})"
)
error_message = self._build_auto_mode_required_message()
else:
# Model was specified but not available
available_models = self._get_available_models()
error_message = (
f"Model '{model_name}' is not available with current API keys. "
f"Available models: {', '.join(available_models)}. "
f"Suggested model for {self.get_name()}: '{suggested_model}' "
f"(category: {tool_category.value})"
)
error_message = self._build_model_unavailable_message(model_name)
raise ValueError(error_message)
# Create model context for tests
@@ -1237,7 +1254,7 @@ When recommending searches, be specific about what information you need and why
# Generic error response for any unavailable model
return {
"status": "error",
"content": f"Model '{model_context}' is not available. {str(e)}",
"content": self._build_model_unavailable_message(str(model_context)),
"content_type": "text",
"metadata": {
"error_type": "validation_error",
@@ -1264,7 +1281,7 @@ When recommending searches, be specific about what information you need and why
model_name = getattr(model_context, "model_name", "unknown")
return {
"status": "error",
"content": f"Model '{model_name}' is not available. {str(e)}",
"content": self._build_model_unavailable_message(model_name),
"content_type": "text",
"metadata": {
"error_type": "validation_error",

View File

@@ -73,8 +73,17 @@ class ModelContext:
if self._provider is None:
self._provider = ModelProviderRegistry.get_provider_for_model(self.model_name)
if not self._provider:
available_models = ModelProviderRegistry.get_available_models()
raise ValueError(f"Model '{self.model_name}' is not available. Available models: {available_models}")
available_models = ModelProviderRegistry.get_available_model_names()
if available_models:
available_text = ", ".join(available_models)
else:
available_text = (
"No models detected. Configure provider credentials or set DEFAULT_MODEL to a valid option."
)
raise ValueError(
f"Model '{self.model_name}' is not available with current API keys. Available models: {available_text}."
)
return self._provider
@property