fix: custom provider must only accept a model if it's declared explicitly. Upon model rejection (in auto mode) the list of available models is returned up-front to help with selection.

This commit is contained in:
Fahad
2025-10-02 13:49:23 +04:00
parent 82a03ce63f
commit d285fadf4c
6 changed files with 116 additions and 146 deletions

View File

@@ -6,7 +6,7 @@ from typing import Optional
from .openai_compatible import OpenAICompatibleProvider
from .openrouter_registry import OpenRouterModelRegistry
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
from .shared import ModelCapabilities, ModelResponse, ProviderType
class CustomProvider(OpenAICompatibleProvider):
@@ -91,51 +91,22 @@ class CustomProvider(OpenAICompatibleProvider):
canonical_name: str,
requested_name: Optional[str] = None,
) -> Optional[ModelCapabilities]:
"""Return custom capabilities from the registry or generic defaults."""
"""Return capabilities for models explicitly marked as custom."""
builtin = super()._lookup_capabilities(canonical_name, requested_name)
if builtin is not None:
return builtin
capabilities = self._registry.get_capabilities(canonical_name)
if capabilities:
config = self._registry.resolve(canonical_name)
if config and getattr(config, "is_custom", False):
capabilities.provider = ProviderType.CUSTOM
return capabilities
# Non-custom models should fall through so OpenRouter handles them
return None
registry_entry = self._registry.resolve(canonical_name)
if registry_entry and getattr(registry_entry, "is_custom", False):
registry_entry.provider = ProviderType.CUSTOM
return registry_entry
logging.debug(
f"Using generic capabilities for '{canonical_name}' via Custom API. "
"Consider adding to custom_models.json for specific capabilities."
"Custom provider cannot resolve model '%s'; ensure it is declared with 'is_custom': true in custom_models.json",
canonical_name,
)
supports_temperature, temperature_constraint, temperature_reason = TemperatureConstraint.resolve_settings(
canonical_name
)
logging.warning(
f"Model '{canonical_name}' not found in custom_models.json. Using generic capabilities with inferred settings. "
f"Temperature support: {supports_temperature} ({temperature_reason}). "
"For better accuracy, add this model to your custom_models.json configuration."
)
generic = ModelCapabilities(
provider=ProviderType.CUSTOM,
model_name=canonical_name,
friendly_name=f"{self.FRIENDLY_NAME} ({canonical_name})",
context_window=32_768,
max_output_tokens=32_768,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False,
supports_temperature=supports_temperature,
temperature_constraint=temperature_constraint,
)
generic._is_generic = True
return generic
return None
def get_provider_type(self) -> ProviderType:
"""Identify this provider for restriction and logging logic."""
@@ -146,49 +117,6 @@ class CustomProvider(OpenAICompatibleProvider):
# Validation
# ------------------------------------------------------------------
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is allowed.
For custom endpoints, only accept models that are explicitly intended for
local/custom usage. This provider should NOT handle OpenRouter or cloud models.
Args:
model_name: Model name to validate
Returns:
True if model is intended for custom/local endpoint
"""
if super().validate_model_name(model_name):
return True
config = self._registry.resolve(model_name)
if config and not getattr(config, "is_custom", False):
return False
clean_model_name = model_name
if ":" in model_name:
clean_model_name = model_name.split(":", 1)[0]
logging.debug(f"Stripped version tag from '{model_name}' -> '{clean_model_name}'")
if super().validate_model_name(clean_model_name):
return True
config = self._registry.resolve(clean_model_name)
if config and not getattr(config, "is_custom", False):
return False
lowered = clean_model_name.lower()
if any(indicator in lowered for indicator in ["local", "ollama", "vllm", "lmstudio"]):
logging.debug(f"Model '{clean_model_name}' validated via local indicators")
return True
if "/" not in clean_model_name:
logging.debug(f"Model '{clean_model_name}' validated as potential local model (no vendor prefix)")
return True
logging.debug(f"Model '{model_name}' rejected by custom provider (appears to be cloud model)")
return False
# ------------------------------------------------------------------
# Request execution
# ------------------------------------------------------------------