refactor: cleanup provider base class; cleanup shared responsibilities; cleanup public contract
docs: document provider base class refactor: cleanup custom provider, it should only deal with `is_custom` model configurations fix: make sure openrouter provider does not load `is_custom` models fix: listmodels tool cleanup
This commit is contained in:
@@ -83,117 +83,69 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
aliases = self._registry.list_aliases()
|
||||
logging.info(f"Custom provider loaded {len(models)} models with {len(aliases)} aliases")
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model aliases to actual model names.
|
||||
# ------------------------------------------------------------------
|
||||
# Capability surface
|
||||
# ------------------------------------------------------------------
|
||||
def _lookup_capabilities(
|
||||
self,
|
||||
canonical_name: str,
|
||||
requested_name: Optional[str] = None,
|
||||
) -> Optional[ModelCapabilities]:
|
||||
"""Return custom capabilities from the registry or generic defaults."""
|
||||
|
||||
For Ollama-style models, strips version tags (e.g., 'llama3.2:latest' -> 'llama3.2')
|
||||
since the base model name is what's typically used in API calls.
|
||||
|
||||
Args:
|
||||
model_name: Input model name or alias
|
||||
|
||||
Returns:
|
||||
Resolved model name with version tags stripped if applicable
|
||||
"""
|
||||
# First, try to resolve through registry as-is
|
||||
config = self._registry.resolve(model_name)
|
||||
|
||||
if config:
|
||||
if config.model_name != model_name:
|
||||
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
||||
return config.model_name
|
||||
else:
|
||||
# If not found in registry, handle version tags for local models
|
||||
# Strip version tags (anything after ':') for Ollama-style models
|
||||
if ":" in model_name:
|
||||
base_model = model_name.split(":")[0]
|
||||
logging.debug(f"Stripped version tag from '{model_name}' -> '{base_model}'")
|
||||
|
||||
# Try to resolve the base model through registry
|
||||
base_config = self._registry.resolve(base_model)
|
||||
if base_config:
|
||||
logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'")
|
||||
return base_config.model_name
|
||||
else:
|
||||
return base_model
|
||||
else:
|
||||
# If not found in registry and no version tag, return as-is
|
||||
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
||||
return model_name
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a custom model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model (or alias)
|
||||
|
||||
Returns:
|
||||
ModelCapabilities from registry or generic defaults
|
||||
"""
|
||||
# Try to get from registry first
|
||||
capabilities = self._registry.get_capabilities(model_name)
|
||||
builtin = super()._lookup_capabilities(canonical_name, requested_name)
|
||||
if builtin is not None:
|
||||
return builtin
|
||||
|
||||
capabilities = self._registry.get_capabilities(canonical_name)
|
||||
if capabilities:
|
||||
# Check if this is an OpenRouter model and apply restrictions
|
||||
config = self._registry.resolve(model_name)
|
||||
if config and not config.is_custom:
|
||||
# This is an OpenRouter model, check restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.OPENROUTER, config.model_name, model_name):
|
||||
raise ValueError(f"OpenRouter model '{model_name}' is not allowed by restriction policy.")
|
||||
|
||||
# Update provider type to OPENROUTER for OpenRouter models
|
||||
capabilities.provider = ProviderType.OPENROUTER
|
||||
else:
|
||||
# Update provider type to CUSTOM for local custom models
|
||||
config = self._registry.resolve(canonical_name)
|
||||
if config and getattr(config, "is_custom", False):
|
||||
capabilities.provider = ProviderType.CUSTOM
|
||||
return capabilities
|
||||
else:
|
||||
# Resolve any potential aliases and create generic capabilities
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
return capabilities
|
||||
# Non-custom models should fall through so OpenRouter handles them
|
||||
return None
|
||||
|
||||
logging.debug(
|
||||
f"Using generic capabilities for '{resolved_name}' via Custom API. "
|
||||
"Consider adding to custom_models.json for specific capabilities."
|
||||
)
|
||||
logging.debug(
|
||||
f"Using generic capabilities for '{canonical_name}' via Custom API. "
|
||||
"Consider adding to custom_models.json for specific capabilities."
|
||||
)
|
||||
|
||||
# Infer temperature behaviour for generic capability fallback
|
||||
supports_temperature, temperature_constraint, temperature_reason = TemperatureConstraint.resolve_settings(
|
||||
resolved_name
|
||||
)
|
||||
supports_temperature, temperature_constraint, temperature_reason = TemperatureConstraint.resolve_settings(
|
||||
canonical_name
|
||||
)
|
||||
|
||||
logging.warning(
|
||||
f"Model '{resolved_name}' not found in custom_models.json. Using generic capabilities with inferred settings. "
|
||||
f"Temperature support: {supports_temperature} ({temperature_reason}). "
|
||||
"For better accuracy, add this model to your custom_models.json configuration."
|
||||
)
|
||||
logging.warning(
|
||||
f"Model '{canonical_name}' not found in custom_models.json. Using generic capabilities with inferred settings. "
|
||||
f"Temperature support: {supports_temperature} ({temperature_reason}). "
|
||||
"For better accuracy, add this model to your custom_models.json configuration."
|
||||
)
|
||||
|
||||
# Create generic capabilities with inferred defaults
|
||||
capabilities = ModelCapabilities(
|
||||
provider=ProviderType.CUSTOM,
|
||||
model_name=resolved_name,
|
||||
friendly_name=f"{self.FRIENDLY_NAME} ({resolved_name})",
|
||||
context_window=32_768, # Conservative default
|
||||
max_output_tokens=32_768, # Conservative default max output
|
||||
supports_extended_thinking=False, # Most custom models don't support this
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # Conservative default
|
||||
supports_temperature=supports_temperature,
|
||||
temperature_constraint=temperature_constraint,
|
||||
)
|
||||
|
||||
# Mark as generic for validation purposes
|
||||
capabilities._is_generic = True
|
||||
|
||||
return capabilities
|
||||
generic = ModelCapabilities(
|
||||
provider=ProviderType.CUSTOM,
|
||||
model_name=canonical_name,
|
||||
friendly_name=f"{self.FRIENDLY_NAME} ({canonical_name})",
|
||||
context_window=32_768,
|
||||
max_output_tokens=32_768,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False,
|
||||
supports_temperature=supports_temperature,
|
||||
temperature_constraint=temperature_constraint,
|
||||
)
|
||||
generic._is_generic = True
|
||||
return generic
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
"""Identify this provider for restriction and logging logic."""
|
||||
|
||||
return ProviderType.CUSTOM
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Validation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is allowed.
|
||||
|
||||
@@ -206,49 +158,41 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
Returns:
|
||||
True if model is intended for custom/local endpoint
|
||||
"""
|
||||
# logging.debug(f"Custom provider validating model: '{model_name}'")
|
||||
if super().validate_model_name(model_name):
|
||||
return True
|
||||
|
||||
# Try to resolve through registry first
|
||||
config = self._registry.resolve(model_name)
|
||||
if config:
|
||||
model_id = config.model_name
|
||||
# Use explicit is_custom flag for clean validation
|
||||
if config.is_custom:
|
||||
logging.debug(f"... [Custom] Model '{model_name}' -> '{model_id}' validated via registry")
|
||||
return True
|
||||
else:
|
||||
# This is a cloud/OpenRouter model - CustomProvider should NOT handle these
|
||||
# Let OpenRouter provider handle them instead
|
||||
# logging.debug(f"... [Custom] Model '{model_name}' -> '{model_id}' not custom (defer to OpenRouter)")
|
||||
return False
|
||||
if config and not getattr(config, "is_custom", False):
|
||||
return False
|
||||
|
||||
# Handle version tags for unknown models (e.g., "my-model:latest")
|
||||
clean_model_name = model_name
|
||||
if ":" in model_name:
|
||||
clean_model_name = model_name.split(":")[0]
|
||||
clean_model_name = model_name.split(":", 1)[0]
|
||||
logging.debug(f"Stripped version tag from '{model_name}' -> '{clean_model_name}'")
|
||||
# Try to resolve the clean name
|
||||
|
||||
if super().validate_model_name(clean_model_name):
|
||||
return True
|
||||
|
||||
config = self._registry.resolve(clean_model_name)
|
||||
if config:
|
||||
return self.validate_model_name(clean_model_name) # Recursively validate clean name
|
||||
if config and not getattr(config, "is_custom", False):
|
||||
return False
|
||||
|
||||
# For unknown models (not in registry), only accept if they look like local models
|
||||
# This maintains backward compatibility for custom models not yet in the registry
|
||||
|
||||
# Accept models with explicit local indicators in the name
|
||||
if any(indicator in clean_model_name.lower() for indicator in ["local", "ollama", "vllm", "lmstudio"]):
|
||||
lowered = clean_model_name.lower()
|
||||
if any(indicator in lowered for indicator in ["local", "ollama", "vllm", "lmstudio"]):
|
||||
logging.debug(f"Model '{clean_model_name}' validated via local indicators")
|
||||
return True
|
||||
|
||||
# Accept simple model names without vendor prefix (likely local/custom models)
|
||||
if "/" not in clean_model_name:
|
||||
logging.debug(f"Model '{clean_model_name}' validated as potential local model (no vendor prefix)")
|
||||
return True
|
||||
|
||||
# Reject everything else (likely cloud models not in registry)
|
||||
logging.debug(f"Model '{model_name}' rejected by custom provider (appears to be cloud model)")
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Request execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -284,25 +228,41 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
||||
"""Get model configurations from the registry.
|
||||
# ------------------------------------------------------------------
|
||||
# Registry helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
For CustomProvider, we convert registry configurations to ModelCapabilities objects.
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve registry aliases and strip version tags for local models."""
|
||||
|
||||
Returns:
|
||||
Dictionary mapping model names to their ModelCapabilities objects
|
||||
"""
|
||||
config = self._registry.resolve(model_name)
|
||||
if config:
|
||||
if config.model_name != model_name:
|
||||
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
||||
return config.model_name
|
||||
|
||||
configs = {}
|
||||
if ":" in model_name:
|
||||
base_model = model_name.split(":")[0]
|
||||
logging.debug(f"Stripped version tag from '{model_name}' -> '{base_model}'")
|
||||
|
||||
if self._registry:
|
||||
# Get all models from registry
|
||||
for model_name in self._registry.list_models():
|
||||
# Only include custom models that this provider validates
|
||||
if self.validate_model_name(model_name):
|
||||
config = self._registry.resolve(model_name)
|
||||
if config and config.is_custom:
|
||||
# Use ModelCapabilities directly from registry
|
||||
configs[model_name] = config
|
||||
base_config = self._registry.resolve(base_model)
|
||||
if base_config:
|
||||
logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'")
|
||||
return base_config.model_name
|
||||
return base_model
|
||||
|
||||
return configs
|
||||
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
||||
return model_name
|
||||
|
||||
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||
"""Expose registry capabilities for models marked as custom."""
|
||||
|
||||
if not self._registry:
|
||||
return {}
|
||||
|
||||
capabilities: dict[str, ModelCapabilities] = {}
|
||||
for model_name in self._registry.list_models():
|
||||
config = self._registry.resolve(model_name)
|
||||
if config and getattr(config, "is_custom", False):
|
||||
capabilities[model_name] = config
|
||||
return capabilities
|
||||
|
||||
Reference in New Issue
Block a user