refactor: cleanup provider base class; cleanup shared responsibilities; cleanup public contract

docs: document provider base class
refactor: cleanup custom provider, it should only deal with `is_custom` model configurations
fix: make sure openrouter provider does not load `is_custom` models
fix: listmodels tool cleanup
This commit is contained in:
Fahad
2025-10-02 12:59:45 +04:00
parent 6ec2033f34
commit 693b84db2b
15 changed files with 509 additions and 751 deletions

View File

@@ -83,117 +83,69 @@ class CustomProvider(OpenAICompatibleProvider):
aliases = self._registry.list_aliases()
logging.info(f"Custom provider loaded {len(models)} models with {len(aliases)} aliases")
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model aliases to actual model names.
# ------------------------------------------------------------------
# Capability surface
# ------------------------------------------------------------------
def _lookup_capabilities(
self,
canonical_name: str,
requested_name: Optional[str] = None,
) -> Optional[ModelCapabilities]:
"""Return custom capabilities from the registry or generic defaults."""
For Ollama-style models, strips version tags (e.g., 'llama3.2:latest' -> 'llama3.2')
since the base model name is what's typically used in API calls.
Args:
model_name: Input model name or alias
Returns:
Resolved model name with version tags stripped if applicable
"""
# First, try to resolve through registry as-is
config = self._registry.resolve(model_name)
if config:
if config.model_name != model_name:
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
return config.model_name
else:
# If not found in registry, handle version tags for local models
# Strip version tags (anything after ':') for Ollama-style models
if ":" in model_name:
base_model = model_name.split(":")[0]
logging.debug(f"Stripped version tag from '{model_name}' -> '{base_model}'")
# Try to resolve the base model through registry
base_config = self._registry.resolve(base_model)
if base_config:
logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'")
return base_config.model_name
else:
return base_model
else:
# If not found in registry and no version tag, return as-is
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
return model_name
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a custom model.
Args:
model_name: Name of the model (or alias)
Returns:
ModelCapabilities from registry or generic defaults
"""
# Try to get from registry first
capabilities = self._registry.get_capabilities(model_name)
builtin = super()._lookup_capabilities(canonical_name, requested_name)
if builtin is not None:
return builtin
capabilities = self._registry.get_capabilities(canonical_name)
if capabilities:
# Check if this is an OpenRouter model and apply restrictions
config = self._registry.resolve(model_name)
if config and not config.is_custom:
# This is an OpenRouter model, check restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.OPENROUTER, config.model_name, model_name):
raise ValueError(f"OpenRouter model '{model_name}' is not allowed by restriction policy.")
# Update provider type to OPENROUTER for OpenRouter models
capabilities.provider = ProviderType.OPENROUTER
else:
# Update provider type to CUSTOM for local custom models
config = self._registry.resolve(canonical_name)
if config and getattr(config, "is_custom", False):
capabilities.provider = ProviderType.CUSTOM
return capabilities
else:
# Resolve any potential aliases and create generic capabilities
resolved_name = self._resolve_model_name(model_name)
return capabilities
# Non-custom models should fall through so OpenRouter handles them
return None
logging.debug(
f"Using generic capabilities for '{resolved_name}' via Custom API. "
"Consider adding to custom_models.json for specific capabilities."
)
logging.debug(
f"Using generic capabilities for '{canonical_name}' via Custom API. "
"Consider adding to custom_models.json for specific capabilities."
)
# Infer temperature behaviour for generic capability fallback
supports_temperature, temperature_constraint, temperature_reason = TemperatureConstraint.resolve_settings(
resolved_name
)
supports_temperature, temperature_constraint, temperature_reason = TemperatureConstraint.resolve_settings(
canonical_name
)
logging.warning(
f"Model '{resolved_name}' not found in custom_models.json. Using generic capabilities with inferred settings. "
f"Temperature support: {supports_temperature} ({temperature_reason}). "
"For better accuracy, add this model to your custom_models.json configuration."
)
logging.warning(
f"Model '{canonical_name}' not found in custom_models.json. Using generic capabilities with inferred settings. "
f"Temperature support: {supports_temperature} ({temperature_reason}). "
"For better accuracy, add this model to your custom_models.json configuration."
)
# Create generic capabilities with inferred defaults
capabilities = ModelCapabilities(
provider=ProviderType.CUSTOM,
model_name=resolved_name,
friendly_name=f"{self.FRIENDLY_NAME} ({resolved_name})",
context_window=32_768, # Conservative default
max_output_tokens=32_768, # Conservative default max output
supports_extended_thinking=False, # Most custom models don't support this
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # Conservative default
supports_temperature=supports_temperature,
temperature_constraint=temperature_constraint,
)
# Mark as generic for validation purposes
capabilities._is_generic = True
return capabilities
generic = ModelCapabilities(
provider=ProviderType.CUSTOM,
model_name=canonical_name,
friendly_name=f"{self.FRIENDLY_NAME} ({canonical_name})",
context_window=32_768,
max_output_tokens=32_768,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False,
supports_temperature=supports_temperature,
temperature_constraint=temperature_constraint,
)
generic._is_generic = True
return generic
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
"""Identify this provider for restriction and logging logic."""
return ProviderType.CUSTOM
# ------------------------------------------------------------------
# Validation
# ------------------------------------------------------------------
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is allowed.
@@ -206,49 +158,41 @@ class CustomProvider(OpenAICompatibleProvider):
Returns:
True if model is intended for custom/local endpoint
"""
# logging.debug(f"Custom provider validating model: '{model_name}'")
if super().validate_model_name(model_name):
return True
# Try to resolve through registry first
config = self._registry.resolve(model_name)
if config:
model_id = config.model_name
# Use explicit is_custom flag for clean validation
if config.is_custom:
logging.debug(f"... [Custom] Model '{model_name}' -> '{model_id}' validated via registry")
return True
else:
# This is a cloud/OpenRouter model - CustomProvider should NOT handle these
# Let OpenRouter provider handle them instead
# logging.debug(f"... [Custom] Model '{model_name}' -> '{model_id}' not custom (defer to OpenRouter)")
return False
if config and not getattr(config, "is_custom", False):
return False
# Handle version tags for unknown models (e.g., "my-model:latest")
clean_model_name = model_name
if ":" in model_name:
clean_model_name = model_name.split(":")[0]
clean_model_name = model_name.split(":", 1)[0]
logging.debug(f"Stripped version tag from '{model_name}' -> '{clean_model_name}'")
# Try to resolve the clean name
if super().validate_model_name(clean_model_name):
return True
config = self._registry.resolve(clean_model_name)
if config:
return self.validate_model_name(clean_model_name) # Recursively validate clean name
if config and not getattr(config, "is_custom", False):
return False
# For unknown models (not in registry), only accept if they look like local models
# This maintains backward compatibility for custom models not yet in the registry
# Accept models with explicit local indicators in the name
if any(indicator in clean_model_name.lower() for indicator in ["local", "ollama", "vllm", "lmstudio"]):
lowered = clean_model_name.lower()
if any(indicator in lowered for indicator in ["local", "ollama", "vllm", "lmstudio"]):
logging.debug(f"Model '{clean_model_name}' validated via local indicators")
return True
# Accept simple model names without vendor prefix (likely local/custom models)
if "/" not in clean_model_name:
logging.debug(f"Model '{clean_model_name}' validated as potential local model (no vendor prefix)")
return True
# Reject everything else (likely cloud models not in registry)
logging.debug(f"Model '{model_name}' rejected by custom provider (appears to be cloud model)")
return False
# ------------------------------------------------------------------
# Request execution
# ------------------------------------------------------------------
def generate_content(
self,
prompt: str,
@@ -284,25 +228,41 @@ class CustomProvider(OpenAICompatibleProvider):
**kwargs,
)
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
"""Get model configurations from the registry.
# ------------------------------------------------------------------
# Registry helpers
# ------------------------------------------------------------------
For CustomProvider, we convert registry configurations to ModelCapabilities objects.
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve registry aliases and strip version tags for local models."""
Returns:
Dictionary mapping model names to their ModelCapabilities objects
"""
config = self._registry.resolve(model_name)
if config:
if config.model_name != model_name:
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
return config.model_name
configs = {}
if ":" in model_name:
base_model = model_name.split(":")[0]
logging.debug(f"Stripped version tag from '{model_name}' -> '{base_model}'")
if self._registry:
# Get all models from registry
for model_name in self._registry.list_models():
# Only include custom models that this provider validates
if self.validate_model_name(model_name):
config = self._registry.resolve(model_name)
if config and config.is_custom:
# Use ModelCapabilities directly from registry
configs[model_name] = config
base_config = self._registry.resolve(base_model)
if base_config:
logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'")
return base_config.model_name
return base_model
return configs
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
return model_name
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
"""Expose registry capabilities for models marked as custom."""
if not self._registry:
return {}
capabilities: dict[str, ModelCapabilities] = {}
for model_name in self._registry.list_models():
config = self._registry.resolve(model_name)
if config and getattr(config, "is_custom", False):
capabilities[model_name] = config
return capabilities