fix: custom provider must only accept a model if it's declared explicitly. Upon model rejection (in auto mode) the list of available models is returned up-front to help with selection.
This commit is contained in:
@@ -6,7 +6,7 @@ from typing import Optional
|
||||
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .openrouter_registry import OpenRouterModelRegistry
|
||||
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
|
||||
from .shared import ModelCapabilities, ModelResponse, ProviderType
|
||||
|
||||
|
||||
class CustomProvider(OpenAICompatibleProvider):
|
||||
@@ -91,51 +91,22 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
canonical_name: str,
|
||||
requested_name: Optional[str] = None,
|
||||
) -> Optional[ModelCapabilities]:
|
||||
"""Return custom capabilities from the registry or generic defaults."""
|
||||
"""Return capabilities for models explicitly marked as custom."""
|
||||
|
||||
builtin = super()._lookup_capabilities(canonical_name, requested_name)
|
||||
if builtin is not None:
|
||||
return builtin
|
||||
|
||||
capabilities = self._registry.get_capabilities(canonical_name)
|
||||
if capabilities:
|
||||
config = self._registry.resolve(canonical_name)
|
||||
if config and getattr(config, "is_custom", False):
|
||||
capabilities.provider = ProviderType.CUSTOM
|
||||
return capabilities
|
||||
# Non-custom models should fall through so OpenRouter handles them
|
||||
return None
|
||||
registry_entry = self._registry.resolve(canonical_name)
|
||||
if registry_entry and getattr(registry_entry, "is_custom", False):
|
||||
registry_entry.provider = ProviderType.CUSTOM
|
||||
return registry_entry
|
||||
|
||||
logging.debug(
|
||||
f"Using generic capabilities for '{canonical_name}' via Custom API. "
|
||||
"Consider adding to custom_models.json for specific capabilities."
|
||||
"Custom provider cannot resolve model '%s'; ensure it is declared with 'is_custom': true in custom_models.json",
|
||||
canonical_name,
|
||||
)
|
||||
|
||||
supports_temperature, temperature_constraint, temperature_reason = TemperatureConstraint.resolve_settings(
|
||||
canonical_name
|
||||
)
|
||||
|
||||
logging.warning(
|
||||
f"Model '{canonical_name}' not found in custom_models.json. Using generic capabilities with inferred settings. "
|
||||
f"Temperature support: {supports_temperature} ({temperature_reason}). "
|
||||
"For better accuracy, add this model to your custom_models.json configuration."
|
||||
)
|
||||
|
||||
generic = ModelCapabilities(
|
||||
provider=ProviderType.CUSTOM,
|
||||
model_name=canonical_name,
|
||||
friendly_name=f"{self.FRIENDLY_NAME} ({canonical_name})",
|
||||
context_window=32_768,
|
||||
max_output_tokens=32_768,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False,
|
||||
supports_temperature=supports_temperature,
|
||||
temperature_constraint=temperature_constraint,
|
||||
)
|
||||
generic._is_generic = True
|
||||
return generic
|
||||
return None
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Identify this provider for restriction and logging logic."""
|
||||
@@ -146,49 +117,6 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
# Validation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is allowed.
|
||||
|
||||
For custom endpoints, only accept models that are explicitly intended for
|
||||
local/custom usage. This provider should NOT handle OpenRouter or cloud models.
|
||||
|
||||
Args:
|
||||
model_name: Model name to validate
|
||||
|
||||
Returns:
|
||||
True if model is intended for custom/local endpoint
|
||||
"""
|
||||
if super().validate_model_name(model_name):
|
||||
return True
|
||||
|
||||
config = self._registry.resolve(model_name)
|
||||
if config and not getattr(config, "is_custom", False):
|
||||
return False
|
||||
|
||||
clean_model_name = model_name
|
||||
if ":" in model_name:
|
||||
clean_model_name = model_name.split(":", 1)[0]
|
||||
logging.debug(f"Stripped version tag from '{model_name}' -> '{clean_model_name}'")
|
||||
|
||||
if super().validate_model_name(clean_model_name):
|
||||
return True
|
||||
|
||||
config = self._registry.resolve(clean_model_name)
|
||||
if config and not getattr(config, "is_custom", False):
|
||||
return False
|
||||
|
||||
lowered = clean_model_name.lower()
|
||||
if any(indicator in lowered for indicator in ["local", "ollama", "vllm", "lmstudio"]):
|
||||
logging.debug(f"Model '{clean_model_name}' validated via local indicators")
|
||||
return True
|
||||
|
||||
if "/" not in clean_model_name:
|
||||
logging.debug(f"Model '{clean_model_name}' validated as potential local model (no vendor prefix)")
|
||||
return True
|
||||
|
||||
logging.debug(f"Model '{model_name}' rejected by custom provider (appears to be cloud model)")
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Request execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user