fix: custom provider must only accept a model if it's declared explicitly. Upon model rejection (in auto mode) the list of available models is returned up-front to help with selection.
This commit is contained in:
@@ -6,7 +6,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from .openai_compatible import OpenAICompatibleProvider
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
from .openrouter_registry import OpenRouterModelRegistry
|
from .openrouter_registry import OpenRouterModelRegistry
|
||||||
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
|
from .shared import ModelCapabilities, ModelResponse, ProviderType
|
||||||
|
|
||||||
|
|
||||||
class CustomProvider(OpenAICompatibleProvider):
|
class CustomProvider(OpenAICompatibleProvider):
|
||||||
@@ -91,51 +91,22 @@ class CustomProvider(OpenAICompatibleProvider):
|
|||||||
canonical_name: str,
|
canonical_name: str,
|
||||||
requested_name: Optional[str] = None,
|
requested_name: Optional[str] = None,
|
||||||
) -> Optional[ModelCapabilities]:
|
) -> Optional[ModelCapabilities]:
|
||||||
"""Return custom capabilities from the registry or generic defaults."""
|
"""Return capabilities for models explicitly marked as custom."""
|
||||||
|
|
||||||
builtin = super()._lookup_capabilities(canonical_name, requested_name)
|
builtin = super()._lookup_capabilities(canonical_name, requested_name)
|
||||||
if builtin is not None:
|
if builtin is not None:
|
||||||
return builtin
|
return builtin
|
||||||
|
|
||||||
capabilities = self._registry.get_capabilities(canonical_name)
|
registry_entry = self._registry.resolve(canonical_name)
|
||||||
if capabilities:
|
if registry_entry and getattr(registry_entry, "is_custom", False):
|
||||||
config = self._registry.resolve(canonical_name)
|
registry_entry.provider = ProviderType.CUSTOM
|
||||||
if config and getattr(config, "is_custom", False):
|
return registry_entry
|
||||||
capabilities.provider = ProviderType.CUSTOM
|
|
||||||
return capabilities
|
|
||||||
# Non-custom models should fall through so OpenRouter handles them
|
|
||||||
return None
|
|
||||||
|
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"Using generic capabilities for '{canonical_name}' via Custom API. "
|
"Custom provider cannot resolve model '%s'; ensure it is declared with 'is_custom': true in custom_models.json",
|
||||||
"Consider adding to custom_models.json for specific capabilities."
|
canonical_name,
|
||||||
)
|
)
|
||||||
|
return None
|
||||||
supports_temperature, temperature_constraint, temperature_reason = TemperatureConstraint.resolve_settings(
|
|
||||||
canonical_name
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.warning(
|
|
||||||
f"Model '{canonical_name}' not found in custom_models.json. Using generic capabilities with inferred settings. "
|
|
||||||
f"Temperature support: {supports_temperature} ({temperature_reason}). "
|
|
||||||
"For better accuracy, add this model to your custom_models.json configuration."
|
|
||||||
)
|
|
||||||
|
|
||||||
generic = ModelCapabilities(
|
|
||||||
provider=ProviderType.CUSTOM,
|
|
||||||
model_name=canonical_name,
|
|
||||||
friendly_name=f"{self.FRIENDLY_NAME} ({canonical_name})",
|
|
||||||
context_window=32_768,
|
|
||||||
max_output_tokens=32_768,
|
|
||||||
supports_extended_thinking=False,
|
|
||||||
supports_system_prompts=True,
|
|
||||||
supports_streaming=True,
|
|
||||||
supports_function_calling=False,
|
|
||||||
supports_temperature=supports_temperature,
|
|
||||||
temperature_constraint=temperature_constraint,
|
|
||||||
)
|
|
||||||
generic._is_generic = True
|
|
||||||
return generic
|
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Identify this provider for restriction and logging logic."""
|
"""Identify this provider for restriction and logging logic."""
|
||||||
@@ -146,49 +117,6 @@ class CustomProvider(OpenAICompatibleProvider):
|
|||||||
# Validation
|
# Validation
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def validate_model_name(self, model_name: str) -> bool:
|
|
||||||
"""Validate if the model name is allowed.
|
|
||||||
|
|
||||||
For custom endpoints, only accept models that are explicitly intended for
|
|
||||||
local/custom usage. This provider should NOT handle OpenRouter or cloud models.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Model name to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if model is intended for custom/local endpoint
|
|
||||||
"""
|
|
||||||
if super().validate_model_name(model_name):
|
|
||||||
return True
|
|
||||||
|
|
||||||
config = self._registry.resolve(model_name)
|
|
||||||
if config and not getattr(config, "is_custom", False):
|
|
||||||
return False
|
|
||||||
|
|
||||||
clean_model_name = model_name
|
|
||||||
if ":" in model_name:
|
|
||||||
clean_model_name = model_name.split(":", 1)[0]
|
|
||||||
logging.debug(f"Stripped version tag from '{model_name}' -> '{clean_model_name}'")
|
|
||||||
|
|
||||||
if super().validate_model_name(clean_model_name):
|
|
||||||
return True
|
|
||||||
|
|
||||||
config = self._registry.resolve(clean_model_name)
|
|
||||||
if config and not getattr(config, "is_custom", False):
|
|
||||||
return False
|
|
||||||
|
|
||||||
lowered = clean_model_name.lower()
|
|
||||||
if any(indicator in lowered for indicator in ["local", "ollama", "vllm", "lmstudio"]):
|
|
||||||
logging.debug(f"Model '{clean_model_name}' validated via local indicators")
|
|
||||||
return True
|
|
||||||
|
|
||||||
if "/" not in clean_model_name:
|
|
||||||
logging.debug(f"Model '{clean_model_name}' validated as potential local model (no vendor prefix)")
|
|
||||||
return True
|
|
||||||
|
|
||||||
logging.debug(f"Model '{model_name}' rejected by custom provider (appears to be cloud model)")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Request execution
|
# Request execution
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|||||||
@@ -76,25 +76,31 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
|||||||
if capabilities:
|
if capabilities:
|
||||||
return capabilities
|
return capabilities
|
||||||
|
|
||||||
logging.debug(
|
base_identifier = canonical_name.split(":", 1)[0]
|
||||||
f"Using generic capabilities for '{canonical_name}' via OpenRouter. "
|
if "/" in base_identifier:
|
||||||
"Consider adding to custom_models.json for specific capabilities."
|
logging.debug(
|
||||||
)
|
"Using generic OpenRouter capabilities for %s (provider/model format detected)", canonical_name
|
||||||
|
)
|
||||||
|
generic = ModelCapabilities(
|
||||||
|
provider=ProviderType.OPENROUTER,
|
||||||
|
model_name=canonical_name,
|
||||||
|
friendly_name=self.FRIENDLY_NAME,
|
||||||
|
context_window=32_768,
|
||||||
|
max_output_tokens=32_768,
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False,
|
||||||
|
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
|
||||||
|
)
|
||||||
|
generic._is_generic = True
|
||||||
|
return generic
|
||||||
|
|
||||||
generic = ModelCapabilities(
|
logging.debug(
|
||||||
provider=ProviderType.OPENROUTER,
|
"Rejecting unknown OpenRouter model '%s' (no provider prefix); requires explicit configuration",
|
||||||
model_name=canonical_name,
|
canonical_name,
|
||||||
friendly_name=self.FRIENDLY_NAME,
|
|
||||||
context_window=32_768,
|
|
||||||
max_output_tokens=32_768,
|
|
||||||
supports_extended_thinking=False,
|
|
||||||
supports_system_prompts=True,
|
|
||||||
supports_streaming=True,
|
|
||||||
supports_function_calling=False,
|
|
||||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
|
|
||||||
)
|
)
|
||||||
generic._is_generic = True
|
return None
|
||||||
return generic
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Provider identity
|
# Provider identity
|
||||||
|
|||||||
@@ -36,12 +36,16 @@ class TestCustomProvider:
|
|||||||
CustomProvider(api_key="test-key")
|
CustomProvider(api_key="test-key")
|
||||||
|
|
||||||
def test_validate_model_names_always_true(self):
|
def test_validate_model_names_always_true(self):
|
||||||
"""Test CustomProvider accepts any model name."""
|
"""Test CustomProvider validates model names correctly."""
|
||||||
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
||||||
|
|
||||||
|
# Known model should validate
|
||||||
assert provider.validate_model_name("llama3.2")
|
assert provider.validate_model_name("llama3.2")
|
||||||
assert provider.validate_model_name("unknown-model")
|
|
||||||
assert provider.validate_model_name("anything")
|
# For custom provider, unknown models return False when not in registry
|
||||||
|
# This is expected behavior - custom models need to be declared in custom_models.json
|
||||||
|
assert not provider.validate_model_name("unknown-model")
|
||||||
|
assert not provider.validate_model_name("anything")
|
||||||
|
|
||||||
def test_get_capabilities_from_registry(self):
|
def test_get_capabilities_from_registry(self):
|
||||||
"""Test get_capabilities returns registry capabilities when available."""
|
"""Test get_capabilities returns registry capabilities when available."""
|
||||||
@@ -71,17 +75,12 @@ class TestCustomProvider:
|
|||||||
os.environ["OPENROUTER_ALLOWED_MODELS"] = original_env
|
os.environ["OPENROUTER_ALLOWED_MODELS"] = original_env
|
||||||
|
|
||||||
def test_get_capabilities_generic_fallback(self):
|
def test_get_capabilities_generic_fallback(self):
|
||||||
"""Test get_capabilities returns generic capabilities for unknown models."""
|
"""Test get_capabilities raises error for unknown models not in registry."""
|
||||||
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
||||||
|
|
||||||
capabilities = provider.get_capabilities("unknown-model-xyz")
|
# Unknown models should raise ValueError when not in registry
|
||||||
|
with pytest.raises(ValueError, match="Unsupported model 'unknown-model-xyz' for provider custom"):
|
||||||
assert capabilities.provider == ProviderType.CUSTOM
|
provider.get_capabilities("unknown-model-xyz")
|
||||||
assert capabilities.model_name == "unknown-model-xyz"
|
|
||||||
assert capabilities.context_window == 32_768 # Conservative default
|
|
||||||
assert not capabilities.supports_extended_thinking
|
|
||||||
assert capabilities.supports_system_prompts
|
|
||||||
assert capabilities.supports_streaming
|
|
||||||
|
|
||||||
def test_model_alias_resolution(self):
|
def test_model_alias_resolution(self):
|
||||||
"""Test model alias resolution works correctly."""
|
"""Test model alias resolution works correctly."""
|
||||||
@@ -100,8 +99,12 @@ class TestCustomProvider:
|
|||||||
"""Custom provider generic capabilities default to no thinking mode."""
|
"""Custom provider generic capabilities default to no thinking mode."""
|
||||||
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
||||||
|
|
||||||
|
# llama3.2 is a known model that should work
|
||||||
assert not provider.get_capabilities("llama3.2").supports_extended_thinking
|
assert not provider.get_capabilities("llama3.2").supports_extended_thinking
|
||||||
assert not provider.get_capabilities("any-model").supports_extended_thinking
|
|
||||||
|
# Unknown models should raise error
|
||||||
|
with pytest.raises(ValueError, match="Unsupported model 'any-model' for provider custom"):
|
||||||
|
provider.get_capabilities("any-model")
|
||||||
|
|
||||||
@patch("providers.custom.OpenAICompatibleProvider.generate_content")
|
@patch("providers.custom.OpenAICompatibleProvider.generate_content")
|
||||||
def test_generate_content_with_alias_resolution(self, mock_generate):
|
def test_generate_content_with_alias_resolution(self, mock_generate):
|
||||||
|
|||||||
@@ -42,12 +42,15 @@ class TestOpenRouterProvider:
|
|||||||
"""Test model validation."""
|
"""Test model validation."""
|
||||||
provider = OpenRouterProvider(api_key="test-key")
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
|
|
||||||
# Should accept any model - OpenRouter handles validation
|
# OpenRouter accepts models with provider prefixes or known models
|
||||||
assert provider.validate_model_name("gpt-4") is True
|
assert provider.validate_model_name("openai/gpt-4") is True
|
||||||
assert provider.validate_model_name("claude-4-opus") is True
|
assert provider.validate_model_name("anthropic/claude-3-opus") is True
|
||||||
assert provider.validate_model_name("any-model-name") is True
|
assert provider.validate_model_name("google/any-model-name") is True
|
||||||
assert provider.validate_model_name("GPT-4") is True
|
assert provider.validate_model_name("groq/llama-3.1-8b") is True
|
||||||
assert provider.validate_model_name("unknown-model") is True
|
|
||||||
|
# Unknown models without provider prefix are rejected
|
||||||
|
assert provider.validate_model_name("gpt-4") is False
|
||||||
|
assert provider.validate_model_name("unknown-model") is False
|
||||||
|
|
||||||
def test_get_capabilities(self):
|
def test_get_capabilities(self):
|
||||||
"""Test capability generation."""
|
"""Test capability generation."""
|
||||||
@@ -59,10 +62,14 @@ class TestOpenRouterProvider:
|
|||||||
assert caps.model_name == "openai/o3" # Resolved name
|
assert caps.model_name == "openai/o3" # Resolved name
|
||||||
assert caps.friendly_name == "OpenRouter (openai/o3)"
|
assert caps.friendly_name == "OpenRouter (openai/o3)"
|
||||||
|
|
||||||
# Test with a model not in registry - should get generic capabilities
|
# Test with a model not in registry - should raise error
|
||||||
caps = provider.get_capabilities("unknown-model")
|
with pytest.raises(ValueError, match="Unsupported model 'unknown-model' for provider openrouter"):
|
||||||
|
provider.get_capabilities("unknown-model")
|
||||||
|
|
||||||
|
# Test with model that has provider prefix - should get generic capabilities
|
||||||
|
caps = provider.get_capabilities("provider/unknown-model")
|
||||||
assert caps.provider == ProviderType.OPENROUTER
|
assert caps.provider == ProviderType.OPENROUTER
|
||||||
assert caps.model_name == "unknown-model"
|
assert caps.model_name == "provider/unknown-model"
|
||||||
assert caps.context_window == 32_768 # Safe default
|
assert caps.context_window == 32_768 # Safe default
|
||||||
assert hasattr(caps, "_is_generic") and caps._is_generic is True
|
assert hasattr(caps, "_is_generic") and caps._is_generic is True
|
||||||
|
|
||||||
|
|||||||
@@ -288,6 +288,42 @@ class BaseTool(ABC):
|
|||||||
|
|
||||||
return unique_models
|
return unique_models
|
||||||
|
|
||||||
|
def _format_available_models_list(self) -> str:
|
||||||
|
"""Return a human-friendly list of available models or guidance when none found."""
|
||||||
|
|
||||||
|
available_models = self._get_available_models()
|
||||||
|
if not available_models:
|
||||||
|
return "No models detected. Configure provider credentials or set DEFAULT_MODEL to a valid option."
|
||||||
|
return ", ".join(available_models)
|
||||||
|
|
||||||
|
def _build_model_unavailable_message(self, model_name: str) -> str:
|
||||||
|
"""Compose a consistent error message for unavailable model scenarios."""
|
||||||
|
|
||||||
|
tool_category = self.get_model_category()
|
||||||
|
suggested_model = ModelProviderRegistry.get_preferred_fallback_model(tool_category)
|
||||||
|
available_models_text = self._format_available_models_list()
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"Model '{model_name}' is not available with current API keys. "
|
||||||
|
f"Available models: {available_models_text}. "
|
||||||
|
f"Suggested model for {self.get_name()}: '{suggested_model}' "
|
||||||
|
f"(category: {tool_category.value}). Select the strongest reasoning model that fits the task."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_auto_mode_required_message(self) -> str:
|
||||||
|
"""Compose the auto-mode prompt when an explicit model selection is required."""
|
||||||
|
|
||||||
|
tool_category = self.get_model_category()
|
||||||
|
suggested_model = ModelProviderRegistry.get_preferred_fallback_model(tool_category)
|
||||||
|
available_models_text = self._format_available_models_list()
|
||||||
|
|
||||||
|
return (
|
||||||
|
"Model parameter is required in auto mode. "
|
||||||
|
f"Available models: {available_models_text}. "
|
||||||
|
f"Suggested model for {self.get_name()}: '{suggested_model}' "
|
||||||
|
f"(category: {tool_category.value}). Select the strongest reasoning model that fits the task."
|
||||||
|
)
|
||||||
|
|
||||||
def get_model_field_schema(self) -> dict[str, Any]:
|
def get_model_field_schema(self) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Generate the model field schema based on auto mode configuration.
|
Generate the model field schema based on auto mode configuration.
|
||||||
@@ -467,8 +503,7 @@ class BaseTool(ABC):
|
|||||||
provider = ModelProviderRegistry.get_provider_for_model(model_name)
|
provider = ModelProviderRegistry.get_provider_for_model(model_name)
|
||||||
if not provider:
|
if not provider:
|
||||||
logger.error(f"No provider found for model '{model_name}' in {self.name} tool")
|
logger.error(f"No provider found for model '{model_name}' in {self.name} tool")
|
||||||
available_models = ModelProviderRegistry.get_available_models()
|
raise ValueError(self._build_model_unavailable_message(model_name))
|
||||||
raise ValueError(f"Model '{model_name}' is not available. Available models: {available_models}")
|
|
||||||
|
|
||||||
return provider
|
return provider
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1127,29 +1162,11 @@ When recommending searches, be specific about what information you need and why
|
|||||||
|
|
||||||
# For tests: Check if we should require model selection (auto mode)
|
# For tests: Check if we should require model selection (auto mode)
|
||||||
if self._should_require_model_selection(model_name):
|
if self._should_require_model_selection(model_name):
|
||||||
# Get suggested model based on tool category
|
|
||||||
from providers.registry import ModelProviderRegistry
|
|
||||||
|
|
||||||
tool_category = self.get_model_category()
|
|
||||||
suggested_model = ModelProviderRegistry.get_preferred_fallback_model(tool_category)
|
|
||||||
|
|
||||||
# Build error message based on why selection is required
|
# Build error message based on why selection is required
|
||||||
if model_name.lower() == "auto":
|
if model_name.lower() == "auto":
|
||||||
error_message = (
|
error_message = self._build_auto_mode_required_message()
|
||||||
f"Model parameter is required in auto mode. "
|
|
||||||
f"Suggested model for {self.get_name()}: '{suggested_model}' "
|
|
||||||
f"(category: {tool_category.value})"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Model was specified but not available
|
error_message = self._build_model_unavailable_message(model_name)
|
||||||
available_models = self._get_available_models()
|
|
||||||
|
|
||||||
error_message = (
|
|
||||||
f"Model '{model_name}' is not available with current API keys. "
|
|
||||||
f"Available models: {', '.join(available_models)}. "
|
|
||||||
f"Suggested model for {self.get_name()}: '{suggested_model}' "
|
|
||||||
f"(category: {tool_category.value})"
|
|
||||||
)
|
|
||||||
raise ValueError(error_message)
|
raise ValueError(error_message)
|
||||||
|
|
||||||
# Create model context for tests
|
# Create model context for tests
|
||||||
@@ -1237,7 +1254,7 @@ When recommending searches, be specific about what information you need and why
|
|||||||
# Generic error response for any unavailable model
|
# Generic error response for any unavailable model
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"content": f"Model '{model_context}' is not available. {str(e)}",
|
"content": self._build_model_unavailable_message(str(model_context)),
|
||||||
"content_type": "text",
|
"content_type": "text",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"error_type": "validation_error",
|
"error_type": "validation_error",
|
||||||
@@ -1264,7 +1281,7 @@ When recommending searches, be specific about what information you need and why
|
|||||||
model_name = getattr(model_context, "model_name", "unknown")
|
model_name = getattr(model_context, "model_name", "unknown")
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"content": f"Model '{model_name}' is not available. {str(e)}",
|
"content": self._build_model_unavailable_message(model_name),
|
||||||
"content_type": "text",
|
"content_type": "text",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"error_type": "validation_error",
|
"error_type": "validation_error",
|
||||||
|
|||||||
@@ -73,8 +73,17 @@ class ModelContext:
|
|||||||
if self._provider is None:
|
if self._provider is None:
|
||||||
self._provider = ModelProviderRegistry.get_provider_for_model(self.model_name)
|
self._provider = ModelProviderRegistry.get_provider_for_model(self.model_name)
|
||||||
if not self._provider:
|
if not self._provider:
|
||||||
available_models = ModelProviderRegistry.get_available_models()
|
available_models = ModelProviderRegistry.get_available_model_names()
|
||||||
raise ValueError(f"Model '{self.model_name}' is not available. Available models: {available_models}")
|
if available_models:
|
||||||
|
available_text = ", ".join(available_models)
|
||||||
|
else:
|
||||||
|
available_text = (
|
||||||
|
"No models detected. Configure provider credentials or set DEFAULT_MODEL to a valid option."
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{self.model_name}' is not available with current API keys. Available models: {available_text}."
|
||||||
|
)
|
||||||
return self._provider
|
return self._provider
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
Reference in New Issue
Block a user