fix: custom provider must only accept a model if it's declared explicitly. Upon model rejection (in auto mode) the list of available models is returned up-front to help with selection.
This commit is contained in:
@@ -6,7 +6,7 @@ from typing import Optional
|
||||
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .openrouter_registry import OpenRouterModelRegistry
|
||||
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
|
||||
from .shared import ModelCapabilities, ModelResponse, ProviderType
|
||||
|
||||
|
||||
class CustomProvider(OpenAICompatibleProvider):
|
||||
@@ -91,51 +91,22 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
canonical_name: str,
|
||||
requested_name: Optional[str] = None,
|
||||
) -> Optional[ModelCapabilities]:
|
||||
"""Return custom capabilities from the registry or generic defaults."""
|
||||
"""Return capabilities for models explicitly marked as custom."""
|
||||
|
||||
builtin = super()._lookup_capabilities(canonical_name, requested_name)
|
||||
if builtin is not None:
|
||||
return builtin
|
||||
|
||||
capabilities = self._registry.get_capabilities(canonical_name)
|
||||
if capabilities:
|
||||
config = self._registry.resolve(canonical_name)
|
||||
if config and getattr(config, "is_custom", False):
|
||||
capabilities.provider = ProviderType.CUSTOM
|
||||
return capabilities
|
||||
# Non-custom models should fall through so OpenRouter handles them
|
||||
return None
|
||||
registry_entry = self._registry.resolve(canonical_name)
|
||||
if registry_entry and getattr(registry_entry, "is_custom", False):
|
||||
registry_entry.provider = ProviderType.CUSTOM
|
||||
return registry_entry
|
||||
|
||||
logging.debug(
|
||||
f"Using generic capabilities for '{canonical_name}' via Custom API. "
|
||||
"Consider adding to custom_models.json for specific capabilities."
|
||||
"Custom provider cannot resolve model '%s'; ensure it is declared with 'is_custom': true in custom_models.json",
|
||||
canonical_name,
|
||||
)
|
||||
|
||||
supports_temperature, temperature_constraint, temperature_reason = TemperatureConstraint.resolve_settings(
|
||||
canonical_name
|
||||
)
|
||||
|
||||
logging.warning(
|
||||
f"Model '{canonical_name}' not found in custom_models.json. Using generic capabilities with inferred settings. "
|
||||
f"Temperature support: {supports_temperature} ({temperature_reason}). "
|
||||
"For better accuracy, add this model to your custom_models.json configuration."
|
||||
)
|
||||
|
||||
generic = ModelCapabilities(
|
||||
provider=ProviderType.CUSTOM,
|
||||
model_name=canonical_name,
|
||||
friendly_name=f"{self.FRIENDLY_NAME} ({canonical_name})",
|
||||
context_window=32_768,
|
||||
max_output_tokens=32_768,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False,
|
||||
supports_temperature=supports_temperature,
|
||||
temperature_constraint=temperature_constraint,
|
||||
)
|
||||
generic._is_generic = True
|
||||
return generic
|
||||
return None
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Identify this provider for restriction and logging logic."""
|
||||
@@ -146,49 +117,6 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
# Validation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is allowed.
|
||||
|
||||
For custom endpoints, only accept models that are explicitly intended for
|
||||
local/custom usage. This provider should NOT handle OpenRouter or cloud models.
|
||||
|
||||
Args:
|
||||
model_name: Model name to validate
|
||||
|
||||
Returns:
|
||||
True if model is intended for custom/local endpoint
|
||||
"""
|
||||
if super().validate_model_name(model_name):
|
||||
return True
|
||||
|
||||
config = self._registry.resolve(model_name)
|
||||
if config and not getattr(config, "is_custom", False):
|
||||
return False
|
||||
|
||||
clean_model_name = model_name
|
||||
if ":" in model_name:
|
||||
clean_model_name = model_name.split(":", 1)[0]
|
||||
logging.debug(f"Stripped version tag from '{model_name}' -> '{clean_model_name}'")
|
||||
|
||||
if super().validate_model_name(clean_model_name):
|
||||
return True
|
||||
|
||||
config = self._registry.resolve(clean_model_name)
|
||||
if config and not getattr(config, "is_custom", False):
|
||||
return False
|
||||
|
||||
lowered = clean_model_name.lower()
|
||||
if any(indicator in lowered for indicator in ["local", "ollama", "vllm", "lmstudio"]):
|
||||
logging.debug(f"Model '{clean_model_name}' validated via local indicators")
|
||||
return True
|
||||
|
||||
if "/" not in clean_model_name:
|
||||
logging.debug(f"Model '{clean_model_name}' validated as potential local model (no vendor prefix)")
|
||||
return True
|
||||
|
||||
logging.debug(f"Model '{model_name}' rejected by custom provider (appears to be cloud model)")
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Request execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -76,25 +76,31 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
if capabilities:
|
||||
return capabilities
|
||||
|
||||
logging.debug(
|
||||
f"Using generic capabilities for '{canonical_name}' via OpenRouter. "
|
||||
"Consider adding to custom_models.json for specific capabilities."
|
||||
)
|
||||
base_identifier = canonical_name.split(":", 1)[0]
|
||||
if "/" in base_identifier:
|
||||
logging.debug(
|
||||
"Using generic OpenRouter capabilities for %s (provider/model format detected)", canonical_name
|
||||
)
|
||||
generic = ModelCapabilities(
|
||||
provider=ProviderType.OPENROUTER,
|
||||
model_name=canonical_name,
|
||||
friendly_name=self.FRIENDLY_NAME,
|
||||
context_window=32_768,
|
||||
max_output_tokens=32_768,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False,
|
||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
|
||||
)
|
||||
generic._is_generic = True
|
||||
return generic
|
||||
|
||||
generic = ModelCapabilities(
|
||||
provider=ProviderType.OPENROUTER,
|
||||
model_name=canonical_name,
|
||||
friendly_name=self.FRIENDLY_NAME,
|
||||
context_window=32_768,
|
||||
max_output_tokens=32_768,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False,
|
||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
|
||||
logging.debug(
|
||||
"Rejecting unknown OpenRouter model '%s' (no provider prefix); requires explicit configuration",
|
||||
canonical_name,
|
||||
)
|
||||
generic._is_generic = True
|
||||
return generic
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Provider identity
|
||||
|
||||
@@ -36,12 +36,16 @@ class TestCustomProvider:
|
||||
CustomProvider(api_key="test-key")
|
||||
|
||||
def test_validate_model_names_always_true(self):
|
||||
"""Test CustomProvider accepts any model name."""
|
||||
"""Test CustomProvider validates model names correctly."""
|
||||
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
||||
|
||||
# Known model should validate
|
||||
assert provider.validate_model_name("llama3.2")
|
||||
assert provider.validate_model_name("unknown-model")
|
||||
assert provider.validate_model_name("anything")
|
||||
|
||||
# For custom provider, unknown models return False when not in registry
|
||||
# This is expected behavior - custom models need to be declared in custom_models.json
|
||||
assert not provider.validate_model_name("unknown-model")
|
||||
assert not provider.validate_model_name("anything")
|
||||
|
||||
def test_get_capabilities_from_registry(self):
|
||||
"""Test get_capabilities returns registry capabilities when available."""
|
||||
@@ -71,17 +75,12 @@ class TestCustomProvider:
|
||||
os.environ["OPENROUTER_ALLOWED_MODELS"] = original_env
|
||||
|
||||
def test_get_capabilities_generic_fallback(self):
|
||||
"""Test get_capabilities returns generic capabilities for unknown models."""
|
||||
"""Test get_capabilities raises error for unknown models not in registry."""
|
||||
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
||||
|
||||
capabilities = provider.get_capabilities("unknown-model-xyz")
|
||||
|
||||
assert capabilities.provider == ProviderType.CUSTOM
|
||||
assert capabilities.model_name == "unknown-model-xyz"
|
||||
assert capabilities.context_window == 32_768 # Conservative default
|
||||
assert not capabilities.supports_extended_thinking
|
||||
assert capabilities.supports_system_prompts
|
||||
assert capabilities.supports_streaming
|
||||
# Unknown models should raise ValueError when not in registry
|
||||
with pytest.raises(ValueError, match="Unsupported model 'unknown-model-xyz' for provider custom"):
|
||||
provider.get_capabilities("unknown-model-xyz")
|
||||
|
||||
def test_model_alias_resolution(self):
|
||||
"""Test model alias resolution works correctly."""
|
||||
@@ -100,8 +99,12 @@ class TestCustomProvider:
|
||||
"""Custom provider generic capabilities default to no thinking mode."""
|
||||
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
||||
|
||||
# llama3.2 is a known model that should work
|
||||
assert not provider.get_capabilities("llama3.2").supports_extended_thinking
|
||||
assert not provider.get_capabilities("any-model").supports_extended_thinking
|
||||
|
||||
# Unknown models should raise error
|
||||
with pytest.raises(ValueError, match="Unsupported model 'any-model' for provider custom"):
|
||||
provider.get_capabilities("any-model")
|
||||
|
||||
@patch("providers.custom.OpenAICompatibleProvider.generate_content")
|
||||
def test_generate_content_with_alias_resolution(self, mock_generate):
|
||||
|
||||
@@ -42,12 +42,15 @@ class TestOpenRouterProvider:
|
||||
"""Test model validation."""
|
||||
provider = OpenRouterProvider(api_key="test-key")
|
||||
|
||||
# Should accept any model - OpenRouter handles validation
|
||||
assert provider.validate_model_name("gpt-4") is True
|
||||
assert provider.validate_model_name("claude-4-opus") is True
|
||||
assert provider.validate_model_name("any-model-name") is True
|
||||
assert provider.validate_model_name("GPT-4") is True
|
||||
assert provider.validate_model_name("unknown-model") is True
|
||||
# OpenRouter accepts models with provider prefixes or known models
|
||||
assert provider.validate_model_name("openai/gpt-4") is True
|
||||
assert provider.validate_model_name("anthropic/claude-3-opus") is True
|
||||
assert provider.validate_model_name("google/any-model-name") is True
|
||||
assert provider.validate_model_name("groq/llama-3.1-8b") is True
|
||||
|
||||
# Unknown models without provider prefix are rejected
|
||||
assert provider.validate_model_name("gpt-4") is False
|
||||
assert provider.validate_model_name("unknown-model") is False
|
||||
|
||||
def test_get_capabilities(self):
|
||||
"""Test capability generation."""
|
||||
@@ -59,10 +62,14 @@ class TestOpenRouterProvider:
|
||||
assert caps.model_name == "openai/o3" # Resolved name
|
||||
assert caps.friendly_name == "OpenRouter (openai/o3)"
|
||||
|
||||
# Test with a model not in registry - should get generic capabilities
|
||||
caps = provider.get_capabilities("unknown-model")
|
||||
# Test with a model not in registry - should raise error
|
||||
with pytest.raises(ValueError, match="Unsupported model 'unknown-model' for provider openrouter"):
|
||||
provider.get_capabilities("unknown-model")
|
||||
|
||||
# Test with model that has provider prefix - should get generic capabilities
|
||||
caps = provider.get_capabilities("provider/unknown-model")
|
||||
assert caps.provider == ProviderType.OPENROUTER
|
||||
assert caps.model_name == "unknown-model"
|
||||
assert caps.model_name == "provider/unknown-model"
|
||||
assert caps.context_window == 32_768 # Safe default
|
||||
assert hasattr(caps, "_is_generic") and caps._is_generic is True
|
||||
|
||||
|
||||
@@ -288,6 +288,42 @@ class BaseTool(ABC):
|
||||
|
||||
return unique_models
|
||||
|
||||
def _format_available_models_list(self) -> str:
|
||||
"""Return a human-friendly list of available models or guidance when none found."""
|
||||
|
||||
available_models = self._get_available_models()
|
||||
if not available_models:
|
||||
return "No models detected. Configure provider credentials or set DEFAULT_MODEL to a valid option."
|
||||
return ", ".join(available_models)
|
||||
|
||||
def _build_model_unavailable_message(self, model_name: str) -> str:
|
||||
"""Compose a consistent error message for unavailable model scenarios."""
|
||||
|
||||
tool_category = self.get_model_category()
|
||||
suggested_model = ModelProviderRegistry.get_preferred_fallback_model(tool_category)
|
||||
available_models_text = self._format_available_models_list()
|
||||
|
||||
return (
|
||||
f"Model '{model_name}' is not available with current API keys. "
|
||||
f"Available models: {available_models_text}. "
|
||||
f"Suggested model for {self.get_name()}: '{suggested_model}' "
|
||||
f"(category: {tool_category.value}). Select the strongest reasoning model that fits the task."
|
||||
)
|
||||
|
||||
def _build_auto_mode_required_message(self) -> str:
|
||||
"""Compose the auto-mode prompt when an explicit model selection is required."""
|
||||
|
||||
tool_category = self.get_model_category()
|
||||
suggested_model = ModelProviderRegistry.get_preferred_fallback_model(tool_category)
|
||||
available_models_text = self._format_available_models_list()
|
||||
|
||||
return (
|
||||
"Model parameter is required in auto mode. "
|
||||
f"Available models: {available_models_text}. "
|
||||
f"Suggested model for {self.get_name()}: '{suggested_model}' "
|
||||
f"(category: {tool_category.value}). Select the strongest reasoning model that fits the task."
|
||||
)
|
||||
|
||||
def get_model_field_schema(self) -> dict[str, Any]:
|
||||
"""
|
||||
Generate the model field schema based on auto mode configuration.
|
||||
@@ -467,8 +503,7 @@ class BaseTool(ABC):
|
||||
provider = ModelProviderRegistry.get_provider_for_model(model_name)
|
||||
if not provider:
|
||||
logger.error(f"No provider found for model '{model_name}' in {self.name} tool")
|
||||
available_models = ModelProviderRegistry.get_available_models()
|
||||
raise ValueError(f"Model '{model_name}' is not available. Available models: {available_models}")
|
||||
raise ValueError(self._build_model_unavailable_message(model_name))
|
||||
|
||||
return provider
|
||||
except Exception as e:
|
||||
@@ -1127,29 +1162,11 @@ When recommending searches, be specific about what information you need and why
|
||||
|
||||
# For tests: Check if we should require model selection (auto mode)
|
||||
if self._should_require_model_selection(model_name):
|
||||
# Get suggested model based on tool category
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
tool_category = self.get_model_category()
|
||||
suggested_model = ModelProviderRegistry.get_preferred_fallback_model(tool_category)
|
||||
|
||||
# Build error message based on why selection is required
|
||||
if model_name.lower() == "auto":
|
||||
error_message = (
|
||||
f"Model parameter is required in auto mode. "
|
||||
f"Suggested model for {self.get_name()}: '{suggested_model}' "
|
||||
f"(category: {tool_category.value})"
|
||||
)
|
||||
error_message = self._build_auto_mode_required_message()
|
||||
else:
|
||||
# Model was specified but not available
|
||||
available_models = self._get_available_models()
|
||||
|
||||
error_message = (
|
||||
f"Model '{model_name}' is not available with current API keys. "
|
||||
f"Available models: {', '.join(available_models)}. "
|
||||
f"Suggested model for {self.get_name()}: '{suggested_model}' "
|
||||
f"(category: {tool_category.value})"
|
||||
)
|
||||
error_message = self._build_model_unavailable_message(model_name)
|
||||
raise ValueError(error_message)
|
||||
|
||||
# Create model context for tests
|
||||
@@ -1237,7 +1254,7 @@ When recommending searches, be specific about what information you need and why
|
||||
# Generic error response for any unavailable model
|
||||
return {
|
||||
"status": "error",
|
||||
"content": f"Model '{model_context}' is not available. {str(e)}",
|
||||
"content": self._build_model_unavailable_message(str(model_context)),
|
||||
"content_type": "text",
|
||||
"metadata": {
|
||||
"error_type": "validation_error",
|
||||
@@ -1264,7 +1281,7 @@ When recommending searches, be specific about what information you need and why
|
||||
model_name = getattr(model_context, "model_name", "unknown")
|
||||
return {
|
||||
"status": "error",
|
||||
"content": f"Model '{model_name}' is not available. {str(e)}",
|
||||
"content": self._build_model_unavailable_message(model_name),
|
||||
"content_type": "text",
|
||||
"metadata": {
|
||||
"error_type": "validation_error",
|
||||
|
||||
@@ -73,8 +73,17 @@ class ModelContext:
|
||||
if self._provider is None:
|
||||
self._provider = ModelProviderRegistry.get_provider_for_model(self.model_name)
|
||||
if not self._provider:
|
||||
available_models = ModelProviderRegistry.get_available_models()
|
||||
raise ValueError(f"Model '{self.model_name}' is not available. Available models: {available_models}")
|
||||
available_models = ModelProviderRegistry.get_available_model_names()
|
||||
if available_models:
|
||||
available_text = ", ".join(available_models)
|
||||
else:
|
||||
available_text = (
|
||||
"No models detected. Configure provider credentials or set DEFAULT_MODEL to a valid option."
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Model '{self.model_name}' is not available with current API keys. Available models: {available_text}."
|
||||
)
|
||||
return self._provider
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user