diff --git a/docs/adding_providers.md b/docs/adding_providers.md index 0a62c8c..21abd53 100644 --- a/docs/adding_providers.md +++ b/docs/adding_providers.md @@ -7,7 +7,7 @@ This guide explains how to add support for a new AI model provider to the Zen MC Each provider: - Inherits from `ModelProvider` (base class) or `OpenAICompatibleProvider` (for OpenAI-compatible APIs) - Defines supported models using `ModelCapabilities` objects -- Implements a few core abstract methods +- Implements the minimal abstract hooks (`get_provider_type()` and `generate_content()`) - Gets registered automatically via environment variables ## Choose Your Implementation Path @@ -15,11 +15,11 @@ Each provider: **Option A: Full Provider (`ModelProvider`)** - For APIs with unique features or custom authentication - Complete control over API calls and response handling -- Required methods: `generate_content()`, `get_capabilities()`, `validate_model_name()`, `get_provider_type()` (override `count_tokens()` only when you have a provider-accurate tokenizer) +- Implement `generate_content()` and `get_provider_type()`; override `get_all_model_capabilities()` to expose your catalogue and extend `_lookup_capabilities()` / `_ensure_model_allowed()` only when you need registry lookups or custom restriction rules (override `count_tokens()` only when you have a provider-accurate tokenizer) **Option B: OpenAI-Compatible (`OpenAICompatibleProvider`)** - For APIs that follow OpenAI's chat completion format -- Only need to define: model configurations, capabilities, and validation +- Supply `MODEL_CAPABILITIES`, override `get_provider_type()`, and optionally adjust configuration (the base class handles alias resolution, validation, and request wiring) - Inherits all API handling automatically ⚠️ **Important**: If using aliases (like `"gpt"` → `"gpt-4"`), override `generate_content()` to resolve them before API calls. @@ -62,8 +62,7 @@ logger = logging.getLogger(__name__) class ExampleModelProvider(ModelProvider): """Example model provider implementation.""" - - # Define models using ModelCapabilities objects (like Gemini provider) + MODEL_CAPABILITIES = { "example-large": ModelCapabilities( provider=ProviderType.EXAMPLE, @@ -87,51 +86,47 @@ class ExampleModelProvider(ModelProvider): aliases=["small", "fast"], ), } - + def __init__(self, api_key: str, **kwargs): super().__init__(api_key, **kwargs) # Initialize your API client here - - def get_capabilities(self, model_name: str) -> ModelCapabilities: + + def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]: + return dict(self.MODEL_CAPABILITIES) + + def get_provider_type(self) -> ProviderType: + return ProviderType.EXAMPLE + + def generate_content( + self, + prompt: str, + model_name: str, + system_prompt: Optional[str] = None, + temperature: float = 0.7, + max_output_tokens: Optional[int] = None, + **kwargs, + ) -> ModelResponse: resolved_name = self._resolve_model_name(model_name) - - if resolved_name not in self.MODEL_CAPABILITIES: - raise ValueError(f"Unsupported model: {model_name}") - - # Apply restrictions if needed - from utils.model_restrictions import get_restriction_service - restriction_service = get_restriction_service() - if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name): - raise ValueError(f"Model '{model_name}' is not allowed.") - - return self.MODEL_CAPABILITIES[resolved_name] - - def generate_content(self, prompt: str, model_name: str, system_prompt: Optional[str] = None, - temperature: float = 0.7, max_output_tokens: Optional[int] = None, **kwargs) -> ModelResponse: - resolved_name = self._resolve_model_name(model_name) - + # Your API call logic here # response = your_api_client.generate(...) - + return ModelResponse( - content="Generated response", # From your API + content="Generated response", usage={"input_tokens": 100, "output_tokens": 50, "total_tokens": 150}, model_name=resolved_name, friendly_name="Example", provider=ProviderType.EXAMPLE, ) - def get_provider_type(self) -> ProviderType: - return ProviderType.EXAMPLE - - def validate_model_name(self, model_name: str) -> bool: - resolved_name = self._resolve_model_name(model_name) - return resolved_name in self.MODEL_CAPABILITIES ``` -`ModelProvider.count_tokens()` uses a simple 4-characters-per-token estimate so -providers work out of the box. Override the method only when you can call into -the provider's real tokenizer (for example, the OpenAI-compatible base class -already integrates `tiktoken`). +`ModelProvider.get_capabilities()` automatically resolves aliases, enforces the +shared restriction service, and returns the correct `ModelCapabilities` +instance. Override `_lookup_capabilities()` only when you source capabilities +from a registry or remote API. `ModelProvider.count_tokens()` uses a simple +4-characters-per-token estimate so providers work out of the box—override it +only when you can call the provider's real tokenizer (for example, the +OpenAI-compatible base class integrates `tiktoken`). #### Option B: OpenAI-Compatible Provider (Simplified) @@ -172,26 +167,16 @@ class ExampleProvider(OpenAICompatibleProvider): def __init__(self, api_key: str, **kwargs): kwargs.setdefault("base_url", "https://api.example.com/v1") super().__init__(api_key, **kwargs) - - def get_capabilities(self, model_name: str) -> ModelCapabilities: - resolved_name = self._resolve_model_name(model_name) - if resolved_name not in self.MODEL_CAPABILITIES: - raise ValueError(f"Unsupported model: {model_name}") - return self.MODEL_CAPABILITIES[resolved_name] - + def get_provider_type(self) -> ProviderType: return ProviderType.EXAMPLE - - def validate_model_name(self, model_name: str) -> bool: - resolved_name = self._resolve_model_name(model_name) - return resolved_name in self.MODEL_CAPABILITIES - - def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse: - # IMPORTANT: Resolve aliases before API call - resolved_model_name = self._resolve_model_name(model_name) - return super().generate_content(prompt=prompt, model_name=resolved_model_name, **kwargs) ``` +`OpenAICompatibleProvider` already exposes the declared models via +`MODEL_CAPABILITIES`, resolves aliases through the shared base pipeline, and +enforces restrictions. Most subclasses only need to provide the class metadata +shown above. + ### 3. Register Your Provider Add environment variable mapping in `providers/registry.py`: @@ -237,15 +222,11 @@ DISABLED_TOOLS=debug,tracer Create basic tests to verify your implementation: ```python -# Test model validation -provider = ExampleModelProvider("test-key") -assert provider.validate_model_name("large") == True -assert provider.validate_model_name("unknown") == False - # Test capabilities -caps = provider.get_capabilities("large") -assert caps.context_window > 0 -assert caps.provider == ProviderType.EXAMPLE +provider = ExampleModelProvider("test-key") +capabilities = provider.get_capabilities("large") +assert capabilities.context_window > 0 +assert capabilities.provider == ProviderType.EXAMPLE ``` @@ -259,31 +240,19 @@ When a user requests a model, providers are checked in priority order: 3. **OpenRouter** - catch-all for everything else ### Model Validation -Your `validate_model_name()` should **only** return `True` for models you explicitly support: - -```python -def validate_model_name(self, model_name: str) -> bool: - resolved_name = self._resolve_model_name(model_name) - return resolved_name in self.MODEL_CAPABILITIES # Be specific! -``` +`ModelProvider.validate_model_name()` delegates to `get_capabilities()` so most +providers can rely on the shared implementation. Override it only when you need +to opt out of that pipeline—for example, `CustomProvider` declines OpenRouter +models so they fall through to the dedicated OpenRouter provider. ### Model Aliases -The base class handles alias resolution automatically via the `aliases` field in `ModelCapabilities`. +Aliases declared on `ModelCapabilities` are applied automatically via +`_resolve_model_name()`, and both the validation and request flows call it +before touching your SDK. Override `generate_content()` only when your provider +needs additional alias handling beyond the shared behaviour. ## Important Notes -### Alias Resolution in OpenAI-Compatible Providers -If using `OpenAICompatibleProvider` with aliases, **you must override `generate_content()`** to resolve aliases before API calls: - -```python -def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse: - # Resolve alias before API call - resolved_model_name = self._resolve_model_name(model_name) - return super().generate_content(prompt=prompt, model_name=resolved_model_name, **kwargs) -``` - -Without this, API calls with aliases like `"large"` will fail because your API doesn't recognize the alias. - ## Best Practices - **Be specific in model validation** - only accept models you actually support diff --git a/providers/base.py b/providers/base.py index d959832..05b688a 100644 --- a/providers/base.py +++ b/providers/base.py @@ -43,128 +43,37 @@ class ModelProvider(ABC): self.api_key = api_key self.config = kwargs - @abstractmethod - def get_capabilities(self, model_name: str) -> ModelCapabilities: - """Get capabilities for a specific model.""" - pass - - @abstractmethod - def generate_content( - self, - prompt: str, - model_name: str, - system_prompt: Optional[str] = None, - temperature: float = 0.3, - max_output_tokens: Optional[int] = None, - **kwargs, - ) -> ModelResponse: - """Generate content using the model. - - Args: - prompt: User prompt to send to the model - model_name: Name of the model to use - system_prompt: Optional system prompt for model behavior - temperature: Sampling temperature (0-2) - max_output_tokens: Maximum tokens to generate - **kwargs: Provider-specific parameters - - Returns: - ModelResponse with generated content and metadata - """ - pass - - def count_tokens(self, text: str, model_name: str) -> int: - """Estimate token usage for a piece of text. - - Providers can rely on this shared implementation or override it when - they expose a more accurate tokenizer. This default uses a simple - character-based heuristic so it works even without provider-specific - tooling. - """ - - resolved_model = self._resolve_model_name(model_name) - - if not text: - return 0 - - # Rough estimation: ~4 characters per token for English text - estimated = max(1, len(text) // 4) - logger.debug("Estimating %s tokens for model %s via character heuristic", estimated, resolved_model) - return estimated - + # ------------------------------------------------------------------ + # Provider identity & capability surface + # ------------------------------------------------------------------ @abstractmethod def get_provider_type(self) -> ProviderType: - """Get the provider type.""" - pass + """Return the concrete provider identity.""" - @abstractmethod - def validate_model_name(self, model_name: str) -> bool: - """Validate if the model name is supported by this provider.""" - pass + def get_capabilities(self, model_name: str) -> ModelCapabilities: + """Resolve capability metadata for a model name. - def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None: - """Validate model parameters against capabilities. - - Raises: - ValueError: If parameters are invalid + This centralises the alias resolution → lookup → restriction check + pipeline so providers only override the pieces they genuinely need to + customise. Subclasses usually only override ``_lookup_capabilities`` to + integrate a registry or dynamic source, or ``_finalise_capabilities`` to + tweak the returned object. """ - capabilities = self.get_capabilities(model_name) - # Validate temperature using constraint - if not capabilities.temperature_constraint.validate(temperature): - constraint_desc = capabilities.temperature_constraint.get_description() - raise ValueError(f"Temperature {temperature} is invalid for model {model_name}. {constraint_desc}") + resolved_name = self._resolve_model_name(model_name) + capabilities = self._lookup_capabilities(resolved_name, model_name) - def get_model_configurations(self) -> dict[str, ModelCapabilities]: - """Get model configurations for this provider. + if capabilities is None: + self._raise_unsupported_model(model_name) - This is a hook method that subclasses can override to provide - their model configurations from different sources. + self._ensure_model_allowed(capabilities, resolved_name, model_name) + return self._finalise_capabilities(capabilities, resolved_name, model_name) + + def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]: + """Return the provider's statically declared model capabilities.""" - Returns: - Dictionary mapping model names to their ModelCapabilities objects - """ - model_map = getattr(self, "MODEL_CAPABILITIES", None) - if isinstance(model_map, dict) and model_map: - return {k: v for k, v in model_map.items() if isinstance(v, ModelCapabilities)} return {} - def _resolve_model_name(self, model_name: str) -> str: - """Resolve model shorthand to full name. - - This implementation uses the hook methods to support different - model configuration sources. - - Args: - model_name: Model name that may be an alias - - Returns: - Resolved model name - """ - # Get model configurations from the hook method - model_configs = self.get_model_configurations() - - # First check if it's already a base model name (case-sensitive exact match) - if model_name in model_configs: - return model_name - - # Check case-insensitively for both base models and aliases - model_name_lower = model_name.lower() - - # Check base model names case-insensitively - for base_model in model_configs: - if base_model.lower() == model_name_lower: - return base_model - - # Check aliases from the model configurations - alias_map = ModelCapabilities.collect_aliases(model_configs) - for base_model, aliases in alias_map.items(): - if any(alias.lower() == model_name_lower for alias in aliases): - return base_model - - # If not found, return as-is - return model_name - def list_models( self, *, @@ -175,7 +84,7 @@ class ModelProvider(ABC): ) -> list[str]: """Return formatted model names supported by this provider.""" - model_configs = self.get_model_configurations() + model_configs = self.get_all_model_capabilities() if not model_configs: return [] @@ -202,36 +111,155 @@ class ModelProvider(ABC): unique=unique, ) - def close(self): - """Clean up any resources held by the provider. + # ------------------------------------------------------------------ + # Request execution + # ------------------------------------------------------------------ + @abstractmethod + def generate_content( + self, + prompt: str, + model_name: str, + system_prompt: Optional[str] = None, + temperature: float = 0.3, + max_output_tokens: Optional[int] = None, + **kwargs, + ) -> ModelResponse: + """Generate content using the model.""" + + def count_tokens(self, text: str, model_name: str) -> int: + """Estimate token usage for a piece of text.""" + + resolved_model = self._resolve_model_name(model_name) + + if not text: + return 0 + + estimated = max(1, len(text) // 4) + logger.debug("Estimating %s tokens for model %s via character heuristic", estimated, resolved_model) + return estimated + + def close(self) -> None: + """Clean up any resources held by the provider.""" - Default implementation does nothing. - Subclasses should override if they hold resources that need cleanup. - """ - # Base implementation: no resources to clean up return + # ------------------------------------------------------------------ + # Validation hooks + # ------------------------------------------------------------------ + def validate_model_name(self, model_name: str) -> bool: + """Return ``True`` when the model resolves to an allowed capability.""" + + try: + self.get_capabilities(model_name) + except ValueError: + return False + return True + + def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None: + """Validate model parameters against capabilities.""" + + capabilities = self.get_capabilities(model_name) + + if not capabilities.temperature_constraint.validate(temperature): + constraint_desc = capabilities.temperature_constraint.get_description() + raise ValueError(f"Temperature {temperature} is invalid for model {model_name}. {constraint_desc}") + + # ------------------------------------------------------------------ + # Preference / registry hooks + # ------------------------------------------------------------------ def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]: - """Get the preferred model from this provider for a given category. + """Get the preferred model from this provider for a given category.""" - Args: - category: The tool category requiring a model - allowed_models: Pre-filtered list of model names that are allowed by restrictions - - Returns: - Model name if this provider has a preference, None otherwise - """ - # Default implementation - providers can override with specific logic return None def get_model_registry(self) -> Optional[dict[str, Any]]: - """Get the model registry for providers that maintain one. + """Return the model registry backing this provider, if any.""" - This is a hook method for providers like CustomProvider that maintain - a dynamic model registry. + return None + + # ------------------------------------------------------------------ + # Capability lookup pipeline + # ------------------------------------------------------------------ + def _lookup_capabilities( + self, + canonical_name: str, + requested_name: Optional[str] = None, + ) -> Optional[ModelCapabilities]: + """Return ``ModelCapabilities`` for the canonical model name.""" + + return self.get_all_model_capabilities().get(canonical_name) + + def _ensure_model_allowed( + self, + capabilities: ModelCapabilities, + canonical_name: str, + requested_name: str, + ) -> None: + """Raise ``ValueError`` if the model violates restriction policy.""" + + try: + from utils.model_restrictions import get_restriction_service + except Exception: # pragma: no cover - only triggered if service import breaks + return + + restriction_service = get_restriction_service() + if not restriction_service: + return + + if restriction_service.is_allowed(self.get_provider_type(), canonical_name, requested_name): + return + + raise ValueError( + f"{self.get_provider_type().value} model '{canonical_name}' is not allowed by restriction policy." + ) + + def _finalise_capabilities( + self, + capabilities: ModelCapabilities, + canonical_name: str, + requested_name: str, + ) -> ModelCapabilities: + """Allow subclasses to adjust capability metadata before returning.""" + + return capabilities + + def _raise_unsupported_model(self, model_name: str) -> None: + """Raise the canonical unsupported-model error.""" + + raise ValueError(f"Unsupported model '{model_name}' for provider {self.get_provider_type().value}.") + + def _resolve_model_name(self, model_name: str) -> str: + """Resolve model shorthand to full name. + + This implementation uses the hook methods to support different + model configuration sources. + + Args: + model_name: Model name that may be an alias Returns: - Model registry dict or None if not applicable + Resolved model name """ - # Default implementation - most providers don't have a registry - return None + # Get model configurations from the hook method + model_configs = self.get_all_model_capabilities() + + # First check if it's already a base model name (case-sensitive exact match) + if model_name in model_configs: + return model_name + + # Check case-insensitively for both base models and aliases + model_name_lower = model_name.lower() + + # Check base model names case-insensitively + for base_model in model_configs: + if base_model.lower() == model_name_lower: + return base_model + + # Check aliases from the model configurations + alias_map = ModelCapabilities.collect_aliases(model_configs) + for base_model, aliases in alias_map.items(): + if any(alias.lower() == model_name_lower for alias in aliases): + return base_model + + # If not found, return as-is + return model_name diff --git a/providers/custom.py b/providers/custom.py index 6a06457..63e6f8e 100644 --- a/providers/custom.py +++ b/providers/custom.py @@ -83,117 +83,69 @@ class CustomProvider(OpenAICompatibleProvider): aliases = self._registry.list_aliases() logging.info(f"Custom provider loaded {len(models)} models with {len(aliases)} aliases") - def _resolve_model_name(self, model_name: str) -> str: - """Resolve model aliases to actual model names. + # ------------------------------------------------------------------ + # Capability surface + # ------------------------------------------------------------------ + def _lookup_capabilities( + self, + canonical_name: str, + requested_name: Optional[str] = None, + ) -> Optional[ModelCapabilities]: + """Return custom capabilities from the registry or generic defaults.""" - For Ollama-style models, strips version tags (e.g., 'llama3.2:latest' -> 'llama3.2') - since the base model name is what's typically used in API calls. - - Args: - model_name: Input model name or alias - - Returns: - Resolved model name with version tags stripped if applicable - """ - # First, try to resolve through registry as-is - config = self._registry.resolve(model_name) - - if config: - if config.model_name != model_name: - logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'") - return config.model_name - else: - # If not found in registry, handle version tags for local models - # Strip version tags (anything after ':') for Ollama-style models - if ":" in model_name: - base_model = model_name.split(":")[0] - logging.debug(f"Stripped version tag from '{model_name}' -> '{base_model}'") - - # Try to resolve the base model through registry - base_config = self._registry.resolve(base_model) - if base_config: - logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'") - return base_config.model_name - else: - return base_model - else: - # If not found in registry and no version tag, return as-is - logging.debug(f"Model '{model_name}' not found in registry, using as-is") - return model_name - - def get_capabilities(self, model_name: str) -> ModelCapabilities: - """Get capabilities for a custom model. - - Args: - model_name: Name of the model (or alias) - - Returns: - ModelCapabilities from registry or generic defaults - """ - # Try to get from registry first - capabilities = self._registry.get_capabilities(model_name) + builtin = super()._lookup_capabilities(canonical_name, requested_name) + if builtin is not None: + return builtin + capabilities = self._registry.get_capabilities(canonical_name) if capabilities: - # Check if this is an OpenRouter model and apply restrictions - config = self._registry.resolve(model_name) - if config and not config.is_custom: - # This is an OpenRouter model, check restrictions - from utils.model_restrictions import get_restriction_service - - restriction_service = get_restriction_service() - if not restriction_service.is_allowed(ProviderType.OPENROUTER, config.model_name, model_name): - raise ValueError(f"OpenRouter model '{model_name}' is not allowed by restriction policy.") - - # Update provider type to OPENROUTER for OpenRouter models - capabilities.provider = ProviderType.OPENROUTER - else: - # Update provider type to CUSTOM for local custom models + config = self._registry.resolve(canonical_name) + if config and getattr(config, "is_custom", False): capabilities.provider = ProviderType.CUSTOM - return capabilities - else: - # Resolve any potential aliases and create generic capabilities - resolved_name = self._resolve_model_name(model_name) + return capabilities + # Non-custom models should fall through so OpenRouter handles them + return None - logging.debug( - f"Using generic capabilities for '{resolved_name}' via Custom API. " - "Consider adding to custom_models.json for specific capabilities." - ) + logging.debug( + f"Using generic capabilities for '{canonical_name}' via Custom API. " + "Consider adding to custom_models.json for specific capabilities." + ) - # Infer temperature behaviour for generic capability fallback - supports_temperature, temperature_constraint, temperature_reason = TemperatureConstraint.resolve_settings( - resolved_name - ) + supports_temperature, temperature_constraint, temperature_reason = TemperatureConstraint.resolve_settings( + canonical_name + ) - logging.warning( - f"Model '{resolved_name}' not found in custom_models.json. Using generic capabilities with inferred settings. " - f"Temperature support: {supports_temperature} ({temperature_reason}). " - "For better accuracy, add this model to your custom_models.json configuration." - ) + logging.warning( + f"Model '{canonical_name}' not found in custom_models.json. Using generic capabilities with inferred settings. " + f"Temperature support: {supports_temperature} ({temperature_reason}). " + "For better accuracy, add this model to your custom_models.json configuration." + ) - # Create generic capabilities with inferred defaults - capabilities = ModelCapabilities( - provider=ProviderType.CUSTOM, - model_name=resolved_name, - friendly_name=f"{self.FRIENDLY_NAME} ({resolved_name})", - context_window=32_768, # Conservative default - max_output_tokens=32_768, # Conservative default max output - supports_extended_thinking=False, # Most custom models don't support this - supports_system_prompts=True, - supports_streaming=True, - supports_function_calling=False, # Conservative default - supports_temperature=supports_temperature, - temperature_constraint=temperature_constraint, - ) - - # Mark as generic for validation purposes - capabilities._is_generic = True - - return capabilities + generic = ModelCapabilities( + provider=ProviderType.CUSTOM, + model_name=canonical_name, + friendly_name=f"{self.FRIENDLY_NAME} ({canonical_name})", + context_window=32_768, + max_output_tokens=32_768, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, + supports_temperature=supports_temperature, + temperature_constraint=temperature_constraint, + ) + generic._is_generic = True + return generic def get_provider_type(self) -> ProviderType: - """Get the provider type.""" + """Identify this provider for restriction and logging logic.""" + return ProviderType.CUSTOM + # ------------------------------------------------------------------ + # Validation + # ------------------------------------------------------------------ + def validate_model_name(self, model_name: str) -> bool: """Validate if the model name is allowed. @@ -206,49 +158,41 @@ class CustomProvider(OpenAICompatibleProvider): Returns: True if model is intended for custom/local endpoint """ - # logging.debug(f"Custom provider validating model: '{model_name}'") + if super().validate_model_name(model_name): + return True - # Try to resolve through registry first config = self._registry.resolve(model_name) - if config: - model_id = config.model_name - # Use explicit is_custom flag for clean validation - if config.is_custom: - logging.debug(f"... [Custom] Model '{model_name}' -> '{model_id}' validated via registry") - return True - else: - # This is a cloud/OpenRouter model - CustomProvider should NOT handle these - # Let OpenRouter provider handle them instead - # logging.debug(f"... [Custom] Model '{model_name}' -> '{model_id}' not custom (defer to OpenRouter)") - return False + if config and not getattr(config, "is_custom", False): + return False - # Handle version tags for unknown models (e.g., "my-model:latest") clean_model_name = model_name if ":" in model_name: - clean_model_name = model_name.split(":")[0] + clean_model_name = model_name.split(":", 1)[0] logging.debug(f"Stripped version tag from '{model_name}' -> '{clean_model_name}'") - # Try to resolve the clean name + + if super().validate_model_name(clean_model_name): + return True + config = self._registry.resolve(clean_model_name) - if config: - return self.validate_model_name(clean_model_name) # Recursively validate clean name + if config and not getattr(config, "is_custom", False): + return False - # For unknown models (not in registry), only accept if they look like local models - # This maintains backward compatibility for custom models not yet in the registry - - # Accept models with explicit local indicators in the name - if any(indicator in clean_model_name.lower() for indicator in ["local", "ollama", "vllm", "lmstudio"]): + lowered = clean_model_name.lower() + if any(indicator in lowered for indicator in ["local", "ollama", "vllm", "lmstudio"]): logging.debug(f"Model '{clean_model_name}' validated via local indicators") return True - # Accept simple model names without vendor prefix (likely local/custom models) if "/" not in clean_model_name: logging.debug(f"Model '{clean_model_name}' validated as potential local model (no vendor prefix)") return True - # Reject everything else (likely cloud models not in registry) logging.debug(f"Model '{model_name}' rejected by custom provider (appears to be cloud model)") return False + # ------------------------------------------------------------------ + # Request execution + # ------------------------------------------------------------------ + def generate_content( self, prompt: str, @@ -284,25 +228,41 @@ class CustomProvider(OpenAICompatibleProvider): **kwargs, ) - def get_model_configurations(self) -> dict[str, ModelCapabilities]: - """Get model configurations from the registry. + # ------------------------------------------------------------------ + # Registry helpers + # ------------------------------------------------------------------ - For CustomProvider, we convert registry configurations to ModelCapabilities objects. + def _resolve_model_name(self, model_name: str) -> str: + """Resolve registry aliases and strip version tags for local models.""" - Returns: - Dictionary mapping model names to their ModelCapabilities objects - """ + config = self._registry.resolve(model_name) + if config: + if config.model_name != model_name: + logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'") + return config.model_name - configs = {} + if ":" in model_name: + base_model = model_name.split(":")[0] + logging.debug(f"Stripped version tag from '{model_name}' -> '{base_model}'") - if self._registry: - # Get all models from registry - for model_name in self._registry.list_models(): - # Only include custom models that this provider validates - if self.validate_model_name(model_name): - config = self._registry.resolve(model_name) - if config and config.is_custom: - # Use ModelCapabilities directly from registry - configs[model_name] = config + base_config = self._registry.resolve(base_model) + if base_config: + logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'") + return base_config.model_name + return base_model - return configs + logging.debug(f"Model '{model_name}' not found in registry, using as-is") + return model_name + + def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]: + """Expose registry capabilities for models marked as custom.""" + + if not self._registry: + return {} + + capabilities: dict[str, ModelCapabilities] = {} + for model_name in self._registry.list_models(): + config = self._registry.resolve(model_name) + if config and getattr(config, "is_custom", False): + capabilities[model_name] = config + return capabilities diff --git a/providers/dial.py b/providers/dial.py index 1c0b885..db11417 100644 --- a/providers/dial.py +++ b/providers/dial.py @@ -261,68 +261,10 @@ class DIALModelProvider(OpenAICompatibleProvider): logger.info(f"Initialized DIAL provider with host: {dial_host} and api-version: {self.api_version}") - def get_capabilities(self, model_name: str) -> ModelCapabilities: - """Get capabilities for a specific model. - - Args: - model_name: Name of the model (can be shorthand) - - Returns: - ModelCapabilities object - - Raises: - ValueError: If model is not supported or not allowed - """ - resolved_name = self._resolve_model_name(model_name) - - if resolved_name not in self.MODEL_CAPABILITIES: - raise ValueError(f"Unsupported DIAL model: {model_name}") - - # Check restrictions - from utils.model_restrictions import get_restriction_service - - restriction_service = get_restriction_service() - if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name): - raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.") - - # Return the ModelCapabilities object directly from MODEL_CAPABILITIES - return self.MODEL_CAPABILITIES[resolved_name] - def get_provider_type(self) -> ProviderType: """Get the provider type.""" return ProviderType.DIAL - def validate_model_name(self, model_name: str) -> bool: - """Validate if the model name is supported. - - Args: - model_name: Model name to validate - - Returns: - True if model is supported and allowed, False otherwise - """ - resolved_name = self._resolve_model_name(model_name) - - if resolved_name not in self.MODEL_CAPABILITIES: - return False - - # Check against base class allowed_models if configured - if self.allowed_models is not None: - # Check both original and resolved names (case-insensitive) - if model_name.lower() not in self.allowed_models and resolved_name.lower() not in self.allowed_models: - logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' not in allowed_models list") - return False - - # Also check restrictions via ModelRestrictionService - from utils.model_restrictions import get_restriction_service - - restriction_service = get_restriction_service() - if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name): - logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' blocked by restrictions") - return False - - return True - def _get_deployment_client(self, deployment: str): """Get or create a cached client for a specific deployment. @@ -504,7 +446,7 @@ class DIALModelProvider(OpenAICompatibleProvider): f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}" ) - def close(self): + def close(self) -> None: """Clean up HTTP clients when provider is closed.""" logger.info("Closing DIAL provider HTTP clients...") diff --git a/providers/gemini.py b/providers/gemini.py index 952aab4..f333cf2 100644 --- a/providers/gemini.py +++ b/providers/gemini.py @@ -131,6 +131,19 @@ class GeminiModelProvider(ModelProvider): self._token_counters = {} # Cache for token counting self._base_url = kwargs.get("base_url", None) # Optional custom endpoint + # ------------------------------------------------------------------ + # Capability surface + # ------------------------------------------------------------------ + + def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]: + """Return statically defined Gemini capabilities.""" + + return dict(self.MODEL_CAPABILITIES) + + # ------------------------------------------------------------------ + # Client access + # ------------------------------------------------------------------ + @property def client(self): """Lazy initialization of Gemini client.""" @@ -146,25 +159,9 @@ class GeminiModelProvider(ModelProvider): self._client = genai.Client(api_key=self.api_key) return self._client - def get_capabilities(self, model_name: str) -> ModelCapabilities: - """Get capabilities for a specific Gemini model.""" - # Resolve shorthand - resolved_name = self._resolve_model_name(model_name) - - if resolved_name not in self.MODEL_CAPABILITIES: - raise ValueError(f"Unsupported Gemini model: {model_name}") - - # Check if model is allowed by restrictions - from utils.model_restrictions import get_restriction_service - - restriction_service = get_restriction_service() - # IMPORTANT: Parameter order is (provider_type, model_name, original_name) - # resolved_name is the canonical model name, model_name is the user input - if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name): - raise ValueError(f"Gemini model '{resolved_name}' is not allowed by restriction policy.") - - # Return the ModelCapabilities object directly from MODEL_CAPABILITIES - return self.MODEL_CAPABILITIES[resolved_name] + # ------------------------------------------------------------------ + # Request execution + # ------------------------------------------------------------------ def generate_content( self, @@ -365,26 +362,6 @@ class GeminiModelProvider(ModelProvider): """Get the provider type.""" return ProviderType.GOOGLE - def validate_model_name(self, model_name: str) -> bool: - """Validate if the model name is supported and allowed.""" - resolved_name = self._resolve_model_name(model_name) - - # First check if model is supported - if resolved_name not in self.MODEL_CAPABILITIES: - return False - - # Then check if model is allowed by restrictions - from utils.model_restrictions import get_restriction_service - - restriction_service = get_restriction_service() - # IMPORTANT: Parameter order is (provider_type, model_name, original_name) - # resolved_name is the canonical model name, model_name is the user input - if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name): - logger.debug(f"Gemini model '{model_name}' -> '{resolved_name}' blocked by restrictions") - return False - - return True - def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int: """Get actual thinking token budget for a model and thinking mode.""" resolved_name = self._resolve_model_name(model_name) diff --git a/providers/openai_compatible.py b/providers/openai_compatible.py index 98de109..8714186 100644 --- a/providers/openai_compatible.py +++ b/providers/openai_compatible.py @@ -5,7 +5,6 @@ import ipaddress import logging import os import time -from abc import abstractmethod from typing import Optional from urllib.parse import urlparse @@ -61,6 +60,33 @@ class OpenAICompatibleProvider(ModelProvider): "This may be insecure. Consider setting an API key for authentication." ) + def _ensure_model_allowed( + self, + capabilities: ModelCapabilities, + canonical_name: str, + requested_name: str, + ) -> None: + """Respect provider-specific allowlists before default restriction checks.""" + + super()._ensure_model_allowed(capabilities, canonical_name, requested_name) + + if self.allowed_models is not None: + requested = requested_name.lower() + canonical = canonical_name.lower() + + if requested not in self.allowed_models and canonical not in self.allowed_models: + raise ValueError( + f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}" + ) + + def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]: + """Return statically declared capabilities for OpenAI-compatible providers.""" + + model_map = getattr(self, "MODEL_CAPABILITIES", None) + if isinstance(model_map, dict): + return {k: v for k, v in model_map.items() if isinstance(v, ModelCapabilities)} + return {} + def _parse_allowed_models(self) -> Optional[set[str]]: """Parse allowed models from environment variable. @@ -686,30 +712,6 @@ class OpenAICompatibleProvider(ModelProvider): return super().count_tokens(text, model_name) - @abstractmethod - def get_capabilities(self, model_name: str) -> ModelCapabilities: - """Get capabilities for a specific model. - - Must be implemented by subclasses. - """ - pass - - @abstractmethod - def get_provider_type(self) -> ProviderType: - """Get the provider type. - - Must be implemented by subclasses. - """ - pass - - @abstractmethod - def validate_model_name(self, model_name: str) -> bool: - """Validate if the model name is supported. - - Must be implemented by subclasses. - """ - pass - def _is_error_retryable(self, error: Exception) -> bool: """Determine if an error should be retried based on structured error codes. diff --git a/providers/openai_provider.py b/providers/openai_provider.py index 29071f0..a032756 100644 --- a/providers/openai_provider.py +++ b/providers/openai_provider.py @@ -174,106 +174,61 @@ class OpenAIModelProvider(OpenAICompatibleProvider): kwargs.setdefault("base_url", "https://api.openai.com/v1") super().__init__(api_key, **kwargs) - def get_capabilities(self, model_name: str) -> ModelCapabilities: - """Get capabilities for a specific OpenAI model.""" - # First check if it's a key in MODEL_CAPABILITIES - if model_name in self.MODEL_CAPABILITIES: - self._check_model_restrictions(model_name, model_name) - return self.MODEL_CAPABILITIES[model_name] + # ------------------------------------------------------------------ + # Capability surface + # ------------------------------------------------------------------ - # Try resolving as alias - resolved_name = self._resolve_model_name(model_name) + def _lookup_capabilities( + self, + canonical_name: str, + requested_name: Optional[str] = None, + ) -> Optional[ModelCapabilities]: + """Look up OpenAI capabilities from built-ins or the custom registry.""" - # Check if resolved name is a key - if resolved_name in self.MODEL_CAPABILITIES: - self._check_model_restrictions(resolved_name, model_name) - return self.MODEL_CAPABILITIES[resolved_name] + builtin = super()._lookup_capabilities(canonical_name, requested_name) + if builtin is not None: + return builtin - # Finally check if resolved name matches any API model name - for key, capabilities in self.MODEL_CAPABILITIES.items(): - if resolved_name == capabilities.model_name: - self._check_model_restrictions(key, model_name) - return capabilities - - # Check custom models registry for user-configured OpenAI models try: from .openrouter_registry import OpenRouterModelRegistry registry = OpenRouterModelRegistry() - config = registry.get_model_config(resolved_name) + config = registry.get_model_config(canonical_name) if config and config.provider == ProviderType.OPENAI: - self._check_model_restrictions(config.model_name, model_name) - - # Update provider type to ensure consistency - config.provider = ProviderType.OPENAI return config - except Exception as e: - # Log but don't fail - registry might not be available - logger.debug(f"Could not check custom models registry for '{resolved_name}': {e}") + except Exception as exc: # pragma: no cover - registry failures are non-critical + logger.debug(f"Could not resolve custom OpenAI model '{canonical_name}': {exc}") + return None + + def _finalise_capabilities( + self, + capabilities: ModelCapabilities, + canonical_name: str, + requested_name: str, + ) -> ModelCapabilities: + """Ensure registry-sourced models report the correct provider type.""" + + if capabilities.provider != ProviderType.OPENAI: + capabilities.provider = ProviderType.OPENAI + return capabilities + + def _raise_unsupported_model(self, model_name: str) -> None: raise ValueError(f"Unsupported OpenAI model: {model_name}") - def _check_model_restrictions(self, provider_model_name: str, user_model_name: str) -> None: - """Check if a model is allowed by restriction policy. - - Args: - provider_model_name: The model name used by the provider - user_model_name: The model name requested by the user - - Raises: - ValueError: If the model is not allowed by restriction policy - """ - from utils.model_restrictions import get_restriction_service - - restriction_service = get_restriction_service() - if not restriction_service.is_allowed(ProviderType.OPENAI, provider_model_name, user_model_name): - raise ValueError(f"OpenAI model '{user_model_name}' is not allowed by restriction policy.") + # ------------------------------------------------------------------ + # Provider identity + # ------------------------------------------------------------------ def get_provider_type(self) -> ProviderType: """Get the provider type.""" return ProviderType.OPENAI - def validate_model_name(self, model_name: str) -> bool: - """Validate if the model name is supported and allowed.""" - resolved_name = self._resolve_model_name(model_name) - - # First, determine which model name to check against restrictions. - model_to_check = None - is_custom_model = False - - if resolved_name in self.MODEL_CAPABILITIES: - model_to_check = resolved_name - else: - # If not a built-in model, check the custom models registry. - try: - from .openrouter_registry import OpenRouterModelRegistry - - registry = OpenRouterModelRegistry() - config = registry.get_model_config(resolved_name) - - if config and config.provider == ProviderType.OPENAI: - model_to_check = config.model_name - is_custom_model = True - except Exception as e: - # Log but don't fail - registry might not be available. - logger.debug(f"Could not check custom models registry for '{resolved_name}': {e}") - - # If no model was found (neither built-in nor custom), it's invalid. - if not model_to_check: - return False - - # Now, perform the restriction check once. - from utils.model_restrictions import get_restriction_service - - restriction_service = get_restriction_service() - if not restriction_service.is_allowed(ProviderType.OPENAI, model_to_check, model_name): - model_type = "custom " if is_custom_model else "" - logger.debug(f"OpenAI {model_type}model '{model_name}' -> '{resolved_name}' blocked by restrictions") - return False - - return True + # ------------------------------------------------------------------ + # Request execution + # ------------------------------------------------------------------ def generate_content( self, @@ -298,6 +253,10 @@ class OpenAIModelProvider(OpenAICompatibleProvider): **kwargs, ) + # ------------------------------------------------------------------ + # Provider preferences + # ------------------------------------------------------------------ + def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]: """Get OpenAI's preferred model for a given category from allowed models. diff --git a/providers/openrouter.py b/providers/openrouter.py index 1a60542..b4b9d6a 100644 --- a/providers/openrouter.py +++ b/providers/openrouter.py @@ -61,108 +61,52 @@ class OpenRouterProvider(OpenAICompatibleProvider): aliases = self._registry.list_aliases() logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases") - def _resolve_model_name(self, model_name: str) -> str: - """Resolve model aliases to OpenRouter model names. + # ------------------------------------------------------------------ + # Capability surface + # ------------------------------------------------------------------ - Args: - model_name: Input model name or alias - - Returns: - Resolved OpenRouter model name - """ - # Try to resolve through registry - config = self._registry.resolve(model_name) - - if config: - if config.model_name != model_name: - logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'") - return config.model_name - else: - # If not found in registry, return as-is - # This allows using models not in our config file - logging.debug(f"Model '{model_name}' not found in registry, using as-is") - return model_name - - def get_capabilities(self, model_name: str) -> ModelCapabilities: - """Get capabilities for a model. - - Args: - model_name: Name of the model (or alias) - - Returns: - ModelCapabilities from registry or generic defaults - """ - # Try to get from registry first - capabilities = self._registry.get_capabilities(model_name) + def _lookup_capabilities( + self, + canonical_name: str, + requested_name: Optional[str] = None, + ) -> Optional[ModelCapabilities]: + """Fetch OpenRouter capabilities from the registry or build a generic fallback.""" + capabilities = self._registry.get_capabilities(canonical_name) if capabilities: return capabilities - else: - # Resolve any potential aliases and create generic capabilities - resolved_name = self._resolve_model_name(model_name) - logging.debug( - f"Using generic capabilities for '{resolved_name}' via OpenRouter. " - "Consider adding to custom_models.json for specific capabilities." - ) + logging.debug( + f"Using generic capabilities for '{canonical_name}' via OpenRouter. " + "Consider adding to custom_models.json for specific capabilities." + ) - # Create generic capabilities with conservative defaults - capabilities = ModelCapabilities( - provider=ProviderType.OPENROUTER, - model_name=resolved_name, - friendly_name=self.FRIENDLY_NAME, - context_window=32_768, # Conservative default context window - max_output_tokens=32_768, - supports_extended_thinking=False, - supports_system_prompts=True, - supports_streaming=True, - supports_function_calling=False, - temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0), - ) + generic = ModelCapabilities( + provider=ProviderType.OPENROUTER, + model_name=canonical_name, + friendly_name=self.FRIENDLY_NAME, + context_window=32_768, + max_output_tokens=32_768, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, + temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0), + ) + generic._is_generic = True + return generic - # Mark as generic for validation purposes - capabilities._is_generic = True - - return capabilities + # ------------------------------------------------------------------ + # Provider identity + # ------------------------------------------------------------------ def get_provider_type(self) -> ProviderType: - """Get the provider type.""" + """Identify this provider for restrictions and logging.""" return ProviderType.OPENROUTER - def validate_model_name(self, model_name: str) -> bool: - """Validate if the model name is allowed. - - As the catch-all provider, OpenRouter accepts any model name that wasn't - handled by higher-priority providers. OpenRouter will validate based on - the API key's permissions and local restrictions. - - Args: - model_name: Model name to validate - - Returns: - True if model is allowed, False if restricted - """ - # Check model restrictions if configured - from utils.model_restrictions import get_restriction_service - - restriction_service = get_restriction_service() - if restriction_service: - # Check if model name itself is allowed - if restriction_service.is_allowed(self.get_provider_type(), model_name): - return True - - # Also check aliases - model_name might be an alias - model_config = self._registry.resolve(model_name) - if model_config and model_config.aliases: - for alias in model_config.aliases: - if restriction_service.is_allowed(self.get_provider_type(), alias): - return True - - # If restrictions are configured and model/alias not in allowed list, reject - return False - - # No restrictions configured - accept any model name as the fallback provider - return True + # ------------------------------------------------------------------ + # Request execution + # ------------------------------------------------------------------ def generate_content( self, @@ -204,6 +148,10 @@ class OpenRouterProvider(OpenAICompatibleProvider): **kwargs, ) + # ------------------------------------------------------------------ + # Registry helpers + # ------------------------------------------------------------------ + def list_models( self, *, @@ -227,6 +175,12 @@ class OpenRouterProvider(OpenAICompatibleProvider): if not config: continue + # Custom models belong to CustomProvider; skip them here so the two + # providers don't race over the same registrations (important for tests + # that stub the registry with minimal objects lacking attrs). + if hasattr(config, "is_custom") and config.is_custom is True: + continue + if restriction_service: allowed = restriction_service.is_allowed(self.get_provider_type(), model_name) @@ -255,24 +209,37 @@ class OpenRouterProvider(OpenAICompatibleProvider): unique=unique, ) - def get_model_configurations(self) -> dict[str, ModelCapabilities]: - """Get model configurations from the registry. + # ------------------------------------------------------------------ + # Registry helpers + # ------------------------------------------------------------------ - For OpenRouter, we convert registry configurations to ModelCapabilities objects. + def _resolve_model_name(self, model_name: str) -> str: + """Resolve aliases defined in the OpenRouter registry.""" - Returns: - Dictionary mapping model names to their ModelCapabilities objects - """ - configs = {} + config = self._registry.resolve(model_name) + if config: + if config.model_name != model_name: + logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'") + return config.model_name - if self._registry: - # Get all models from registry - for model_name in self._registry.list_models(): - # Only include models that this provider validates - if self.validate_model_name(model_name): - config = self._registry.resolve(model_name) - if config and not config.is_custom: # Only OpenRouter models, not custom ones - # Use ModelCapabilities directly from registry - configs[model_name] = config + logging.debug(f"Model '{model_name}' not found in registry, using as-is") + return model_name - return configs + def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]: + """Expose registry-backed OpenRouter capabilities.""" + + if not self._registry: + return {} + + capabilities: dict[str, ModelCapabilities] = {} + for model_name in self._registry.list_models(): + config = self._registry.resolve(model_name) + if not config: + continue + + # See note in list_models: respect the CustomProvider boundary. + if hasattr(config, "is_custom") and config.is_custom is True: + continue + + capabilities[model_name] = config + return capabilities diff --git a/providers/registry.py b/providers/registry.py index 917fd37..6f412ff 100644 --- a/providers/registry.py +++ b/providers/registry.py @@ -64,6 +64,8 @@ class ModelProviderRegistry: """ instance = cls() instance._providers[provider_type] = provider_class + # Invalidate any cached instance so subsequent lookups use the new registration + instance._initialized_providers.pop(provider_type, None) @classmethod def get_provider(cls, provider_type: ProviderType, force_new: bool = False) -> Optional[ModelProvider]: diff --git a/providers/xai.py b/providers/xai.py index a9b2387..c03bc57 100644 --- a/providers/xai.py +++ b/providers/xai.py @@ -85,46 +85,10 @@ class XAIModelProvider(OpenAICompatibleProvider): kwargs.setdefault("base_url", "https://api.x.ai/v1") super().__init__(api_key, **kwargs) - def get_capabilities(self, model_name: str) -> ModelCapabilities: - """Get capabilities for a specific X.AI model.""" - # Resolve shorthand - resolved_name = self._resolve_model_name(model_name) - - if resolved_name not in self.MODEL_CAPABILITIES: - raise ValueError(f"Unsupported X.AI model: {model_name}") - - # Check if model is allowed by restrictions - from utils.model_restrictions import get_restriction_service - - restriction_service = get_restriction_service() - if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name): - raise ValueError(f"X.AI model '{model_name}' is not allowed by restriction policy.") - - # Return the ModelCapabilities object directly from MODEL_CAPABILITIES - return self.MODEL_CAPABILITIES[resolved_name] - def get_provider_type(self) -> ProviderType: """Get the provider type.""" return ProviderType.XAI - def validate_model_name(self, model_name: str) -> bool: - """Validate if the model name is supported and allowed.""" - resolved_name = self._resolve_model_name(model_name) - - # First check if model is supported - if resolved_name not in self.MODEL_CAPABILITIES: - return False - - # Then check if model is allowed by restrictions - from utils.model_restrictions import get_restriction_service - - restriction_service = get_restriction_service() - if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name): - logger.debug(f"X.AI model '{model_name}' -> '{resolved_name}' blocked by restrictions") - return False - - return True - def generate_content( self, prompt: str, diff --git a/tests/test_custom_provider.py b/tests/test_custom_provider.py index 4c300f9..4f7ca30 100644 --- a/tests/test_custom_provider.py +++ b/tests/test_custom_provider.py @@ -54,11 +54,9 @@ class TestCustomProvider: provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1") - # Test with a model that should be in the registry (OpenRouter model) - capabilities = provider.get_capabilities("o3") # o3 is an OpenRouter model - - assert capabilities.provider == ProviderType.OPENROUTER # o3 is an OpenRouter model (is_custom=false) - assert capabilities.context_window > 0 + # OpenRouter-backed models should be handled by the OpenRouter provider + with pytest.raises(ValueError): + provider.get_capabilities("o3") # Test with a custom model (is_custom=true) capabilities = provider.get_capabilities("local-llama") @@ -168,7 +166,13 @@ class TestCustomProviderRegistration: return CustomProvider(api_key="", base_url="http://localhost:11434/v1") with patch.dict( - os.environ, {"OPENROUTER_API_KEY": "test-openrouter-key", "CUSTOM_API_PLACEHOLDER": "configured"} + os.environ, + { + "OPENROUTER_API_KEY": "test-openrouter-key", + "CUSTOM_API_PLACEHOLDER": "configured", + "OPENROUTER_ALLOWED_MODELS": "llama,anthropic/claude-opus-4.1", + }, + clear=True, ): # Register both providers ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider) @@ -195,18 +199,22 @@ class TestCustomProviderRegistration: return CustomProvider(api_key="", base_url="http://localhost:11434/v1") with patch.dict( - os.environ, {"OPENROUTER_API_KEY": "test-openrouter-key", "CUSTOM_API_PLACEHOLDER": "configured"} + os.environ, + { + "OPENROUTER_API_KEY": "test-openrouter-key", + "CUSTOM_API_PLACEHOLDER": "configured", + "OPENROUTER_ALLOWED_MODELS": "", + }, + clear=True, ): - # Register OpenRouter first (higher priority) - ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider) - ModelProviderRegistry.register_provider(ProviderType.CUSTOM, custom_provider_factory) + import utils.model_restrictions - # Test model resolution - OpenRouter should win for shared aliases - provider_for_model = ModelProviderRegistry.get_provider_for_model("llama") + utils.model_restrictions._restriction_service = None + custom_provider = custom_provider_factory() + openrouter_provider = OpenRouterProvider(api_key="test-openrouter-key") - # OpenRouter should be selected first due to registration order - assert provider_for_model is not None - # The exact provider type depends on which validates the model first + assert not custom_provider.validate_model_name("llama") + assert openrouter_provider.validate_model_name("llama") class TestConfigureProvidersFunction: diff --git a/tests/test_dial_provider.py b/tests/test_dial_provider.py index 9062a18..6c6f6a7 100644 --- a/tests/test_dial_provider.py +++ b/tests/test_dial_provider.py @@ -121,7 +121,7 @@ class TestDIALProvider: """Test that get_capabilities raises for invalid models.""" provider = DIALModelProvider("test-key") - with pytest.raises(ValueError, match="Unsupported DIAL model"): + with pytest.raises(ValueError, match="Unsupported model 'invalid-model' for provider dial"): provider.get_capabilities("invalid-model") @patch("utils.model_restrictions.get_restriction_service") diff --git a/tests/test_model_restrictions.py b/tests/test_model_restrictions.py index 8aaf620..4277463 100644 --- a/tests/test_model_restrictions.py +++ b/tests/test_model_restrictions.py @@ -356,15 +356,13 @@ class TestCustomProviderOpenRouterRestrictions: provider = CustomProvider(base_url="http://test.com/v1") - # For OpenRouter models, get_capabilities should still work but mark them as OPENROUTER - # This tests the capabilities lookup, not validation - capabilities = provider.get_capabilities("opus") - assert capabilities.provider == ProviderType.OPENROUTER + # For OpenRouter models, CustomProvider should defer by raising + with pytest.raises(ValueError): + provider.get_capabilities("opus") - # Should raise for disallowed OpenRouter model - with pytest.raises(ValueError) as exc_info: + # Should raise for disallowed OpenRouter model (still defers) + with pytest.raises(ValueError): provider.get_capabilities("haiku") - assert "not allowed by restriction policy" in str(exc_info.value) # Should still work for custom models (is_custom=true) capabilities = provider.get_capabilities("local-llama") diff --git a/tests/test_xai_provider.py b/tests/test_xai_provider.py index f3e3a76..392be5b 100644 --- a/tests/test_xai_provider.py +++ b/tests/test_xai_provider.py @@ -141,7 +141,7 @@ class TestXAIProvider: """Test error handling for unsupported models.""" provider = XAIModelProvider("test-key") - with pytest.raises(ValueError, match="Unsupported X.AI model"): + with pytest.raises(ValueError, match="Unsupported model 'invalid-model' for provider xai"): provider.get_capabilities("invalid-model") def test_extended_thinking_flags(self): diff --git a/tools/listmodels.py b/tools/listmodels.py index 7bde7f2..18e94aa 100644 --- a/tools/listmodels.py +++ b/tools/listmodels.py @@ -105,13 +105,11 @@ class ListModelsTool(BaseTool): output_lines.append("**Status**: Configured and available") output_lines.append("\n**Models**:") - # Get models from the provider's model configurations - for model_name, capabilities in provider.get_model_configurations().items(): - # Get description and context from the ModelCapabilities object + aliases = [] + for model_name, capabilities in provider.get_all_model_capabilities().items(): description = capabilities.description or "No description available" context_window = capabilities.context_window - # Format context window if context_window >= 1_000_000: context_str = f"{context_window // 1_000_000}M context" elif context_window >= 1_000: @@ -120,31 +118,15 @@ class ListModelsTool(BaseTool): context_str = f"{context_window} context" if context_window > 0 else "unknown context" output_lines.append(f"- `{model_name}` - {context_str}") + output_lines.append(f" - {description}") - # Extract key capability from description - if "Ultra-fast" in description: - output_lines.append(" - Fast processing, quick iterations") - elif "Deep reasoning" in description: - output_lines.append(" - Extended reasoning with thinking mode") - elif "Strong reasoning" in description: - output_lines.append(" - Logical problems, systematic analysis") - elif "EXTREMELY EXPENSIVE" in description: - output_lines.append(" - ⚠️ Professional grade (very expensive)") - elif "Advanced reasoning" in description: - output_lines.append(" - Advanced reasoning and complex analysis") - - # Show aliases for this provider - aliases = [] - for model_name, capabilities in provider.get_model_configurations().items(): - if capabilities.aliases: - for alias in capabilities.aliases: - # Skip aliases that are the same as the model name to avoid duplicates - if alias != model_name: - aliases.append(f"- `{alias}` → `{model_name}`") + for alias in capabilities.aliases or []: + if alias != model_name: + aliases.append(f"- `{alias}` → `{model_name}`") if aliases: output_lines.append("\n**Aliases**:") - output_lines.extend(sorted(aliases)) # Sort for consistent output + output_lines.extend(sorted(aliases)) else: output_lines.append(f"**Status**: Not configured (set {info['env_key']})")