From 6d237d09709f757a042baf655f47eb4ddfc078ad Mon Sep 17 00:00:00 2001 From: Fahad Date: Thu, 2 Oct 2025 10:25:41 +0400 Subject: [PATCH] refactor: moved temperature method from base provider to model capabilities refactor: model listing cleanup, moved logic to model_capabilities.py docs: added AGENTS.md for onboarding Codex --- providers/base.py | 108 +++++----- providers/custom.py | 19 +- providers/openrouter.py | 109 +++++----- providers/openrouter_registry.py | 19 +- providers/registry.py | 19 +- providers/shared/model_capabilities.py | 80 +++++++- tests/test_alias_target_restrictions.py | 39 ++-- tests/test_buggy_behavior_prevention.py | 213 +++++++++---------- tests/test_model_restrictions.py | 80 ++++++-- tests/test_old_behavior_simulation.py | 216 -------------------- tests/test_openai_compatible_token_usage.py | 5 +- tests/test_openrouter_provider.py | 17 +- tests/test_supported_models_aliases.py | 18 +- utils/model_restrictions.py | 30 ++- 14 files changed, 460 insertions(+), 512 deletions(-) delete mode 100644 tests/test_old_behavior_simulation.py diff --git a/providers/base.py b/providers/base.py index 0f3a7f0..93bbe62 100644 --- a/providers/base.py +++ b/providers/base.py @@ -18,13 +18,26 @@ logger = logging.getLogger(__name__) class ModelProvider(ABC): - """Defines the contract implemented by every model provider backend. + """Abstract base class for all model backends in the MCP server. - Subclasses adapt third-party SDKs into the MCP server by exposing - capability metadata, request execution, and token counting through a - consistent interface. Shared helper methods (temperature validation, - alias resolution, image handling, etc.) live here so individual providers - only need to focus on provider-specific details. + Role + Defines the interface every provider must implement so the registry, + restriction service, and tools have a uniform surface for listing + models, resolving aliases, and executing requests. + + Responsibilities + * expose static capability metadata for each supported model via + :class:`ModelCapabilities` + * accept user prompts, forward them to the underlying SDK, and wrap + responses in :class:`ModelResponse` + * report tokenizer counts for budgeting and validation logic + * advertise provider identity (``ProviderType``) so restriction + policies can map environment configuration onto providers + * validate whether a model name or alias is recognised by the provider + + Shared helpers like temperature validation, alias resolution, and + restriction-aware ``list_models`` live here so concrete subclasses only + need to supply their catalogue and wire up SDK-specific behaviour. """ # All concrete providers must define their supported models @@ -151,67 +164,52 @@ class ModelProvider(ABC): # If not found, return as-is return model_name - def list_models(self, respect_restrictions: bool = True) -> list[str]: - """Return a list of model names supported by this provider. - - This implementation uses the get_model_configurations() hook - to support different model configuration sources. + def list_models( + self, + *, + respect_restrictions: bool = True, + include_aliases: bool = True, + lowercase: bool = False, + unique: bool = False, + ) -> list[str]: + """Return formatted model names supported by this provider. Args: - respect_restrictions: Whether to apply provider-specific restriction logic. + respect_restrictions: Apply provider restriction policy. + include_aliases: Include aliases alongside canonical model names. + lowercase: Normalize returned names to lowercase. + unique: Deduplicate names after formatting. Returns: - List of model names available from this provider + List of model names formatted according to the provided options. """ - from utils.model_restrictions import get_restriction_service - restriction_service = get_restriction_service() if respect_restrictions else None - models = [] - - # Get model configurations from the hook method model_configs = self.get_model_configurations() + if not model_configs: + return [] - for model_name in model_configs: - # Check restrictions if enabled - if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name): - continue + restriction_service = None + if respect_restrictions: + from utils.model_restrictions import get_restriction_service - # Add the base model - models.append(model_name) + restriction_service = get_restriction_service() - # Add aliases derived from the model configurations - alias_map = ModelCapabilities.collect_aliases(model_configs) - for model_name, aliases in alias_map.items(): - # Only add aliases for models that passed restriction check - if model_name in models: - models.extend(aliases) + if restriction_service: + allowed_configs = {} + for model_name, config in model_configs.items(): + if restriction_service.is_allowed(self.get_provider_type(), model_name): + allowed_configs[model_name] = config + model_configs = allowed_configs - return models + if not model_configs: + return [] - def list_all_known_models(self) -> list[str]: - """Return all model names known by this provider, including alias targets. - - This is used for validation purposes to ensure restriction policies - can validate against both aliases and their target model names. - - Returns: - List of all model names and alias targets known by this provider - """ - all_models = set() - - # Get model configurations from the hook method - model_configs = self.get_model_configurations() - - # Add all base model names - for model_name in model_configs: - all_models.add(model_name.lower()) - - # Add aliases derived from the model configurations - for aliases in ModelCapabilities.collect_aliases(model_configs).values(): - for alias in aliases: - all_models.add(alias.lower()) - - return list(all_models) + return ModelCapabilities.collect_model_names( + model_configs, + include_aliases=include_aliases, + lowercase=lowercase, + unique=unique, + ) def validate_image(self, image_path: str, max_size_mb: float = None) -> tuple[bytes, str]: """Provider-independent image validation. diff --git a/providers/custom.py b/providers/custom.py index 3f6f813..a4bad33 100644 --- a/providers/custom.py +++ b/providers/custom.py @@ -32,11 +32,20 @@ _TEMP_UNSUPPORTED_KEYWORDS = [ class CustomProvider(OpenAICompatibleProvider): """Adapter for self-hosted or local OpenAI-compatible endpoints. - The provider reuses the :mod:`providers.shared` registry to surface - user-defined aliases and capability metadata. It also normalises - Ollama-style version tags (``model:latest``) and enforces the same - restriction policies used by cloud providers, ensuring consistent - behaviour regardless of where the model is hosted. + Role + Provide a uniform bridge between the MCP server and user-managed + OpenAI-compatible services (Ollama, vLLM, LM Studio, bespoke gateways). + By subclassing :class:`OpenAICompatibleProvider` it inherits request and + token handling, while the custom registry exposes locally defined model + metadata. + + Notable behaviour + * Uses :class:`OpenRouterModelRegistry` to load model definitions and + aliases so custom deployments share the same metadata pipeline as + OpenRouter itself. + * Normalises version-tagged model names (``model:latest``) and applies + restriction policies just like cloud providers, ensuring consistent + behaviour across environments. """ FRIENDLY_NAME = "Custom API" diff --git a/providers/openrouter.py b/providers/openrouter.py index fdbbc62..67f5990 100644 --- a/providers/openrouter.py +++ b/providers/openrouter.py @@ -17,9 +17,19 @@ from .shared import ( class OpenRouterProvider(OpenAICompatibleProvider): """Client for OpenRouter's multi-model aggregation service. - OpenRouter surfaces dozens of upstream vendors. This provider layers alias - resolution, restriction-aware filtering, and sensible capability defaults - on top of the generic OpenAI-compatible plumbing. + Role + Surface OpenRouter’s dynamic catalogue through the same interface as + native providers so tools can reference OpenRouter models and aliases + without special cases. + + Characteristics + * Pulls live model definitions from :class:`OpenRouterModelRegistry` + (aliases, provider-specific metadata, capability hints) + * Applies alias-aware restriction checks before exposing models to the + registry or tooling + * Reuses :class:`OpenAICompatibleProvider` infrastructure for request + execution so OpenRouter endpoints behave like standard OpenAI-style + APIs. """ FRIENDLY_NAME = "OpenRouter" @@ -208,75 +218,56 @@ class OpenRouterProvider(OpenAICompatibleProvider): """ return False - def list_models(self, respect_restrictions: bool = True) -> list[str]: - """Return a list of model names supported by this provider. + def list_models( + self, + *, + respect_restrictions: bool = True, + include_aliases: bool = True, + lowercase: bool = False, + unique: bool = False, + ) -> list[str]: + """Return formatted OpenRouter model names, respecting alias-aware restrictions.""" - Args: - respect_restrictions: Whether to apply provider-specific restriction logic. + if not self._registry: + return [] - Returns: - List of model names available from this provider - """ from utils.model_restrictions import get_restriction_service restriction_service = get_restriction_service() if respect_restrictions else None - models = [] + allowed_configs: dict[str, ModelCapabilities] = {} - if self._registry: - for model_name in self._registry.list_models(): - # ===================================================================================== - # CRITICAL ALIAS-AWARE RESTRICTION CHECKING (Fixed Issue #98) - # ===================================================================================== - # Previously, restrictions only checked full model names (e.g., "google/gemini-2.5-pro") - # but users specify aliases in OPENROUTER_ALLOWED_MODELS (e.g., "pro"). - # This caused "no models available" error even with valid restrictions. - # - # Fix: Check both model name AND all aliases against restrictions - # TEST COVERAGE: tests/test_provider_routing_bugs.py::TestOpenRouterAliasRestrictions - # ===================================================================================== - if restriction_service: - # Get model config to check aliases as well - model_config = self._registry.resolve(model_name) - allowed = False + for model_name in self._registry.list_models(): + config = self._registry.resolve(model_name) + if not config: + continue - # Check if model name itself is allowed - if restriction_service.is_allowed(self.get_provider_type(), model_name): - allowed = True + if restriction_service: + allowed = restriction_service.is_allowed(self.get_provider_type(), model_name) - # CRITICAL: Also check aliases - this fixes the alias restriction bug - if not allowed and model_config and model_config.aliases: - for alias in model_config.aliases: - if restriction_service.is_allowed(self.get_provider_type(), alias): - allowed = True - break + if not allowed and config.aliases: + for alias in config.aliases: + if restriction_service.is_allowed(self.get_provider_type(), alias): + allowed = True + break - if not allowed: - continue + if not allowed: + continue - models.append(model_name) + allowed_configs[model_name] = config - return models + if not allowed_configs: + return [] - def list_all_known_models(self) -> list[str]: - """Return all model names known by this provider, including alias targets. + # When restrictions are in place, don't include aliases to avoid confusion + # Only return the canonical model names that are actually allowed + actual_include_aliases = include_aliases and not respect_restrictions - Returns: - List of all model names and alias targets known by this provider - """ - all_models = set() - - if self._registry: - # Get all models and aliases from the registry - all_models.update(model.lower() for model in self._registry.list_models()) - all_models.update(alias.lower() for alias in self._registry.list_aliases()) - - # For each alias, also add its target - for alias in self._registry.list_aliases(): - config = self._registry.resolve(alias) - if config: - all_models.add(config.model_name.lower()) - - return list(all_models) + return ModelCapabilities.collect_model_names( + allowed_configs, + include_aliases=actual_include_aliases, + lowercase=lowercase, + unique=unique, + ) def get_model_configurations(self) -> dict[str, ModelCapabilities]: """Get model configurations from the registry. diff --git a/providers/openrouter_registry.py b/providers/openrouter_registry.py index 9e1dbf1..a7cbfb2 100644 --- a/providers/openrouter_registry.py +++ b/providers/openrouter_registry.py @@ -17,12 +17,21 @@ from .shared import ( class OpenRouterModelRegistry: - """Loads and validates the OpenRouter/custom model catalogue. + """In-memory view of OpenRouter and custom model metadata. - The registry parses ``conf/custom_models.json`` (or an override supplied via - environment variable), builds case-insensitive alias maps, and exposes - :class:`~providers.shared.ModelCapabilities` objects used by several - providers. + Role + Parse the packaged ``conf/custom_models.json`` (or user-specified + overrides), construct alias and capability maps, and serve those + structures to providers that rely on OpenRouter semantics (both the + OpenRouter provider itself and the Custom provider). + + Key duties + * Load :class:`ModelCapabilities` definitions from configuration files + * Maintain a case-insensitive alias → canonical name map for fast + resolution + * Provide helpers to list models, list aliases, and resolve an arbitrary + name to its capability object without repeatedly touching the file + system. """ def __init__(self, config_path: Optional[str] = None): diff --git a/providers/registry.py b/providers/registry.py index 0783f8f..917fd37 100644 --- a/providers/registry.py +++ b/providers/registry.py @@ -12,11 +12,22 @@ if TYPE_CHECKING: class ModelProviderRegistry: - """Singleton that caches provider instances and coordinates priority order. + """Central catalogue of provider implementations used by the MCP server. - Responsibilities include resolving API keys from the environment, lazily - instantiating providers, and choosing the best provider for a model based - on restriction policies and provider priority. + Role + Holds the mapping between :class:`ProviderType` values and concrete + :class:`ModelProvider` subclasses/factories. At runtime the registry + is responsible for instantiating providers, caching them for reuse, and + mediating lookup of providers and model names in provider priority + order. + + Core responsibilities + * Resolve API keys and other runtime configuration for each provider + * Lazily create provider instances so unused backends incur no cost + * Expose convenience methods for enumerating available models and + locating which provider can service a requested model name or alias + * Honour the project-wide provider priority policy so namespaces (or + alias collisions) are resolved deterministically. """ _instance = None diff --git a/providers/shared/model_capabilities.py b/providers/shared/model_capabilities.py index 33f1af7..02c2d1c 100644 --- a/providers/shared/model_capabilities.py +++ b/providers/shared/model_capabilities.py @@ -11,24 +11,46 @@ __all__ = ["ModelCapabilities"] @dataclass class ModelCapabilities: - """Static capabilities and constraints for a provider-managed model.""" + """Static description of what a model can do within a provider. + + Role + Acts as the canonical record for everything the server needs to know + about a model—its provider, token limits, feature switches, aliases, + and temperature rules. Providers populate these objects so tools and + higher-level services can rely on a consistent schema. + + Typical usage + * Provider subclasses declare `MODEL_CAPABILITIES` maps containing these + objects (for example ``OpenAIModelProvider``) + * Helper utilities (e.g. restriction validation, alias expansion) read + these objects to build model lists for tooling and policy enforcement + * Tool selection logic inspects attributes such as + ``supports_extended_thinking`` or ``context_window`` to choose an + appropriate model for a task. + """ provider: ProviderType model_name: str friendly_name: str - context_window: int - max_output_tokens: int + description: str = "" + aliases: list[str] = field(default_factory=list) + + # Capacity limits / resource budgets + context_window: int = 0 + max_output_tokens: int = 0 + max_thinking_tokens: int = 0 + + # Capability flags supports_extended_thinking: bool = False supports_system_prompts: bool = True supports_streaming: bool = True supports_function_calling: bool = False supports_images: bool = False - max_image_size_mb: float = 0.0 - supports_temperature: bool = True - description: str = "" - aliases: list[str] = field(default_factory=list) supports_json_mode: bool = False - max_thinking_tokens: int = 0 + supports_temperature: bool = True + + # Additional attributes + max_image_size_mb: float = 0.0 is_custom: bool = False temperature_constraint: TemperatureConstraint = field( default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3) @@ -56,3 +78,45 @@ class ModelCapabilities: for base_model, capabilities in model_configs.items() if capabilities.aliases } + + @staticmethod + def collect_model_names( + model_configs: dict[str, "ModelCapabilities"], + *, + include_aliases: bool = True, + lowercase: bool = False, + unique: bool = False, + ) -> list[str]: + """Build an ordered list of model names and aliases. + + Args: + model_configs: Mapping of canonical model names to capabilities. + include_aliases: When True, include aliases for each model. + lowercase: When True, normalize names to lowercase. + unique: When True, ensure each returned name appears once (after formatting). + + Returns: + Ordered list of model names (and optionally aliases) formatted per options. + """ + + formatted_names: list[str] = [] + seen: set[str] | None = set() if unique else None + + def append_name(name: str) -> None: + formatted = name.lower() if lowercase else name + + if seen is not None: + if formatted in seen: + return + seen.add(formatted) + + formatted_names.append(formatted) + + for base_model, capabilities in model_configs.items(): + append_name(base_model) + + if include_aliases and capabilities.aliases: + for alias in capabilities.aliases: + append_name(alias) + + return formatted_names diff --git a/tests/test_alias_target_restrictions.py b/tests/test_alias_target_restrictions.py index f3dbd82..e779acd 100644 --- a/tests/test_alias_target_restrictions.py +++ b/tests/test_alias_target_restrictions.py @@ -22,7 +22,7 @@ class TestAliasTargetRestrictions: provider = OpenAIModelProvider(api_key="test-key") # Get all known models including aliases and targets - all_known = provider.list_all_known_models() + all_known = provider.list_models(respect_restrictions=False, include_aliases=True, lowercase=True, unique=True) # Should include both aliases and their targets assert "mini" in all_known # alias @@ -35,7 +35,7 @@ class TestAliasTargetRestrictions: provider = GeminiModelProvider(api_key="test-key") # Get all known models including aliases and targets - all_known = provider.list_all_known_models() + all_known = provider.list_models(respect_restrictions=False, include_aliases=True, lowercase=True, unique=True) # Should include both aliases and their targets assert "flash" in all_known # alias @@ -162,7 +162,9 @@ class TestAliasTargetRestrictions: """ # Test OpenAI provider openai_provider = OpenAIModelProvider(api_key="test-key") - openai_all_known = openai_provider.list_all_known_models() + openai_all_known = openai_provider.list_models( + respect_restrictions=False, include_aliases=True, lowercase=True, unique=True + ) # Verify that for each alias, its target is also included for model_name, config in openai_provider.MODEL_CAPABILITIES.items(): @@ -175,7 +177,9 @@ class TestAliasTargetRestrictions: # Test Gemini provider gemini_provider = GeminiModelProvider(api_key="test-key") - gemini_all_known = gemini_provider.list_all_known_models() + gemini_all_known = gemini_provider.list_models( + respect_restrictions=False, include_aliases=True, lowercase=True, unique=True + ) # Verify that for each alias, its target is also included for model_name, config in gemini_provider.MODEL_CAPABILITIES.items(): @@ -186,8 +190,8 @@ class TestAliasTargetRestrictions: config.lower() in gemini_all_known ), f"Target '{config}' for alias '{model_name}' not in known models" - def test_no_duplicate_models_in_list_all_known_models(self): - """Test that list_all_known_models doesn't return duplicates.""" + def test_no_duplicate_models_in_alias_aware_listing(self): + """Test that alias-aware list_models variant doesn't return duplicates.""" # Test all providers providers = [ OpenAIModelProvider(api_key="test-key"), @@ -195,7 +199,9 @@ class TestAliasTargetRestrictions: ] for provider in providers: - all_known = provider.list_all_known_models() + all_known = provider.list_models( + respect_restrictions=False, include_aliases=True, lowercase=True, unique=True + ) # Should not have duplicates assert len(all_known) == len(set(all_known)), f"{provider.__class__.__name__} returns duplicate models" @@ -207,7 +213,7 @@ class TestAliasTargetRestrictions: from unittest.mock import MagicMock mock_provider = MagicMock() - mock_provider.list_all_known_models.return_value = ["model1", "model2", "target-model"] + mock_provider.list_models.return_value = ["model1", "model2", "target-model"] # Set up a restriction that should trigger validation service.restrictions = {ProviderType.OPENAI: {"invalid-model"}} @@ -218,7 +224,12 @@ class TestAliasTargetRestrictions: service.validate_against_known_models(provider_instances) # Verify the polymorphic method was called - mock_provider.list_all_known_models.assert_called_once() + mock_provider.list_models.assert_called_once_with( + respect_restrictions=False, + include_aliases=True, + lowercase=True, + unique=True, + ) @patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini"}) # Restrict to specific model def test_complex_alias_chains_handled_correctly(self): @@ -250,7 +261,7 @@ class TestAliasTargetRestrictions: - A restriction on "o4-mini" (target) would not be recognized as valid After the fix: - - list_all_known_models() returns ["mini", "o3mini", "o4-mini", "o3-mini"] (aliases + targets) + - list_models(respect_restrictions=False, include_aliases=True, lowercase=True, unique=True) returns ["mini", "o3mini", "o4-mini", "o3-mini"] (aliases + targets) - validate_against_known_models() checks against all names - A restriction on "o4-mini" is recognized as valid """ @@ -262,7 +273,7 @@ class TestAliasTargetRestrictions: provider_instances = {ProviderType.OPENAI: provider} # Get all known models - should include BOTH aliases AND targets - all_known = provider.list_all_known_models() + all_known = provider.list_models(respect_restrictions=False, include_aliases=True, lowercase=True, unique=True) # Critical check: should contain both aliases and their targets assert "mini" in all_known # alias @@ -310,7 +321,7 @@ class TestAliasTargetRestrictions: the restriction is properly enforced and the target is recognized as a valid model to restrict. - The bug: If list_all_known_models() doesn't include targets, then validation + The bug: If list_models(respect_restrictions=False, include_aliases=True, lowercase=True, unique=True) doesn't include targets, then validation would incorrectly warn that target model names are "not recognized", making it appear that target-based restrictions don't work. """ @@ -325,7 +336,9 @@ class TestAliasTargetRestrictions: provider = OpenAIModelProvider(api_key="test-key") # These specific target models should be recognized as valid - all_known = provider.list_all_known_models() + all_known = provider.list_models( + respect_restrictions=False, include_aliases=True, lowercase=True, unique=True + ) assert "o4-mini" in all_known, "Target model o4-mini should be known" assert "o3-mini" in all_known, "Target model o3-mini should be known" diff --git a/tests/test_buggy_behavior_prevention.py b/tests/test_buggy_behavior_prevention.py index bfc26a0..c8c3cdd 100644 --- a/tests/test_buggy_behavior_prevention.py +++ b/tests/test_buggy_behavior_prevention.py @@ -1,12 +1,9 @@ """ -Tests that demonstrate the OLD BUGGY BEHAVIOR is now FIXED. +Regression scenarios ensuring alias-aware model listings stay correct. -These tests verify that scenarios which would have incorrectly passed -before our fix now behave correctly. Each test documents the specific -bug that was fixed and what the old vs new behavior should be. - -IMPORTANT: These tests PASS with our fix, but would have FAILED to catch -bugs with the old code (before list_all_known_models was implemented). +Each test captures behavior that previously regressed so we can guard it +permanently. The focus is confirming aliases and their canonical targets +remain visible to the restriction service and related validation logic. """ import os @@ -21,42 +18,34 @@ from utils.model_restrictions import ModelRestrictionService class TestBuggyBehaviorPrevention: - """ - These tests prove that our fix prevents the HIGH-severity regression - that was identified by the O3 precommit analysis. + """Regression tests for alias-aware restriction validation.""" - OLD BUG: list_models() only returned alias keys, not targets - FIX: list_all_known_models() returns both aliases AND targets - """ - - def test_old_bug_would_miss_target_restrictions(self): - """ - OLD BUG: If restriction was set on target model (e.g., 'o4-mini'), - validation would incorrectly warn it's not recognized because - list_models() only returned aliases ['mini', 'o3mini']. - - NEW BEHAVIOR: list_all_known_models() includes targets, so 'o4-mini' - is recognized as valid and no warning is generated. - """ + def test_alias_listing_includes_targets_for_restriction_validation(self): + """Alias-aware lists expose both aliases and canonical targets.""" provider = OpenAIModelProvider(api_key="test-key") - # This is what the old broken list_models() would return - aliases only - old_broken_list = ["mini", "o3mini"] # Missing 'o4-mini', 'o3-mini' targets + # Baseline alias-only list captured for regression documentation + alias_only_snapshot = ["mini", "o3mini"] # Missing 'o4-mini', 'o3-mini' targets - # This is what our fixed list_all_known_models() returns - new_fixed_list = provider.list_all_known_models() + # Canonical listing with aliases and targets + comprehensive_list = provider.list_models( + respect_restrictions=False, + include_aliases=True, + lowercase=True, + unique=True, + ) - # Verify the fix: new method includes both aliases AND targets - assert "mini" in new_fixed_list # alias - assert "o4-mini" in new_fixed_list # target - THIS WAS MISSING IN OLD CODE - assert "o3mini" in new_fixed_list # alias - assert "o3-mini" in new_fixed_list # target - THIS WAS MISSING IN OLD CODE + # Comprehensive listing should contain aliases and their targets + assert "mini" in comprehensive_list + assert "o4-mini" in comprehensive_list + assert "o3mini" in comprehensive_list + assert "o3-mini" in comprehensive_list - # Prove the old behavior was broken - assert "o4-mini" not in old_broken_list # Old code didn't include targets - assert "o3-mini" not in old_broken_list # Old code didn't include targets + # Legacy alias-only snapshots exclude targets + assert "o4-mini" not in alias_only_snapshot + assert "o3-mini" not in alias_only_snapshot - # This target validation would have FAILED with old code + # This scenario previously failed when targets were omitted service = ModelRestrictionService() service.restrictions = {ProviderType.OPENAI: {"o4-mini"}} # Restrict to target @@ -64,24 +53,19 @@ class TestBuggyBehaviorPrevention: provider_instances = {ProviderType.OPENAI: provider} service.validate_against_known_models(provider_instances) - # NEW BEHAVIOR: No warnings because o4-mini is now in list_all_known_models + # No warnings expected because alias-aware list includes the target target_warnings = [ call for call in mock_logger.warning.call_args_list if "o4-mini" in str(call) and "not a recognized" in str(call) ] - assert len(target_warnings) == 0, "o4-mini should be recognized with our fix" + assert len(target_warnings) == 0, "o4-mini should be recognized as a valid target" - def test_old_bug_would_incorrectly_warn_about_valid_targets(self): - """ - OLD BUG: Admins setting restrictions on target models would get - false warnings that their restriction models are "not recognized". - - NEW BEHAVIOR: Target models are properly recognized. - """ + def test_target_models_are_recognized_during_validation(self): + """Target model restrictions should not trigger false warnings.""" # Test with Gemini provider too provider = GeminiModelProvider(api_key="test-key") - all_known = provider.list_all_known_models() + all_known = provider.list_models(respect_restrictions=False, include_aliases=True, lowercase=True, unique=True) # Verify both aliases and targets are included assert "flash" in all_known # alias @@ -108,13 +92,8 @@ class TestBuggyBehaviorPrevention: assert "gemini-2.5-flash" not in warning or "not a recognized" not in warning assert "gemini-2.5-pro" not in warning or "not a recognized" not in warning - def test_old_bug_policy_bypass_prevention(self): - """ - OLD BUG: Policy enforcement was incomplete because validation - didn't know about target models. This could allow policy bypasses. - - NEW BEHAVIOR: Complete validation against all known model names. - """ + def test_policy_enforcement_remains_comprehensive(self): + """Policy validation must account for both aliases and targets.""" provider = OpenAIModelProvider(api_key="test-key") # Simulate a scenario where admin wants to restrict specific targets @@ -138,64 +117,85 @@ class TestBuggyBehaviorPrevention: # But o4mini (the actual alias for o4-mini) should work assert provider.validate_model_name("o4mini") # Resolves to o4-mini, which IS allowed - # Verify our list_all_known_models includes the restricted models - all_known = provider.list_all_known_models() + # Verify our alias-aware list includes the restricted models + all_known = provider.list_models( + respect_restrictions=False, + include_aliases=True, + lowercase=True, + unique=True, + ) assert "o3-mini" in all_known # Should be known (and allowed) assert "o4-mini" in all_known # Should be known (and allowed) assert "o3-pro" in all_known # Should be known (but blocked) assert "mini" in all_known # Should be known (and allowed since it resolves to o4-mini) - def test_demonstration_of_old_vs_new_interface(self): - """ - Direct comparison of old vs new interface to document the fix. - """ + def test_alias_aware_listing_extends_canonical_view(self): + """Alias-aware list should be a superset of restriction-filtered names.""" provider = OpenAIModelProvider(api_key="test-key") - # OLD interface (still exists for backward compatibility) - old_style_models = provider.list_models(respect_restrictions=False) + baseline_models = provider.list_models(respect_restrictions=False) - # NEW interface (our fix) - new_comprehensive_models = provider.list_all_known_models() + alias_aware_models = provider.list_models( + respect_restrictions=False, + include_aliases=True, + lowercase=True, + unique=True, + ) - # The new interface should be a superset of the old one - for model in old_style_models: + # Alias-aware variant should contain everything from the baseline + for model in baseline_models: assert model.lower() in [ - m.lower() for m in new_comprehensive_models - ], f"New interface missing model {model} from old interface" + m.lower() for m in alias_aware_models + ], f"Alias-aware listing missing baseline model {model}" - # The new interface should include target models that old one might miss - targets_that_should_exist = ["o4-mini", "o3-mini"] - for target in targets_that_should_exist: - assert target in new_comprehensive_models, f"New interface should include target model {target}" + # Alias-aware variant should include canonical targets as well + for target in ("o4-mini", "o3-mini"): + assert target in alias_aware_models, f"Alias-aware listing should include target model {target}" - def test_old_validation_interface_still_works(self): - """ - Verify our fix doesn't break existing validation workflows. - """ + def test_restriction_validation_uses_alias_aware_variant(self): + """Validation should request the alias-aware lowercased, deduped list.""" service = ModelRestrictionService() - # Create a mock provider that simulates the old behavior - old_style_provider = MagicMock() - old_style_provider.MODEL_CAPABILITIES = { + # Simulate a provider that only returns aliases when asked for models + alias_only_provider = MagicMock() + alias_only_provider.MODEL_CAPABILITIES = { "mini": "o4-mini", "o3mini": "o3-mini", "o4-mini": {"context_window": 200000}, "o3-mini": {"context_window": 200000}, } - # OLD BROKEN: This would only return aliases - old_style_provider.list_models.return_value = ["mini", "o3mini"] - # NEW FIXED: This includes both aliases and targets - old_style_provider.list_all_known_models.return_value = ["mini", "o3mini", "o4-mini", "o3-mini"] + + # Simulate alias-only vs. alias-aware behavior using a side effect + def list_models_side_effect(**kwargs): + respect_restrictions = kwargs.get("respect_restrictions", True) + include_aliases = kwargs.get("include_aliases", True) + lowercase = kwargs.get("lowercase", False) + unique = kwargs.get("unique", False) + + if respect_restrictions and include_aliases and not lowercase and not unique: + return ["mini", "o3mini"] + + if not respect_restrictions and include_aliases and lowercase and unique: + return ["mini", "o3mini", "o4-mini", "o3-mini"] + + raise AssertionError(f"Unexpected list_models call: {kwargs}") + + alias_only_provider.list_models.side_effect = list_models_side_effect # Test that validation now uses the comprehensive method service.restrictions = {ProviderType.OPENAI: {"o4-mini"}} # Restrict to target with patch("utils.model_restrictions.logger") as mock_logger: - provider_instances = {ProviderType.OPENAI: old_style_provider} + provider_instances = {ProviderType.OPENAI: alias_only_provider} service.validate_against_known_models(provider_instances) - # Verify the new method was called, not the old one - old_style_provider.list_all_known_models.assert_called_once() + # Verify the alias-aware variant was used + alias_only_provider.list_models.assert_called_with( + respect_restrictions=False, + include_aliases=True, + lowercase=True, + unique=True, + ) # Should not warn about o4-mini being unrecognized target_warnings = [ @@ -205,17 +205,17 @@ class TestBuggyBehaviorPrevention: ] assert len(target_warnings) == 0 - def test_regression_proof_comprehensive_coverage(self): - """ - Comprehensive test to prove our fix covers all provider types. - """ + def test_alias_listing_covers_targets_for_all_providers(self): + """Alias-aware listings should expose targets across providers.""" providers_to_test = [ (OpenAIModelProvider(api_key="test-key"), "mini", "o4-mini"), (GeminiModelProvider(api_key="test-key"), "flash", "gemini-2.5-flash"), ] for provider, alias, target in providers_to_test: - all_known = provider.list_all_known_models() + all_known = provider.list_models( + respect_restrictions=False, include_aliases=True, lowercase=True, unique=True + ) # Every provider should include both aliases and targets assert alias in all_known, f"{provider.__class__.__name__} missing alias {alias}" @@ -226,13 +226,7 @@ class TestBuggyBehaviorPrevention: @patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini,invalid-model"}) def test_validation_correctly_identifies_invalid_models(self): - """ - Test that validation still catches truly invalid models while - properly recognizing valid target models. - - This proves our fix works: o4-mini appears in the "Known models" list - because list_all_known_models() now includes target models. - """ + """Validation should flag invalid models while listing valid targets.""" # Clear cached restriction service import utils.model_restrictions @@ -245,7 +239,6 @@ class TestBuggyBehaviorPrevention: provider_instances = {ProviderType.OPENAI: provider} service.validate_against_known_models(provider_instances) - # Should warn about 'invalid-model' (truly invalid) invalid_warnings = [ call for call in mock_logger.warning.call_args_list @@ -253,39 +246,37 @@ class TestBuggyBehaviorPrevention: ] assert len(invalid_warnings) > 0, "Should warn about truly invalid models" - # The warning should mention o4-mini in the "Known models" list (proving our fix works) + # The warning should mention o4-mini in the known models list warning_text = str(mock_logger.warning.call_args_list[0]) assert "Known models:" in warning_text, "Warning should include known models list" - assert "o4-mini" in warning_text, "o4-mini should appear in known models (proves our fix works)" - assert "o3-mini" in warning_text, "o3-mini should appear in known models (proves our fix works)" + assert "o4-mini" in warning_text, "o4-mini should appear in known models" + assert "o3-mini" in warning_text, "o3-mini should appear in known models" # But the warning should be specifically about invalid-model assert "'invalid-model'" in warning_text, "Warning should specifically mention invalid-model" - def test_custom_provider_also_implements_fix(self): - """ - Verify that custom provider also implements the comprehensive interface. - """ + def test_custom_provider_alias_listing(self): + """Custom provider should expose alias-aware listings as well.""" from providers.custom import CustomProvider # This might fail if no URL is set, but that's expected try: provider = CustomProvider(base_url="http://test.com/v1") - all_known = provider.list_all_known_models() + all_known = provider.list_models( + respect_restrictions=False, include_aliases=True, lowercase=True, unique=True + ) # Should return a list (might be empty if registry not loaded) assert isinstance(all_known, list) except ValueError: # Expected if no base_url configured, skip this test pytest.skip("Custom provider requires URL configuration") - def test_openrouter_provider_also_implements_fix(self): - """ - Verify that OpenRouter provider also implements the comprehensive interface. - """ + def test_openrouter_provider_alias_listing(self): + """OpenRouter provider should expose alias-aware listings.""" from providers.openrouter import OpenRouterProvider provider = OpenRouterProvider(api_key="test-key") - all_known = provider.list_all_known_models() + all_known = provider.list_models(respect_restrictions=False, include_aliases=True, lowercase=True, unique=True) # Should return a list with both aliases and targets assert isinstance(all_known, list) diff --git a/tests/test_model_restrictions.py b/tests/test_model_restrictions.py index f2eb430..8aaf620 100644 --- a/tests/test_model_restrictions.py +++ b/tests/test_model_restrictions.py @@ -142,7 +142,7 @@ class TestModelRestrictionService: "o3-mini": {"context_window": 200000}, "o4-mini": {"context_window": 200000}, } - mock_provider.list_all_known_models.return_value = ["o3", "o3-mini", "o4-mini"] + mock_provider.list_models.return_value = ["o3", "o3-mini", "o4-mini"] provider_instances = {ProviderType.OPENAI: mock_provider} service.validate_against_known_models(provider_instances) @@ -447,7 +447,13 @@ class TestRegistryIntegration: } mock_openai.get_provider_type.return_value = ProviderType.OPENAI - def openai_list_models(respect_restrictions=True): + def openai_list_models( + *, + respect_restrictions: bool = True, + include_aliases: bool = True, + lowercase: bool = False, + unique: bool = False, + ): from utils.model_restrictions import get_restriction_service restriction_service = get_restriction_service() if respect_restrictions else None @@ -457,15 +463,26 @@ class TestRegistryIntegration: target_model = config if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model): continue - models.append(model_name) + if include_aliases: + models.append(model_name) else: if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, model_name): continue models.append(model_name) + if lowercase: + models = [m.lower() for m in models] + if unique: + seen = set() + ordered = [] + for name in models: + if name in seen: + continue + seen.add(name) + ordered.append(name) + models = ordered return models - mock_openai.list_models = openai_list_models - mock_openai.list_all_known_models.return_value = ["o3", "o3-mini"] + mock_openai.list_models = MagicMock(side_effect=openai_list_models) mock_gemini = MagicMock() mock_gemini.MODEL_CAPABILITIES = { @@ -474,7 +491,13 @@ class TestRegistryIntegration: } mock_gemini.get_provider_type.return_value = ProviderType.GOOGLE - def gemini_list_models(respect_restrictions=True): + def gemini_list_models( + *, + respect_restrictions: bool = True, + include_aliases: bool = True, + lowercase: bool = False, + unique: bool = False, + ): from utils.model_restrictions import get_restriction_service restriction_service = get_restriction_service() if respect_restrictions else None @@ -484,18 +507,26 @@ class TestRegistryIntegration: target_model = config if restriction_service and not restriction_service.is_allowed(ProviderType.GOOGLE, target_model): continue - models.append(model_name) + if include_aliases: + models.append(model_name) else: if restriction_service and not restriction_service.is_allowed(ProviderType.GOOGLE, model_name): continue models.append(model_name) + if lowercase: + models = [m.lower() for m in models] + if unique: + seen = set() + ordered = [] + for name in models: + if name in seen: + continue + seen.add(name) + ordered.append(name) + models = ordered return models - mock_gemini.list_models = gemini_list_models - mock_gemini.list_all_known_models.return_value = [ - "gemini-2.5-pro", - "gemini-2.5-flash", - ] + mock_gemini.list_models = MagicMock(side_effect=gemini_list_models) def get_provider_side_effect(provider_type): if provider_type == ProviderType.OPENAI: @@ -615,7 +646,13 @@ class TestAutoModeWithRestrictions: } mock_openai.get_provider_type.return_value = ProviderType.OPENAI - def openai_list_models(respect_restrictions=True): + def openai_list_models( + *, + respect_restrictions: bool = True, + include_aliases: bool = True, + lowercase: bool = False, + unique: bool = False, + ): from utils.model_restrictions import get_restriction_service restriction_service = get_restriction_service() if respect_restrictions else None @@ -625,15 +662,26 @@ class TestAutoModeWithRestrictions: target_model = config if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model): continue - models.append(model_name) + if include_aliases: + models.append(model_name) else: if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, model_name): continue models.append(model_name) + if lowercase: + models = [m.lower() for m in models] + if unique: + seen = set() + ordered = [] + for name in models: + if name in seen: + continue + seen.add(name) + ordered.append(name) + models = ordered return models - mock_openai.list_models = openai_list_models - mock_openai.list_all_known_models.return_value = ["o3", "o3-mini", "o4-mini"] + mock_openai.list_models = MagicMock(side_effect=openai_list_models) # Add get_preferred_model method to mock to match new implementation def get_preferred_model(category, allowed_models): diff --git a/tests/test_old_behavior_simulation.py b/tests/test_old_behavior_simulation.py deleted file mode 100644 index dc4719a..0000000 --- a/tests/test_old_behavior_simulation.py +++ /dev/null @@ -1,216 +0,0 @@ -""" -Tests that simulate the OLD BROKEN BEHAVIOR to prove it was indeed broken. - -These tests create mock providers that behave like the old code (before our fix) -and demonstrate that they would have failed to catch the HIGH-severity bug. - -IMPORTANT: These tests show what WOULD HAVE HAPPENED with the old code. -They prove that our fix was necessary and actually addresses real problems. -""" - -from unittest.mock import MagicMock, patch - -from providers.shared import ProviderType -from utils.model_restrictions import ModelRestrictionService - - -class TestOldBehaviorSimulation: - """ - Simulate the old broken behavior to prove it was buggy. - """ - - def test_old_behavior_would_miss_target_restrictions(self): - """ - SIMULATION: This test recreates the OLD BROKEN BEHAVIOR and proves it was buggy. - - OLD BUG: When validation service called provider.list_models(), it only got - aliases back, not targets. This meant target-based restrictions weren't validated. - """ - # Create a mock provider that simulates the OLD BROKEN BEHAVIOR - old_broken_provider = MagicMock() - old_broken_provider.MODEL_CAPABILITIES = { - "mini": "o4-mini", # alias -> target - "o3mini": "o3-mini", # alias -> target - "o4-mini": {"context_window": 200000}, - "o3-mini": {"context_window": 200000}, - } - - # OLD BROKEN: list_models only returned aliases, missing targets - old_broken_provider.list_models.return_value = ["mini", "o3mini"] - - # OLD BROKEN: There was no list_all_known_models method! - # We simulate this by making it behave like the old list_models - old_broken_provider.list_all_known_models.return_value = ["mini", "o3mini"] # MISSING TARGETS! - - # Now test what happens when admin tries to restrict by target model - service = ModelRestrictionService() - service.restrictions = {ProviderType.OPENAI: {"o4-mini"}} # Restrict to target model - - with patch("utils.model_restrictions.logger") as mock_logger: - provider_instances = {ProviderType.OPENAI: old_broken_provider} - service.validate_against_known_models(provider_instances) - - # OLD BROKEN BEHAVIOR: Would warn about o4-mini being "not recognized" - # because it wasn't in the list_all_known_models response - target_warnings = [ - call - for call in mock_logger.warning.call_args_list - if "o4-mini" in str(call) and "not a recognized" in str(call) - ] - - # This proves the old behavior was broken - it would generate false warnings - assert len(target_warnings) > 0, "OLD BROKEN BEHAVIOR: Would incorrectly warn about valid target models" - - # Verify the warning message shows the broken list - warning_text = str(target_warnings[0]) - assert "mini" in warning_text # Alias was included - assert "o3mini" in warning_text # Alias was included - # But targets were missing from the known models list in old behavior - - def test_new_behavior_fixes_the_problem(self): - """ - Compare old vs new behavior to show our fix works. - """ - # Create mock provider with NEW FIXED BEHAVIOR - new_fixed_provider = MagicMock() - new_fixed_provider.MODEL_CAPABILITIES = { - "mini": "o4-mini", - "o3mini": "o3-mini", - "o4-mini": {"context_window": 200000}, - "o3-mini": {"context_window": 200000}, - } - - # NEW FIXED: list_all_known_models includes BOTH aliases AND targets - new_fixed_provider.list_all_known_models.return_value = [ - "mini", - "o3mini", # aliases - "o4-mini", - "o3-mini", # targets - THESE WERE MISSING IN OLD CODE! - ] - - # Same restriction scenario - service = ModelRestrictionService() - service.restrictions = {ProviderType.OPENAI: {"o4-mini"}} # Restrict to target model - - with patch("utils.model_restrictions.logger") as mock_logger: - provider_instances = {ProviderType.OPENAI: new_fixed_provider} - service.validate_against_known_models(provider_instances) - - # NEW FIXED BEHAVIOR: No warnings about o4-mini being unrecognized - target_warnings = [ - call - for call in mock_logger.warning.call_args_list - if "o4-mini" in str(call) and "not a recognized" in str(call) - ] - - # Our fix prevents false warnings - assert len(target_warnings) == 0, "NEW FIXED BEHAVIOR: Should not warn about valid target models" - - def test_policy_bypass_prevention_old_vs_new(self): - """ - Show how the old behavior could have led to policy bypass scenarios. - """ - # OLD BROKEN: Admin thinks they've restricted access to o4-mini, - # but validation doesn't recognize it as a valid restriction target - old_broken_provider = MagicMock() - old_broken_provider.list_all_known_models.return_value = ["mini", "o3mini"] # Missing targets - - # NEW FIXED: Same provider with our fix - new_fixed_provider = MagicMock() - new_fixed_provider.list_all_known_models.return_value = ["mini", "o3mini", "o4-mini", "o3-mini"] - - # Test restriction on target model - use completely separate service instances - old_service = ModelRestrictionService() - old_service.restrictions = {ProviderType.OPENAI: {"o4-mini", "completely-invalid-model"}} - - new_service = ModelRestrictionService() - new_service.restrictions = {ProviderType.OPENAI: {"o4-mini", "completely-invalid-model"}} - - # OLD BEHAVIOR: Would warn about BOTH models being unrecognized - with patch("utils.model_restrictions.logger") as mock_logger_old: - provider_instances = {ProviderType.OPENAI: old_broken_provider} - old_service.validate_against_known_models(provider_instances) - - old_warnings = [str(call) for call in mock_logger_old.warning.call_args_list] - print(f"OLD warnings: {old_warnings}") # Debug output - - # NEW BEHAVIOR: Only warns about truly invalid model - with patch("utils.model_restrictions.logger") as mock_logger_new: - provider_instances = {ProviderType.OPENAI: new_fixed_provider} - new_service.validate_against_known_models(provider_instances) - - new_warnings = [str(call) for call in mock_logger_new.warning.call_args_list] - print(f"NEW warnings: {new_warnings}") # Debug output - - # For now, just verify that we get some warnings in both cases - # The key point is that the "Known models" list is different - assert len(old_warnings) > 0, "OLD: Should have warnings" - assert len(new_warnings) > 0, "NEW: Should have warnings for invalid model" - - # Verify the known models list is different between old and new - str(old_warnings[0]) if old_warnings else "" - new_warning_text = str(new_warnings[0]) if new_warnings else "" - - if "Known models:" in new_warning_text: - # NEW behavior should include o4-mini in known models list - assert "o4-mini" in new_warning_text, "NEW: Should include o4-mini in known models" - - print("This test demonstrates that our fix improves the 'Known models' list shown to users.") - - def test_demonstrate_target_coverage_improvement(self): - """ - Show the exact improvement in target model coverage. - """ - # Simulate different provider implementations - providers_old_vs_new = [ - # (old_broken_list, new_fixed_list, provider_name) - (["mini", "o3mini"], ["mini", "o3mini", "o4-mini", "o3-mini"], "OpenAI"), - ( - ["flash", "pro"], - ["flash", "pro", "gemini-2.5-flash", "gemini-2.5-pro"], - "Gemini", - ), - ] - - for old_list, new_list, provider_name in providers_old_vs_new: - # Count how many additional models are now covered - old_coverage = set(old_list) - new_coverage = set(new_list) - - additional_coverage = new_coverage - old_coverage - - # There should be additional target models covered - assert len(additional_coverage) > 0, f"{provider_name}: Should have additional target coverage" - - # All old models should still be covered - assert old_coverage.issubset(new_coverage), f"{provider_name}: Should maintain backward compatibility" - - print(f"{provider_name} provider:") - print(f" Old coverage: {sorted(old_coverage)}") - print(f" New coverage: {sorted(new_coverage)}") - print(f" Additional models: {sorted(additional_coverage)}") - - def test_comprehensive_alias_target_mapping_verification(self): - """ - Verify that our fix provides comprehensive alias->target coverage. - """ - from providers.gemini import GeminiModelProvider - from providers.openai_provider import OpenAIModelProvider - - # Test real providers to ensure they implement our fix correctly - providers = [OpenAIModelProvider(api_key="test-key"), GeminiModelProvider(api_key="test-key")] - - for provider in providers: - all_known = provider.list_all_known_models() - - # Check that every model and its aliases appear in the comprehensive list - for model_name, config in provider.MODEL_CAPABILITIES.items(): - assert model_name.lower() in all_known, f"{provider.__class__.__name__}: Missing model {model_name}" - - for alias in getattr(config, "aliases", []): - assert ( - alias.lower() in all_known - ), f"{provider.__class__.__name__}: Missing alias {alias} for model {model_name}" - assert ( - provider._resolve_model_name(alias) == model_name - ), f"{provider.__class__.__name__}: Alias {alias} should resolve to {model_name}" diff --git a/tests/test_openai_compatible_token_usage.py b/tests/test_openai_compatible_token_usage.py index 276ee55..37a288f 100644 --- a/tests/test_openai_compatible_token_usage.py +++ b/tests/test_openai_compatible_token_usage.py @@ -26,10 +26,7 @@ class TestOpenAICompatibleTokenUsage(unittest.TestCase): def validate_model_name(self, model_name): return True - def list_models(self, respect_restrictions=True): - return ["test-model"] - - def list_all_known_models(self): + def list_models(self, **kwargs): return ["test-model"] self.provider = TestProvider("test-key") diff --git a/tests/test_openrouter_provider.py b/tests/test_openrouter_provider.py index 057bfde..0731646 100644 --- a/tests/test_openrouter_provider.py +++ b/tests/test_openrouter_provider.py @@ -151,7 +151,7 @@ class TestOpenRouterAutoMode: os.environ["DEFAULT_MODEL"] = "auto" mock_registry = Mock() - mock_registry.list_models.return_value = [ + model_names = [ "google/gemini-2.5-flash", "google/gemini-2.5-pro", "openai/o3", @@ -159,6 +159,18 @@ class TestOpenRouterAutoMode: "anthropic/claude-opus-4.1", "anthropic/claude-sonnet-4.1", ] + mock_registry.list_models.return_value = model_names + + # Mock resolve to return a ModelCapabilities-like object for each model + def mock_resolve(model_name): + if model_name in model_names: + mock_config = Mock() + mock_config.is_custom = False + mock_config.aliases = [] # Empty list of aliases + return mock_config + return None + + mock_registry.resolve.side_effect = mock_resolve ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider) @@ -171,8 +183,7 @@ class TestOpenRouterAutoMode: assert len(available_models) > 0, "Should find OpenRouter models in auto mode" assert all(provider_type == ProviderType.OPENROUTER for provider_type in available_models.values()) - expected_models = mock_registry.list_models.return_value - for model in expected_models: + for model in model_names: assert model in available_models, f"Model {model} should be available" @pytest.mark.no_mock_provider diff --git a/tests/test_supported_models_aliases.py b/tests/test_supported_models_aliases.py index efc1716..e3b55d5 100644 --- a/tests/test_supported_models_aliases.py +++ b/tests/test_supported_models_aliases.py @@ -151,11 +151,16 @@ class TestSupportedModelsAliases: assert "o3-2025-04-16" in dial_models assert "o3" in dial_models - def test_list_all_known_models_includes_aliases(self): - """Test that list_all_known_models returns all models and aliases in lowercase.""" + def test_list_models_all_known_variant_includes_aliases(self): + """Unified list_models should support lowercase, alias-inclusive listings.""" # Test Gemini gemini_provider = GeminiModelProvider("test-key") - gemini_all = gemini_provider.list_all_known_models() + gemini_all = gemini_provider.list_models( + respect_restrictions=False, + include_aliases=True, + lowercase=True, + unique=True, + ) assert "gemini-2.5-flash" in gemini_all assert "flash" in gemini_all assert "gemini-2.5-pro" in gemini_all @@ -165,7 +170,12 @@ class TestSupportedModelsAliases: # Test OpenAI openai_provider = OpenAIModelProvider("test-key") - openai_all = openai_provider.list_all_known_models() + openai_all = openai_provider.list_models( + respect_restrictions=False, + include_aliases=True, + lowercase=True, + unique=True, + ) assert "o4-mini" in openai_all assert "mini" in openai_all assert "o3-mini" in openai_all diff --git a/utils/model_restrictions.py b/utils/model_restrictions.py index 2e3a7f3..8b0984e 100644 --- a/utils/model_restrictions.py +++ b/utils/model_restrictions.py @@ -30,13 +30,20 @@ logger = logging.getLogger(__name__) class ModelRestrictionService: - """ - Centralized service for managing model usage restrictions. + """Central authority for environment-driven model allowlists. - This service: - 1. Loads restrictions from environment variables at startup - 2. Validates restrictions against known models - 3. Provides a simple interface to check if a model is allowed + Role + Interpret ``*_ALLOWED_MODELS`` environment variables, keep their + entries normalised (lowercase), and answer whether a provider/model + pairing is permitted. + + Responsibilities + * Parse, cache, and expose per-provider restriction sets + * Validate configuration by cross-checking each entry against the + provider’s alias-aware model list + * Offer helper methods such as ``is_allowed`` and ``filter_models`` to + enforce policy everywhere model names appear (tool selection, CLI + commands, etc.). """ # Environment variable names @@ -94,9 +101,14 @@ class ModelRestrictionService: # Get all supported models using the clean polymorphic interface try: - # Use list_all_known_models to get both aliases and their targets - all_models = provider.list_all_known_models() - supported_models = {model.lower() for model in all_models} + # Gather canonical models and aliases with consistent formatting + all_models = provider.list_models( + respect_restrictions=False, + include_aliases=True, + lowercase=True, + unique=True, + ) + supported_models = set(all_models) except Exception as e: logger.debug(f"Could not get model list from {provider_type.value} provider: {e}") supported_models = set()