refactor: cleanup provider base class; cleanup shared responsibilities; cleanup public contract
docs: document provider base class refactor: cleanup custom provider, it should only deal with `is_custom` model configurations fix: make sure openrouter provider does not load `is_custom` models fix: listmodels tool cleanup
This commit is contained in:
@@ -43,128 +43,37 @@ class ModelProvider(ABC):
|
||||
self.api_key = api_key
|
||||
self.config = kwargs
|
||||
|
||||
@abstractmethod
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a specific model."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.3,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""Generate content using the model.
|
||||
|
||||
Args:
|
||||
prompt: User prompt to send to the model
|
||||
model_name: Name of the model to use
|
||||
system_prompt: Optional system prompt for model behavior
|
||||
temperature: Sampling temperature (0-2)
|
||||
max_output_tokens: Maximum tokens to generate
|
||||
**kwargs: Provider-specific parameters
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated content and metadata
|
||||
"""
|
||||
pass
|
||||
|
||||
def count_tokens(self, text: str, model_name: str) -> int:
|
||||
"""Estimate token usage for a piece of text.
|
||||
|
||||
Providers can rely on this shared implementation or override it when
|
||||
they expose a more accurate tokenizer. This default uses a simple
|
||||
character-based heuristic so it works even without provider-specific
|
||||
tooling.
|
||||
"""
|
||||
|
||||
resolved_model = self._resolve_model_name(model_name)
|
||||
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
# Rough estimation: ~4 characters per token for English text
|
||||
estimated = max(1, len(text) // 4)
|
||||
logger.debug("Estimating %s tokens for model %s via character heuristic", estimated, resolved_model)
|
||||
return estimated
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Provider identity & capability surface
|
||||
# ------------------------------------------------------------------
|
||||
@abstractmethod
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
pass
|
||||
"""Return the concrete provider identity."""
|
||||
|
||||
@abstractmethod
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is supported by this provider."""
|
||||
pass
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Resolve capability metadata for a model name.
|
||||
|
||||
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
|
||||
"""Validate model parameters against capabilities.
|
||||
|
||||
Raises:
|
||||
ValueError: If parameters are invalid
|
||||
This centralises the alias resolution → lookup → restriction check
|
||||
pipeline so providers only override the pieces they genuinely need to
|
||||
customise. Subclasses usually only override ``_lookup_capabilities`` to
|
||||
integrate a registry or dynamic source, or ``_finalise_capabilities`` to
|
||||
tweak the returned object.
|
||||
"""
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
|
||||
# Validate temperature using constraint
|
||||
if not capabilities.temperature_constraint.validate(temperature):
|
||||
constraint_desc = capabilities.temperature_constraint.get_description()
|
||||
raise ValueError(f"Temperature {temperature} is invalid for model {model_name}. {constraint_desc}")
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
capabilities = self._lookup_capabilities(resolved_name, model_name)
|
||||
|
||||
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
||||
"""Get model configurations for this provider.
|
||||
if capabilities is None:
|
||||
self._raise_unsupported_model(model_name)
|
||||
|
||||
This is a hook method that subclasses can override to provide
|
||||
their model configurations from different sources.
|
||||
self._ensure_model_allowed(capabilities, resolved_name, model_name)
|
||||
return self._finalise_capabilities(capabilities, resolved_name, model_name)
|
||||
|
||||
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||
"""Return the provider's statically declared model capabilities."""
|
||||
|
||||
Returns:
|
||||
Dictionary mapping model names to their ModelCapabilities objects
|
||||
"""
|
||||
model_map = getattr(self, "MODEL_CAPABILITIES", None)
|
||||
if isinstance(model_map, dict) and model_map:
|
||||
return {k: v for k, v in model_map.items() if isinstance(v, ModelCapabilities)}
|
||||
return {}
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name.
|
||||
|
||||
This implementation uses the hook methods to support different
|
||||
model configuration sources.
|
||||
|
||||
Args:
|
||||
model_name: Model name that may be an alias
|
||||
|
||||
Returns:
|
||||
Resolved model name
|
||||
"""
|
||||
# Get model configurations from the hook method
|
||||
model_configs = self.get_model_configurations()
|
||||
|
||||
# First check if it's already a base model name (case-sensitive exact match)
|
||||
if model_name in model_configs:
|
||||
return model_name
|
||||
|
||||
# Check case-insensitively for both base models and aliases
|
||||
model_name_lower = model_name.lower()
|
||||
|
||||
# Check base model names case-insensitively
|
||||
for base_model in model_configs:
|
||||
if base_model.lower() == model_name_lower:
|
||||
return base_model
|
||||
|
||||
# Check aliases from the model configurations
|
||||
alias_map = ModelCapabilities.collect_aliases(model_configs)
|
||||
for base_model, aliases in alias_map.items():
|
||||
if any(alias.lower() == model_name_lower for alias in aliases):
|
||||
return base_model
|
||||
|
||||
# If not found, return as-is
|
||||
return model_name
|
||||
|
||||
def list_models(
|
||||
self,
|
||||
*,
|
||||
@@ -175,7 +84,7 @@ class ModelProvider(ABC):
|
||||
) -> list[str]:
|
||||
"""Return formatted model names supported by this provider."""
|
||||
|
||||
model_configs = self.get_model_configurations()
|
||||
model_configs = self.get_all_model_capabilities()
|
||||
if not model_configs:
|
||||
return []
|
||||
|
||||
@@ -202,36 +111,155 @@ class ModelProvider(ABC):
|
||||
unique=unique,
|
||||
)
|
||||
|
||||
def close(self):
|
||||
"""Clean up any resources held by the provider.
|
||||
# ------------------------------------------------------------------
|
||||
# Request execution
|
||||
# ------------------------------------------------------------------
|
||||
@abstractmethod
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.3,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""Generate content using the model."""
|
||||
|
||||
def count_tokens(self, text: str, model_name: str) -> int:
|
||||
"""Estimate token usage for a piece of text."""
|
||||
|
||||
resolved_model = self._resolve_model_name(model_name)
|
||||
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
estimated = max(1, len(text) // 4)
|
||||
logger.debug("Estimating %s tokens for model %s via character heuristic", estimated, resolved_model)
|
||||
return estimated
|
||||
|
||||
def close(self) -> None:
|
||||
"""Clean up any resources held by the provider."""
|
||||
|
||||
Default implementation does nothing.
|
||||
Subclasses should override if they hold resources that need cleanup.
|
||||
"""
|
||||
# Base implementation: no resources to clean up
|
||||
return
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Validation hooks
|
||||
# ------------------------------------------------------------------
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Return ``True`` when the model resolves to an allowed capability."""
|
||||
|
||||
try:
|
||||
self.get_capabilities(model_name)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
|
||||
"""Validate model parameters against capabilities."""
|
||||
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
|
||||
if not capabilities.temperature_constraint.validate(temperature):
|
||||
constraint_desc = capabilities.temperature_constraint.get_description()
|
||||
raise ValueError(f"Temperature {temperature} is invalid for model {model_name}. {constraint_desc}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Preference / registry hooks
|
||||
# ------------------------------------------------------------------
|
||||
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
|
||||
"""Get the preferred model from this provider for a given category.
|
||||
"""Get the preferred model from this provider for a given category."""
|
||||
|
||||
Args:
|
||||
category: The tool category requiring a model
|
||||
allowed_models: Pre-filtered list of model names that are allowed by restrictions
|
||||
|
||||
Returns:
|
||||
Model name if this provider has a preference, None otherwise
|
||||
"""
|
||||
# Default implementation - providers can override with specific logic
|
||||
return None
|
||||
|
||||
def get_model_registry(self) -> Optional[dict[str, Any]]:
|
||||
"""Get the model registry for providers that maintain one.
|
||||
"""Return the model registry backing this provider, if any."""
|
||||
|
||||
This is a hook method for providers like CustomProvider that maintain
|
||||
a dynamic model registry.
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Capability lookup pipeline
|
||||
# ------------------------------------------------------------------
|
||||
def _lookup_capabilities(
|
||||
self,
|
||||
canonical_name: str,
|
||||
requested_name: Optional[str] = None,
|
||||
) -> Optional[ModelCapabilities]:
|
||||
"""Return ``ModelCapabilities`` for the canonical model name."""
|
||||
|
||||
return self.get_all_model_capabilities().get(canonical_name)
|
||||
|
||||
def _ensure_model_allowed(
|
||||
self,
|
||||
capabilities: ModelCapabilities,
|
||||
canonical_name: str,
|
||||
requested_name: str,
|
||||
) -> None:
|
||||
"""Raise ``ValueError`` if the model violates restriction policy."""
|
||||
|
||||
try:
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
except Exception: # pragma: no cover - only triggered if service import breaks
|
||||
return
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service:
|
||||
return
|
||||
|
||||
if restriction_service.is_allowed(self.get_provider_type(), canonical_name, requested_name):
|
||||
return
|
||||
|
||||
raise ValueError(
|
||||
f"{self.get_provider_type().value} model '{canonical_name}' is not allowed by restriction policy."
|
||||
)
|
||||
|
||||
def _finalise_capabilities(
|
||||
self,
|
||||
capabilities: ModelCapabilities,
|
||||
canonical_name: str,
|
||||
requested_name: str,
|
||||
) -> ModelCapabilities:
|
||||
"""Allow subclasses to adjust capability metadata before returning."""
|
||||
|
||||
return capabilities
|
||||
|
||||
def _raise_unsupported_model(self, model_name: str) -> None:
|
||||
"""Raise the canonical unsupported-model error."""
|
||||
|
||||
raise ValueError(f"Unsupported model '{model_name}' for provider {self.get_provider_type().value}.")
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name.
|
||||
|
||||
This implementation uses the hook methods to support different
|
||||
model configuration sources.
|
||||
|
||||
Args:
|
||||
model_name: Model name that may be an alias
|
||||
|
||||
Returns:
|
||||
Model registry dict or None if not applicable
|
||||
Resolved model name
|
||||
"""
|
||||
# Default implementation - most providers don't have a registry
|
||||
return None
|
||||
# Get model configurations from the hook method
|
||||
model_configs = self.get_all_model_capabilities()
|
||||
|
||||
# First check if it's already a base model name (case-sensitive exact match)
|
||||
if model_name in model_configs:
|
||||
return model_name
|
||||
|
||||
# Check case-insensitively for both base models and aliases
|
||||
model_name_lower = model_name.lower()
|
||||
|
||||
# Check base model names case-insensitively
|
||||
for base_model in model_configs:
|
||||
if base_model.lower() == model_name_lower:
|
||||
return base_model
|
||||
|
||||
# Check aliases from the model configurations
|
||||
alias_map = ModelCapabilities.collect_aliases(model_configs)
|
||||
for base_model, aliases in alias_map.items():
|
||||
if any(alias.lower() == model_name_lower for alias in aliases):
|
||||
return base_model
|
||||
|
||||
# If not found, return as-is
|
||||
return model_name
|
||||
|
||||
@@ -83,117 +83,69 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
aliases = self._registry.list_aliases()
|
||||
logging.info(f"Custom provider loaded {len(models)} models with {len(aliases)} aliases")
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model aliases to actual model names.
|
||||
# ------------------------------------------------------------------
|
||||
# Capability surface
|
||||
# ------------------------------------------------------------------
|
||||
def _lookup_capabilities(
|
||||
self,
|
||||
canonical_name: str,
|
||||
requested_name: Optional[str] = None,
|
||||
) -> Optional[ModelCapabilities]:
|
||||
"""Return custom capabilities from the registry or generic defaults."""
|
||||
|
||||
For Ollama-style models, strips version tags (e.g., 'llama3.2:latest' -> 'llama3.2')
|
||||
since the base model name is what's typically used in API calls.
|
||||
|
||||
Args:
|
||||
model_name: Input model name or alias
|
||||
|
||||
Returns:
|
||||
Resolved model name with version tags stripped if applicable
|
||||
"""
|
||||
# First, try to resolve through registry as-is
|
||||
config = self._registry.resolve(model_name)
|
||||
|
||||
if config:
|
||||
if config.model_name != model_name:
|
||||
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
||||
return config.model_name
|
||||
else:
|
||||
# If not found in registry, handle version tags for local models
|
||||
# Strip version tags (anything after ':') for Ollama-style models
|
||||
if ":" in model_name:
|
||||
base_model = model_name.split(":")[0]
|
||||
logging.debug(f"Stripped version tag from '{model_name}' -> '{base_model}'")
|
||||
|
||||
# Try to resolve the base model through registry
|
||||
base_config = self._registry.resolve(base_model)
|
||||
if base_config:
|
||||
logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'")
|
||||
return base_config.model_name
|
||||
else:
|
||||
return base_model
|
||||
else:
|
||||
# If not found in registry and no version tag, return as-is
|
||||
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
||||
return model_name
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a custom model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model (or alias)
|
||||
|
||||
Returns:
|
||||
ModelCapabilities from registry or generic defaults
|
||||
"""
|
||||
# Try to get from registry first
|
||||
capabilities = self._registry.get_capabilities(model_name)
|
||||
builtin = super()._lookup_capabilities(canonical_name, requested_name)
|
||||
if builtin is not None:
|
||||
return builtin
|
||||
|
||||
capabilities = self._registry.get_capabilities(canonical_name)
|
||||
if capabilities:
|
||||
# Check if this is an OpenRouter model and apply restrictions
|
||||
config = self._registry.resolve(model_name)
|
||||
if config and not config.is_custom:
|
||||
# This is an OpenRouter model, check restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.OPENROUTER, config.model_name, model_name):
|
||||
raise ValueError(f"OpenRouter model '{model_name}' is not allowed by restriction policy.")
|
||||
|
||||
# Update provider type to OPENROUTER for OpenRouter models
|
||||
capabilities.provider = ProviderType.OPENROUTER
|
||||
else:
|
||||
# Update provider type to CUSTOM for local custom models
|
||||
config = self._registry.resolve(canonical_name)
|
||||
if config and getattr(config, "is_custom", False):
|
||||
capabilities.provider = ProviderType.CUSTOM
|
||||
return capabilities
|
||||
else:
|
||||
# Resolve any potential aliases and create generic capabilities
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
return capabilities
|
||||
# Non-custom models should fall through so OpenRouter handles them
|
||||
return None
|
||||
|
||||
logging.debug(
|
||||
f"Using generic capabilities for '{resolved_name}' via Custom API. "
|
||||
"Consider adding to custom_models.json for specific capabilities."
|
||||
)
|
||||
logging.debug(
|
||||
f"Using generic capabilities for '{canonical_name}' via Custom API. "
|
||||
"Consider adding to custom_models.json for specific capabilities."
|
||||
)
|
||||
|
||||
# Infer temperature behaviour for generic capability fallback
|
||||
supports_temperature, temperature_constraint, temperature_reason = TemperatureConstraint.resolve_settings(
|
||||
resolved_name
|
||||
)
|
||||
supports_temperature, temperature_constraint, temperature_reason = TemperatureConstraint.resolve_settings(
|
||||
canonical_name
|
||||
)
|
||||
|
||||
logging.warning(
|
||||
f"Model '{resolved_name}' not found in custom_models.json. Using generic capabilities with inferred settings. "
|
||||
f"Temperature support: {supports_temperature} ({temperature_reason}). "
|
||||
"For better accuracy, add this model to your custom_models.json configuration."
|
||||
)
|
||||
logging.warning(
|
||||
f"Model '{canonical_name}' not found in custom_models.json. Using generic capabilities with inferred settings. "
|
||||
f"Temperature support: {supports_temperature} ({temperature_reason}). "
|
||||
"For better accuracy, add this model to your custom_models.json configuration."
|
||||
)
|
||||
|
||||
# Create generic capabilities with inferred defaults
|
||||
capabilities = ModelCapabilities(
|
||||
provider=ProviderType.CUSTOM,
|
||||
model_name=resolved_name,
|
||||
friendly_name=f"{self.FRIENDLY_NAME} ({resolved_name})",
|
||||
context_window=32_768, # Conservative default
|
||||
max_output_tokens=32_768, # Conservative default max output
|
||||
supports_extended_thinking=False, # Most custom models don't support this
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # Conservative default
|
||||
supports_temperature=supports_temperature,
|
||||
temperature_constraint=temperature_constraint,
|
||||
)
|
||||
|
||||
# Mark as generic for validation purposes
|
||||
capabilities._is_generic = True
|
||||
|
||||
return capabilities
|
||||
generic = ModelCapabilities(
|
||||
provider=ProviderType.CUSTOM,
|
||||
model_name=canonical_name,
|
||||
friendly_name=f"{self.FRIENDLY_NAME} ({canonical_name})",
|
||||
context_window=32_768,
|
||||
max_output_tokens=32_768,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False,
|
||||
supports_temperature=supports_temperature,
|
||||
temperature_constraint=temperature_constraint,
|
||||
)
|
||||
generic._is_generic = True
|
||||
return generic
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
"""Identify this provider for restriction and logging logic."""
|
||||
|
||||
return ProviderType.CUSTOM
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Validation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is allowed.
|
||||
|
||||
@@ -206,49 +158,41 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
Returns:
|
||||
True if model is intended for custom/local endpoint
|
||||
"""
|
||||
# logging.debug(f"Custom provider validating model: '{model_name}'")
|
||||
if super().validate_model_name(model_name):
|
||||
return True
|
||||
|
||||
# Try to resolve through registry first
|
||||
config = self._registry.resolve(model_name)
|
||||
if config:
|
||||
model_id = config.model_name
|
||||
# Use explicit is_custom flag for clean validation
|
||||
if config.is_custom:
|
||||
logging.debug(f"... [Custom] Model '{model_name}' -> '{model_id}' validated via registry")
|
||||
return True
|
||||
else:
|
||||
# This is a cloud/OpenRouter model - CustomProvider should NOT handle these
|
||||
# Let OpenRouter provider handle them instead
|
||||
# logging.debug(f"... [Custom] Model '{model_name}' -> '{model_id}' not custom (defer to OpenRouter)")
|
||||
return False
|
||||
if config and not getattr(config, "is_custom", False):
|
||||
return False
|
||||
|
||||
# Handle version tags for unknown models (e.g., "my-model:latest")
|
||||
clean_model_name = model_name
|
||||
if ":" in model_name:
|
||||
clean_model_name = model_name.split(":")[0]
|
||||
clean_model_name = model_name.split(":", 1)[0]
|
||||
logging.debug(f"Stripped version tag from '{model_name}' -> '{clean_model_name}'")
|
||||
# Try to resolve the clean name
|
||||
|
||||
if super().validate_model_name(clean_model_name):
|
||||
return True
|
||||
|
||||
config = self._registry.resolve(clean_model_name)
|
||||
if config:
|
||||
return self.validate_model_name(clean_model_name) # Recursively validate clean name
|
||||
if config and not getattr(config, "is_custom", False):
|
||||
return False
|
||||
|
||||
# For unknown models (not in registry), only accept if they look like local models
|
||||
# This maintains backward compatibility for custom models not yet in the registry
|
||||
|
||||
# Accept models with explicit local indicators in the name
|
||||
if any(indicator in clean_model_name.lower() for indicator in ["local", "ollama", "vllm", "lmstudio"]):
|
||||
lowered = clean_model_name.lower()
|
||||
if any(indicator in lowered for indicator in ["local", "ollama", "vllm", "lmstudio"]):
|
||||
logging.debug(f"Model '{clean_model_name}' validated via local indicators")
|
||||
return True
|
||||
|
||||
# Accept simple model names without vendor prefix (likely local/custom models)
|
||||
if "/" not in clean_model_name:
|
||||
logging.debug(f"Model '{clean_model_name}' validated as potential local model (no vendor prefix)")
|
||||
return True
|
||||
|
||||
# Reject everything else (likely cloud models not in registry)
|
||||
logging.debug(f"Model '{model_name}' rejected by custom provider (appears to be cloud model)")
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Request execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -284,25 +228,41 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
||||
"""Get model configurations from the registry.
|
||||
# ------------------------------------------------------------------
|
||||
# Registry helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
For CustomProvider, we convert registry configurations to ModelCapabilities objects.
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve registry aliases and strip version tags for local models."""
|
||||
|
||||
Returns:
|
||||
Dictionary mapping model names to their ModelCapabilities objects
|
||||
"""
|
||||
config = self._registry.resolve(model_name)
|
||||
if config:
|
||||
if config.model_name != model_name:
|
||||
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
||||
return config.model_name
|
||||
|
||||
configs = {}
|
||||
if ":" in model_name:
|
||||
base_model = model_name.split(":")[0]
|
||||
logging.debug(f"Stripped version tag from '{model_name}' -> '{base_model}'")
|
||||
|
||||
if self._registry:
|
||||
# Get all models from registry
|
||||
for model_name in self._registry.list_models():
|
||||
# Only include custom models that this provider validates
|
||||
if self.validate_model_name(model_name):
|
||||
config = self._registry.resolve(model_name)
|
||||
if config and config.is_custom:
|
||||
# Use ModelCapabilities directly from registry
|
||||
configs[model_name] = config
|
||||
base_config = self._registry.resolve(base_model)
|
||||
if base_config:
|
||||
logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'")
|
||||
return base_config.model_name
|
||||
return base_model
|
||||
|
||||
return configs
|
||||
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
||||
return model_name
|
||||
|
||||
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||
"""Expose registry capabilities for models marked as custom."""
|
||||
|
||||
if not self._registry:
|
||||
return {}
|
||||
|
||||
capabilities: dict[str, ModelCapabilities] = {}
|
||||
for model_name in self._registry.list_models():
|
||||
config = self._registry.resolve(model_name)
|
||||
if config and getattr(config, "is_custom", False):
|
||||
capabilities[model_name] = config
|
||||
return capabilities
|
||||
|
||||
@@ -261,68 +261,10 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
|
||||
logger.info(f"Initialized DIAL provider with host: {dial_host} and api-version: {self.api_version}")
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a specific model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model (can be shorthand)
|
||||
|
||||
Returns:
|
||||
ModelCapabilities object
|
||||
|
||||
Raises:
|
||||
ValueError: If model is not supported or not allowed
|
||||
"""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||
raise ValueError(f"Unsupported DIAL model: {model_name}")
|
||||
|
||||
# Check restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
|
||||
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
|
||||
|
||||
# Return the ModelCapabilities object directly from MODEL_CAPABILITIES
|
||||
return self.MODEL_CAPABILITIES[resolved_name]
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
return ProviderType.DIAL
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is supported.
|
||||
|
||||
Args:
|
||||
model_name: Model name to validate
|
||||
|
||||
Returns:
|
||||
True if model is supported and allowed, False otherwise
|
||||
"""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||
return False
|
||||
|
||||
# Check against base class allowed_models if configured
|
||||
if self.allowed_models is not None:
|
||||
# Check both original and resolved names (case-insensitive)
|
||||
if model_name.lower() not in self.allowed_models and resolved_name.lower() not in self.allowed_models:
|
||||
logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' not in allowed_models list")
|
||||
return False
|
||||
|
||||
# Also check restrictions via ModelRestrictionService
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
|
||||
logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_deployment_client(self, deployment: str):
|
||||
"""Get or create a cached client for a specific deployment.
|
||||
|
||||
@@ -504,7 +446,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
|
||||
)
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Clean up HTTP clients when provider is closed."""
|
||||
logger.info("Closing DIAL provider HTTP clients...")
|
||||
|
||||
|
||||
@@ -131,6 +131,19 @@ class GeminiModelProvider(ModelProvider):
|
||||
self._token_counters = {} # Cache for token counting
|
||||
self._base_url = kwargs.get("base_url", None) # Optional custom endpoint
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Capability surface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||
"""Return statically defined Gemini capabilities."""
|
||||
|
||||
return dict(self.MODEL_CAPABILITIES)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Client access
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Lazy initialization of Gemini client."""
|
||||
@@ -146,25 +159,9 @@ class GeminiModelProvider(ModelProvider):
|
||||
self._client = genai.Client(api_key=self.api_key)
|
||||
return self._client
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a specific Gemini model."""
|
||||
# Resolve shorthand
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||
raise ValueError(f"Unsupported Gemini model: {model_name}")
|
||||
|
||||
# Check if model is allowed by restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
# IMPORTANT: Parameter order is (provider_type, model_name, original_name)
|
||||
# resolved_name is the canonical model name, model_name is the user input
|
||||
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
|
||||
raise ValueError(f"Gemini model '{resolved_name}' is not allowed by restriction policy.")
|
||||
|
||||
# Return the ModelCapabilities object directly from MODEL_CAPABILITIES
|
||||
return self.MODEL_CAPABILITIES[resolved_name]
|
||||
# ------------------------------------------------------------------
|
||||
# Request execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
@@ -365,26 +362,6 @@ class GeminiModelProvider(ModelProvider):
|
||||
"""Get the provider type."""
|
||||
return ProviderType.GOOGLE
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is supported and allowed."""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
# First check if model is supported
|
||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||
return False
|
||||
|
||||
# Then check if model is allowed by restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
# IMPORTANT: Parameter order is (provider_type, model_name, original_name)
|
||||
# resolved_name is the canonical model name, model_name is the user input
|
||||
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
|
||||
logger.debug(f"Gemini model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int:
|
||||
"""Get actual thinking token budget for a model and thinking mode."""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
@@ -5,7 +5,6 @@ import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -61,6 +60,33 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
"This may be insecure. Consider setting an API key for authentication."
|
||||
)
|
||||
|
||||
def _ensure_model_allowed(
|
||||
self,
|
||||
capabilities: ModelCapabilities,
|
||||
canonical_name: str,
|
||||
requested_name: str,
|
||||
) -> None:
|
||||
"""Respect provider-specific allowlists before default restriction checks."""
|
||||
|
||||
super()._ensure_model_allowed(capabilities, canonical_name, requested_name)
|
||||
|
||||
if self.allowed_models is not None:
|
||||
requested = requested_name.lower()
|
||||
canonical = canonical_name.lower()
|
||||
|
||||
if requested not in self.allowed_models and canonical not in self.allowed_models:
|
||||
raise ValueError(
|
||||
f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}"
|
||||
)
|
||||
|
||||
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||
"""Return statically declared capabilities for OpenAI-compatible providers."""
|
||||
|
||||
model_map = getattr(self, "MODEL_CAPABILITIES", None)
|
||||
if isinstance(model_map, dict):
|
||||
return {k: v for k, v in model_map.items() if isinstance(v, ModelCapabilities)}
|
||||
return {}
|
||||
|
||||
def _parse_allowed_models(self) -> Optional[set[str]]:
|
||||
"""Parse allowed models from environment variable.
|
||||
|
||||
@@ -686,30 +712,6 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
|
||||
return super().count_tokens(text, model_name)
|
||||
|
||||
@abstractmethod
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a specific model.
|
||||
|
||||
Must be implemented by subclasses.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type.
|
||||
|
||||
Must be implemented by subclasses.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is supported.
|
||||
|
||||
Must be implemented by subclasses.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _is_error_retryable(self, error: Exception) -> bool:
|
||||
"""Determine if an error should be retried based on structured error codes.
|
||||
|
||||
|
||||
@@ -174,106 +174,61 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
kwargs.setdefault("base_url", "https://api.openai.com/v1")
|
||||
super().__init__(api_key, **kwargs)
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a specific OpenAI model."""
|
||||
# First check if it's a key in MODEL_CAPABILITIES
|
||||
if model_name in self.MODEL_CAPABILITIES:
|
||||
self._check_model_restrictions(model_name, model_name)
|
||||
return self.MODEL_CAPABILITIES[model_name]
|
||||
# ------------------------------------------------------------------
|
||||
# Capability surface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# Try resolving as alias
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
def _lookup_capabilities(
|
||||
self,
|
||||
canonical_name: str,
|
||||
requested_name: Optional[str] = None,
|
||||
) -> Optional[ModelCapabilities]:
|
||||
"""Look up OpenAI capabilities from built-ins or the custom registry."""
|
||||
|
||||
# Check if resolved name is a key
|
||||
if resolved_name in self.MODEL_CAPABILITIES:
|
||||
self._check_model_restrictions(resolved_name, model_name)
|
||||
return self.MODEL_CAPABILITIES[resolved_name]
|
||||
builtin = super()._lookup_capabilities(canonical_name, requested_name)
|
||||
if builtin is not None:
|
||||
return builtin
|
||||
|
||||
# Finally check if resolved name matches any API model name
|
||||
for key, capabilities in self.MODEL_CAPABILITIES.items():
|
||||
if resolved_name == capabilities.model_name:
|
||||
self._check_model_restrictions(key, model_name)
|
||||
return capabilities
|
||||
|
||||
# Check custom models registry for user-configured OpenAI models
|
||||
try:
|
||||
from .openrouter_registry import OpenRouterModelRegistry
|
||||
|
||||
registry = OpenRouterModelRegistry()
|
||||
config = registry.get_model_config(resolved_name)
|
||||
config = registry.get_model_config(canonical_name)
|
||||
|
||||
if config and config.provider == ProviderType.OPENAI:
|
||||
self._check_model_restrictions(config.model_name, model_name)
|
||||
|
||||
# Update provider type to ensure consistency
|
||||
config.provider = ProviderType.OPENAI
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
# Log but don't fail - registry might not be available
|
||||
logger.debug(f"Could not check custom models registry for '{resolved_name}': {e}")
|
||||
except Exception as exc: # pragma: no cover - registry failures are non-critical
|
||||
logger.debug(f"Could not resolve custom OpenAI model '{canonical_name}': {exc}")
|
||||
|
||||
return None
|
||||
|
||||
def _finalise_capabilities(
|
||||
self,
|
||||
capabilities: ModelCapabilities,
|
||||
canonical_name: str,
|
||||
requested_name: str,
|
||||
) -> ModelCapabilities:
|
||||
"""Ensure registry-sourced models report the correct provider type."""
|
||||
|
||||
if capabilities.provider != ProviderType.OPENAI:
|
||||
capabilities.provider = ProviderType.OPENAI
|
||||
return capabilities
|
||||
|
||||
def _raise_unsupported_model(self, model_name: str) -> None:
|
||||
raise ValueError(f"Unsupported OpenAI model: {model_name}")
|
||||
|
||||
def _check_model_restrictions(self, provider_model_name: str, user_model_name: str) -> None:
|
||||
"""Check if a model is allowed by restriction policy.
|
||||
|
||||
Args:
|
||||
provider_model_name: The model name used by the provider
|
||||
user_model_name: The model name requested by the user
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not allowed by restriction policy
|
||||
"""
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.OPENAI, provider_model_name, user_model_name):
|
||||
raise ValueError(f"OpenAI model '{user_model_name}' is not allowed by restriction policy.")
|
||||
# ------------------------------------------------------------------
|
||||
# Provider identity
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
return ProviderType.OPENAI
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is supported and allowed."""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
# First, determine which model name to check against restrictions.
|
||||
model_to_check = None
|
||||
is_custom_model = False
|
||||
|
||||
if resolved_name in self.MODEL_CAPABILITIES:
|
||||
model_to_check = resolved_name
|
||||
else:
|
||||
# If not a built-in model, check the custom models registry.
|
||||
try:
|
||||
from .openrouter_registry import OpenRouterModelRegistry
|
||||
|
||||
registry = OpenRouterModelRegistry()
|
||||
config = registry.get_model_config(resolved_name)
|
||||
|
||||
if config and config.provider == ProviderType.OPENAI:
|
||||
model_to_check = config.model_name
|
||||
is_custom_model = True
|
||||
except Exception as e:
|
||||
# Log but don't fail - registry might not be available.
|
||||
logger.debug(f"Could not check custom models registry for '{resolved_name}': {e}")
|
||||
|
||||
# If no model was found (neither built-in nor custom), it's invalid.
|
||||
if not model_to_check:
|
||||
return False
|
||||
|
||||
# Now, perform the restriction check once.
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.OPENAI, model_to_check, model_name):
|
||||
model_type = "custom " if is_custom_model else ""
|
||||
logger.debug(f"OpenAI {model_type}model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
||||
return False
|
||||
|
||||
return True
|
||||
# ------------------------------------------------------------------
|
||||
# Request execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
@@ -298,6 +253,10 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Provider preferences
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
|
||||
"""Get OpenAI's preferred model for a given category from allowed models.
|
||||
|
||||
|
||||
@@ -61,108 +61,52 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
aliases = self._registry.list_aliases()
|
||||
logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases")
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model aliases to OpenRouter model names.
|
||||
# ------------------------------------------------------------------
|
||||
# Capability surface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
Args:
|
||||
model_name: Input model name or alias
|
||||
|
||||
Returns:
|
||||
Resolved OpenRouter model name
|
||||
"""
|
||||
# Try to resolve through registry
|
||||
config = self._registry.resolve(model_name)
|
||||
|
||||
if config:
|
||||
if config.model_name != model_name:
|
||||
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
||||
return config.model_name
|
||||
else:
|
||||
# If not found in registry, return as-is
|
||||
# This allows using models not in our config file
|
||||
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
||||
return model_name
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model (or alias)
|
||||
|
||||
Returns:
|
||||
ModelCapabilities from registry or generic defaults
|
||||
"""
|
||||
# Try to get from registry first
|
||||
capabilities = self._registry.get_capabilities(model_name)
|
||||
def _lookup_capabilities(
|
||||
self,
|
||||
canonical_name: str,
|
||||
requested_name: Optional[str] = None,
|
||||
) -> Optional[ModelCapabilities]:
|
||||
"""Fetch OpenRouter capabilities from the registry or build a generic fallback."""
|
||||
|
||||
capabilities = self._registry.get_capabilities(canonical_name)
|
||||
if capabilities:
|
||||
return capabilities
|
||||
else:
|
||||
# Resolve any potential aliases and create generic capabilities
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
logging.debug(
|
||||
f"Using generic capabilities for '{resolved_name}' via OpenRouter. "
|
||||
"Consider adding to custom_models.json for specific capabilities."
|
||||
)
|
||||
logging.debug(
|
||||
f"Using generic capabilities for '{canonical_name}' via OpenRouter. "
|
||||
"Consider adding to custom_models.json for specific capabilities."
|
||||
)
|
||||
|
||||
# Create generic capabilities with conservative defaults
|
||||
capabilities = ModelCapabilities(
|
||||
provider=ProviderType.OPENROUTER,
|
||||
model_name=resolved_name,
|
||||
friendly_name=self.FRIENDLY_NAME,
|
||||
context_window=32_768, # Conservative default context window
|
||||
max_output_tokens=32_768,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False,
|
||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
|
||||
)
|
||||
generic = ModelCapabilities(
|
||||
provider=ProviderType.OPENROUTER,
|
||||
model_name=canonical_name,
|
||||
friendly_name=self.FRIENDLY_NAME,
|
||||
context_window=32_768,
|
||||
max_output_tokens=32_768,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False,
|
||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
|
||||
)
|
||||
generic._is_generic = True
|
||||
return generic
|
||||
|
||||
# Mark as generic for validation purposes
|
||||
capabilities._is_generic = True
|
||||
|
||||
return capabilities
|
||||
# ------------------------------------------------------------------
|
||||
# Provider identity
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
"""Identify this provider for restrictions and logging."""
|
||||
return ProviderType.OPENROUTER
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is allowed.
|
||||
|
||||
As the catch-all provider, OpenRouter accepts any model name that wasn't
|
||||
handled by higher-priority providers. OpenRouter will validate based on
|
||||
the API key's permissions and local restrictions.
|
||||
|
||||
Args:
|
||||
model_name: Model name to validate
|
||||
|
||||
Returns:
|
||||
True if model is allowed, False if restricted
|
||||
"""
|
||||
# Check model restrictions if configured
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if restriction_service:
|
||||
# Check if model name itself is allowed
|
||||
if restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||
return True
|
||||
|
||||
# Also check aliases - model_name might be an alias
|
||||
model_config = self._registry.resolve(model_name)
|
||||
if model_config and model_config.aliases:
|
||||
for alias in model_config.aliases:
|
||||
if restriction_service.is_allowed(self.get_provider_type(), alias):
|
||||
return True
|
||||
|
||||
# If restrictions are configured and model/alias not in allowed list, reject
|
||||
return False
|
||||
|
||||
# No restrictions configured - accept any model name as the fallback provider
|
||||
return True
|
||||
# ------------------------------------------------------------------
|
||||
# Request execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
@@ -204,6 +148,10 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Registry helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def list_models(
|
||||
self,
|
||||
*,
|
||||
@@ -227,6 +175,12 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
if not config:
|
||||
continue
|
||||
|
||||
# Custom models belong to CustomProvider; skip them here so the two
|
||||
# providers don't race over the same registrations (important for tests
|
||||
# that stub the registry with minimal objects lacking attrs).
|
||||
if hasattr(config, "is_custom") and config.is_custom is True:
|
||||
continue
|
||||
|
||||
if restriction_service:
|
||||
allowed = restriction_service.is_allowed(self.get_provider_type(), model_name)
|
||||
|
||||
@@ -255,24 +209,37 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
unique=unique,
|
||||
)
|
||||
|
||||
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
||||
"""Get model configurations from the registry.
|
||||
# ------------------------------------------------------------------
|
||||
# Registry helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
For OpenRouter, we convert registry configurations to ModelCapabilities objects.
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve aliases defined in the OpenRouter registry."""
|
||||
|
||||
Returns:
|
||||
Dictionary mapping model names to their ModelCapabilities objects
|
||||
"""
|
||||
configs = {}
|
||||
config = self._registry.resolve(model_name)
|
||||
if config:
|
||||
if config.model_name != model_name:
|
||||
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
||||
return config.model_name
|
||||
|
||||
if self._registry:
|
||||
# Get all models from registry
|
||||
for model_name in self._registry.list_models():
|
||||
# Only include models that this provider validates
|
||||
if self.validate_model_name(model_name):
|
||||
config = self._registry.resolve(model_name)
|
||||
if config and not config.is_custom: # Only OpenRouter models, not custom ones
|
||||
# Use ModelCapabilities directly from registry
|
||||
configs[model_name] = config
|
||||
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
||||
return model_name
|
||||
|
||||
return configs
|
||||
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||
"""Expose registry-backed OpenRouter capabilities."""
|
||||
|
||||
if not self._registry:
|
||||
return {}
|
||||
|
||||
capabilities: dict[str, ModelCapabilities] = {}
|
||||
for model_name in self._registry.list_models():
|
||||
config = self._registry.resolve(model_name)
|
||||
if not config:
|
||||
continue
|
||||
|
||||
# See note in list_models: respect the CustomProvider boundary.
|
||||
if hasattr(config, "is_custom") and config.is_custom is True:
|
||||
continue
|
||||
|
||||
capabilities[model_name] = config
|
||||
return capabilities
|
||||
|
||||
@@ -64,6 +64,8 @@ class ModelProviderRegistry:
|
||||
"""
|
||||
instance = cls()
|
||||
instance._providers[provider_type] = provider_class
|
||||
# Invalidate any cached instance so subsequent lookups use the new registration
|
||||
instance._initialized_providers.pop(provider_type, None)
|
||||
|
||||
@classmethod
|
||||
def get_provider(cls, provider_type: ProviderType, force_new: bool = False) -> Optional[ModelProvider]:
|
||||
|
||||
@@ -85,46 +85,10 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
kwargs.setdefault("base_url", "https://api.x.ai/v1")
|
||||
super().__init__(api_key, **kwargs)
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a specific X.AI model."""
|
||||
# Resolve shorthand
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||
raise ValueError(f"Unsupported X.AI model: {model_name}")
|
||||
|
||||
# Check if model is allowed by restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name):
|
||||
raise ValueError(f"X.AI model '{model_name}' is not allowed by restriction policy.")
|
||||
|
||||
# Return the ModelCapabilities object directly from MODEL_CAPABILITIES
|
||||
return self.MODEL_CAPABILITIES[resolved_name]
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
return ProviderType.XAI
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is supported and allowed."""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
# First check if model is supported
|
||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||
return False
|
||||
|
||||
# Then check if model is allowed by restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name):
|
||||
logger.debug(f"X.AI model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
|
||||
Reference in New Issue
Block a user