refactor: moved temperature method from base provider to model capabilities

refactor: model listing cleanup, moved logic to model_capabilities.py
docs: added AGENTS.md for onboarding Codex
This commit is contained in:
Fahad
2025-10-02 10:25:41 +04:00
parent f461cb4519
commit 6d237d0970
14 changed files with 460 additions and 512 deletions

View File

@@ -18,13 +18,26 @@ logger = logging.getLogger(__name__)
class ModelProvider(ABC): class ModelProvider(ABC):
"""Defines the contract implemented by every model provider backend. """Abstract base class for all model backends in the MCP server.
Subclasses adapt third-party SDKs into the MCP server by exposing Role
capability metadata, request execution, and token counting through a Defines the interface every provider must implement so the registry,
consistent interface. Shared helper methods (temperature validation, restriction service, and tools have a uniform surface for listing
alias resolution, image handling, etc.) live here so individual providers models, resolving aliases, and executing requests.
only need to focus on provider-specific details.
Responsibilities
* expose static capability metadata for each supported model via
:class:`ModelCapabilities`
* accept user prompts, forward them to the underlying SDK, and wrap
responses in :class:`ModelResponse`
* report tokenizer counts for budgeting and validation logic
* advertise provider identity (``ProviderType``) so restriction
policies can map environment configuration onto providers
* validate whether a model name or alias is recognised by the provider
Shared helpers like temperature validation, alias resolution, and
restriction-aware ``list_models`` live here so concrete subclasses only
need to supply their catalogue and wire up SDK-specific behaviour.
""" """
# All concrete providers must define their supported models # All concrete providers must define their supported models
@@ -151,67 +164,52 @@ class ModelProvider(ABC):
# If not found, return as-is # If not found, return as-is
return model_name return model_name
def list_models(self, respect_restrictions: bool = True) -> list[str]: def list_models(
"""Return a list of model names supported by this provider. self,
*,
This implementation uses the get_model_configurations() hook respect_restrictions: bool = True,
to support different model configuration sources. include_aliases: bool = True,
lowercase: bool = False,
unique: bool = False,
) -> list[str]:
"""Return formatted model names supported by this provider.
Args: Args:
respect_restrictions: Whether to apply provider-specific restriction logic. respect_restrictions: Apply provider restriction policy.
include_aliases: Include aliases alongside canonical model names.
lowercase: Normalize returned names to lowercase.
unique: Deduplicate names after formatting.
Returns: Returns:
List of model names available from this provider List of model names formatted according to the provided options.
""" """
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
models = []
# Get model configurations from the hook method
model_configs = self.get_model_configurations() model_configs = self.get_model_configurations()
if not model_configs:
return []
for model_name in model_configs: restriction_service = None
# Check restrictions if enabled if respect_restrictions:
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name): from utils.model_restrictions import get_restriction_service
continue
# Add the base model restriction_service = get_restriction_service()
models.append(model_name)
# Add aliases derived from the model configurations if restriction_service:
alias_map = ModelCapabilities.collect_aliases(model_configs) allowed_configs = {}
for model_name, aliases in alias_map.items(): for model_name, config in model_configs.items():
# Only add aliases for models that passed restriction check if restriction_service.is_allowed(self.get_provider_type(), model_name):
if model_name in models: allowed_configs[model_name] = config
models.extend(aliases) model_configs = allowed_configs
return models if not model_configs:
return []
def list_all_known_models(self) -> list[str]: return ModelCapabilities.collect_model_names(
"""Return all model names known by this provider, including alias targets. model_configs,
include_aliases=include_aliases,
This is used for validation purposes to ensure restriction policies lowercase=lowercase,
can validate against both aliases and their target model names. unique=unique,
)
Returns:
List of all model names and alias targets known by this provider
"""
all_models = set()
# Get model configurations from the hook method
model_configs = self.get_model_configurations()
# Add all base model names
for model_name in model_configs:
all_models.add(model_name.lower())
# Add aliases derived from the model configurations
for aliases in ModelCapabilities.collect_aliases(model_configs).values():
for alias in aliases:
all_models.add(alias.lower())
return list(all_models)
def validate_image(self, image_path: str, max_size_mb: float = None) -> tuple[bytes, str]: def validate_image(self, image_path: str, max_size_mb: float = None) -> tuple[bytes, str]:
"""Provider-independent image validation. """Provider-independent image validation.

View File

@@ -32,11 +32,20 @@ _TEMP_UNSUPPORTED_KEYWORDS = [
class CustomProvider(OpenAICompatibleProvider): class CustomProvider(OpenAICompatibleProvider):
"""Adapter for self-hosted or local OpenAI-compatible endpoints. """Adapter for self-hosted or local OpenAI-compatible endpoints.
The provider reuses the :mod:`providers.shared` registry to surface Role
user-defined aliases and capability metadata. It also normalises Provide a uniform bridge between the MCP server and user-managed
Ollama-style version tags (``model:latest``) and enforces the same OpenAI-compatible services (Ollama, vLLM, LM Studio, bespoke gateways).
restriction policies used by cloud providers, ensuring consistent By subclassing :class:`OpenAICompatibleProvider` it inherits request and
behaviour regardless of where the model is hosted. token handling, while the custom registry exposes locally defined model
metadata.
Notable behaviour
* Uses :class:`OpenRouterModelRegistry` to load model definitions and
aliases so custom deployments share the same metadata pipeline as
OpenRouter itself.
* Normalises version-tagged model names (``model:latest``) and applies
restriction policies just like cloud providers, ensuring consistent
behaviour across environments.
""" """
FRIENDLY_NAME = "Custom API" FRIENDLY_NAME = "Custom API"

View File

@@ -17,9 +17,19 @@ from .shared import (
class OpenRouterProvider(OpenAICompatibleProvider): class OpenRouterProvider(OpenAICompatibleProvider):
"""Client for OpenRouter's multi-model aggregation service. """Client for OpenRouter's multi-model aggregation service.
OpenRouter surfaces dozens of upstream vendors. This provider layers alias Role
resolution, restriction-aware filtering, and sensible capability defaults Surface OpenRouters dynamic catalogue through the same interface as
on top of the generic OpenAI-compatible plumbing. native providers so tools can reference OpenRouter models and aliases
without special cases.
Characteristics
* Pulls live model definitions from :class:`OpenRouterModelRegistry`
(aliases, provider-specific metadata, capability hints)
* Applies alias-aware restriction checks before exposing models to the
registry or tooling
* Reuses :class:`OpenAICompatibleProvider` infrastructure for request
execution so OpenRouter endpoints behave like standard OpenAI-style
APIs.
""" """
FRIENDLY_NAME = "OpenRouter" FRIENDLY_NAME = "OpenRouter"
@@ -208,75 +218,56 @@ class OpenRouterProvider(OpenAICompatibleProvider):
""" """
return False return False
def list_models(self, respect_restrictions: bool = True) -> list[str]: def list_models(
"""Return a list of model names supported by this provider. self,
*,
respect_restrictions: bool = True,
include_aliases: bool = True,
lowercase: bool = False,
unique: bool = False,
) -> list[str]:
"""Return formatted OpenRouter model names, respecting alias-aware restrictions."""
Args: if not self._registry:
respect_restrictions: Whether to apply provider-specific restriction logic. return []
Returns:
List of model names available from this provider
"""
from utils.model_restrictions import get_restriction_service from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None restriction_service = get_restriction_service() if respect_restrictions else None
models = [] allowed_configs: dict[str, ModelCapabilities] = {}
if self._registry: for model_name in self._registry.list_models():
for model_name in self._registry.list_models(): config = self._registry.resolve(model_name)
# ===================================================================================== if not config:
# CRITICAL ALIAS-AWARE RESTRICTION CHECKING (Fixed Issue #98) continue
# =====================================================================================
# Previously, restrictions only checked full model names (e.g., "google/gemini-2.5-pro")
# but users specify aliases in OPENROUTER_ALLOWED_MODELS (e.g., "pro").
# This caused "no models available" error even with valid restrictions.
#
# Fix: Check both model name AND all aliases against restrictions
# TEST COVERAGE: tests/test_provider_routing_bugs.py::TestOpenRouterAliasRestrictions
# =====================================================================================
if restriction_service:
# Get model config to check aliases as well
model_config = self._registry.resolve(model_name)
allowed = False
# Check if model name itself is allowed if restriction_service:
if restriction_service.is_allowed(self.get_provider_type(), model_name): allowed = restriction_service.is_allowed(self.get_provider_type(), model_name)
allowed = True
# CRITICAL: Also check aliases - this fixes the alias restriction bug if not allowed and config.aliases:
if not allowed and model_config and model_config.aliases: for alias in config.aliases:
for alias in model_config.aliases: if restriction_service.is_allowed(self.get_provider_type(), alias):
if restriction_service.is_allowed(self.get_provider_type(), alias): allowed = True
allowed = True break
break
if not allowed: if not allowed:
continue continue
models.append(model_name) allowed_configs[model_name] = config
return models if not allowed_configs:
return []
def list_all_known_models(self) -> list[str]: # When restrictions are in place, don't include aliases to avoid confusion
"""Return all model names known by this provider, including alias targets. # Only return the canonical model names that are actually allowed
actual_include_aliases = include_aliases and not respect_restrictions
Returns: return ModelCapabilities.collect_model_names(
List of all model names and alias targets known by this provider allowed_configs,
""" include_aliases=actual_include_aliases,
all_models = set() lowercase=lowercase,
unique=unique,
if self._registry: )
# Get all models and aliases from the registry
all_models.update(model.lower() for model in self._registry.list_models())
all_models.update(alias.lower() for alias in self._registry.list_aliases())
# For each alias, also add its target
for alias in self._registry.list_aliases():
config = self._registry.resolve(alias)
if config:
all_models.add(config.model_name.lower())
return list(all_models)
def get_model_configurations(self) -> dict[str, ModelCapabilities]: def get_model_configurations(self) -> dict[str, ModelCapabilities]:
"""Get model configurations from the registry. """Get model configurations from the registry.

View File

@@ -17,12 +17,21 @@ from .shared import (
class OpenRouterModelRegistry: class OpenRouterModelRegistry:
"""Loads and validates the OpenRouter/custom model catalogue. """In-memory view of OpenRouter and custom model metadata.
The registry parses ``conf/custom_models.json`` (or an override supplied via Role
environment variable), builds case-insensitive alias maps, and exposes Parse the packaged ``conf/custom_models.json`` (or user-specified
:class:`~providers.shared.ModelCapabilities` objects used by several overrides), construct alias and capability maps, and serve those
providers. structures to providers that rely on OpenRouter semantics (both the
OpenRouter provider itself and the Custom provider).
Key duties
* Load :class:`ModelCapabilities` definitions from configuration files
* Maintain a case-insensitive alias → canonical name map for fast
resolution
* Provide helpers to list models, list aliases, and resolve an arbitrary
name to its capability object without repeatedly touching the file
system.
""" """
def __init__(self, config_path: Optional[str] = None): def __init__(self, config_path: Optional[str] = None):

View File

@@ -12,11 +12,22 @@ if TYPE_CHECKING:
class ModelProviderRegistry: class ModelProviderRegistry:
"""Singleton that caches provider instances and coordinates priority order. """Central catalogue of provider implementations used by the MCP server.
Responsibilities include resolving API keys from the environment, lazily Role
instantiating providers, and choosing the best provider for a model based Holds the mapping between :class:`ProviderType` values and concrete
on restriction policies and provider priority. :class:`ModelProvider` subclasses/factories. At runtime the registry
is responsible for instantiating providers, caching them for reuse, and
mediating lookup of providers and model names in provider priority
order.
Core responsibilities
* Resolve API keys and other runtime configuration for each provider
* Lazily create provider instances so unused backends incur no cost
* Expose convenience methods for enumerating available models and
locating which provider can service a requested model name or alias
* Honour the project-wide provider priority policy so namespaces (or
alias collisions) are resolved deterministically.
""" """
_instance = None _instance = None

View File

@@ -11,24 +11,46 @@ __all__ = ["ModelCapabilities"]
@dataclass @dataclass
class ModelCapabilities: class ModelCapabilities:
"""Static capabilities and constraints for a provider-managed model.""" """Static description of what a model can do within a provider.
Role
Acts as the canonical record for everything the server needs to know
about a model—its provider, token limits, feature switches, aliases,
and temperature rules. Providers populate these objects so tools and
higher-level services can rely on a consistent schema.
Typical usage
* Provider subclasses declare `MODEL_CAPABILITIES` maps containing these
objects (for example ``OpenAIModelProvider``)
* Helper utilities (e.g. restriction validation, alias expansion) read
these objects to build model lists for tooling and policy enforcement
* Tool selection logic inspects attributes such as
``supports_extended_thinking`` or ``context_window`` to choose an
appropriate model for a task.
"""
provider: ProviderType provider: ProviderType
model_name: str model_name: str
friendly_name: str friendly_name: str
context_window: int description: str = ""
max_output_tokens: int aliases: list[str] = field(default_factory=list)
# Capacity limits / resource budgets
context_window: int = 0
max_output_tokens: int = 0
max_thinking_tokens: int = 0
# Capability flags
supports_extended_thinking: bool = False supports_extended_thinking: bool = False
supports_system_prompts: bool = True supports_system_prompts: bool = True
supports_streaming: bool = True supports_streaming: bool = True
supports_function_calling: bool = False supports_function_calling: bool = False
supports_images: bool = False supports_images: bool = False
max_image_size_mb: float = 0.0
supports_temperature: bool = True
description: str = ""
aliases: list[str] = field(default_factory=list)
supports_json_mode: bool = False supports_json_mode: bool = False
max_thinking_tokens: int = 0 supports_temperature: bool = True
# Additional attributes
max_image_size_mb: float = 0.0
is_custom: bool = False is_custom: bool = False
temperature_constraint: TemperatureConstraint = field( temperature_constraint: TemperatureConstraint = field(
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3) default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
@@ -56,3 +78,45 @@ class ModelCapabilities:
for base_model, capabilities in model_configs.items() for base_model, capabilities in model_configs.items()
if capabilities.aliases if capabilities.aliases
} }
@staticmethod
def collect_model_names(
model_configs: dict[str, "ModelCapabilities"],
*,
include_aliases: bool = True,
lowercase: bool = False,
unique: bool = False,
) -> list[str]:
"""Build an ordered list of model names and aliases.
Args:
model_configs: Mapping of canonical model names to capabilities.
include_aliases: When True, include aliases for each model.
lowercase: When True, normalize names to lowercase.
unique: When True, ensure each returned name appears once (after formatting).
Returns:
Ordered list of model names (and optionally aliases) formatted per options.
"""
formatted_names: list[str] = []
seen: set[str] | None = set() if unique else None
def append_name(name: str) -> None:
formatted = name.lower() if lowercase else name
if seen is not None:
if formatted in seen:
return
seen.add(formatted)
formatted_names.append(formatted)
for base_model, capabilities in model_configs.items():
append_name(base_model)
if include_aliases and capabilities.aliases:
for alias in capabilities.aliases:
append_name(alias)
return formatted_names

View File

@@ -22,7 +22,7 @@ class TestAliasTargetRestrictions:
provider = OpenAIModelProvider(api_key="test-key") provider = OpenAIModelProvider(api_key="test-key")
# Get all known models including aliases and targets # Get all known models including aliases and targets
all_known = provider.list_all_known_models() all_known = provider.list_models(respect_restrictions=False, include_aliases=True, lowercase=True, unique=True)
# Should include both aliases and their targets # Should include both aliases and their targets
assert "mini" in all_known # alias assert "mini" in all_known # alias
@@ -35,7 +35,7 @@ class TestAliasTargetRestrictions:
provider = GeminiModelProvider(api_key="test-key") provider = GeminiModelProvider(api_key="test-key")
# Get all known models including aliases and targets # Get all known models including aliases and targets
all_known = provider.list_all_known_models() all_known = provider.list_models(respect_restrictions=False, include_aliases=True, lowercase=True, unique=True)
# Should include both aliases and their targets # Should include both aliases and their targets
assert "flash" in all_known # alias assert "flash" in all_known # alias
@@ -162,7 +162,9 @@ class TestAliasTargetRestrictions:
""" """
# Test OpenAI provider # Test OpenAI provider
openai_provider = OpenAIModelProvider(api_key="test-key") openai_provider = OpenAIModelProvider(api_key="test-key")
openai_all_known = openai_provider.list_all_known_models() openai_all_known = openai_provider.list_models(
respect_restrictions=False, include_aliases=True, lowercase=True, unique=True
)
# Verify that for each alias, its target is also included # Verify that for each alias, its target is also included
for model_name, config in openai_provider.MODEL_CAPABILITIES.items(): for model_name, config in openai_provider.MODEL_CAPABILITIES.items():
@@ -175,7 +177,9 @@ class TestAliasTargetRestrictions:
# Test Gemini provider # Test Gemini provider
gemini_provider = GeminiModelProvider(api_key="test-key") gemini_provider = GeminiModelProvider(api_key="test-key")
gemini_all_known = gemini_provider.list_all_known_models() gemini_all_known = gemini_provider.list_models(
respect_restrictions=False, include_aliases=True, lowercase=True, unique=True
)
# Verify that for each alias, its target is also included # Verify that for each alias, its target is also included
for model_name, config in gemini_provider.MODEL_CAPABILITIES.items(): for model_name, config in gemini_provider.MODEL_CAPABILITIES.items():
@@ -186,8 +190,8 @@ class TestAliasTargetRestrictions:
config.lower() in gemini_all_known config.lower() in gemini_all_known
), f"Target '{config}' for alias '{model_name}' not in known models" ), f"Target '{config}' for alias '{model_name}' not in known models"
def test_no_duplicate_models_in_list_all_known_models(self): def test_no_duplicate_models_in_alias_aware_listing(self):
"""Test that list_all_known_models doesn't return duplicates.""" """Test that alias-aware list_models variant doesn't return duplicates."""
# Test all providers # Test all providers
providers = [ providers = [
OpenAIModelProvider(api_key="test-key"), OpenAIModelProvider(api_key="test-key"),
@@ -195,7 +199,9 @@ class TestAliasTargetRestrictions:
] ]
for provider in providers: for provider in providers:
all_known = provider.list_all_known_models() all_known = provider.list_models(
respect_restrictions=False, include_aliases=True, lowercase=True, unique=True
)
# Should not have duplicates # Should not have duplicates
assert len(all_known) == len(set(all_known)), f"{provider.__class__.__name__} returns duplicate models" assert len(all_known) == len(set(all_known)), f"{provider.__class__.__name__} returns duplicate models"
@@ -207,7 +213,7 @@ class TestAliasTargetRestrictions:
from unittest.mock import MagicMock from unittest.mock import MagicMock
mock_provider = MagicMock() mock_provider = MagicMock()
mock_provider.list_all_known_models.return_value = ["model1", "model2", "target-model"] mock_provider.list_models.return_value = ["model1", "model2", "target-model"]
# Set up a restriction that should trigger validation # Set up a restriction that should trigger validation
service.restrictions = {ProviderType.OPENAI: {"invalid-model"}} service.restrictions = {ProviderType.OPENAI: {"invalid-model"}}
@@ -218,7 +224,12 @@ class TestAliasTargetRestrictions:
service.validate_against_known_models(provider_instances) service.validate_against_known_models(provider_instances)
# Verify the polymorphic method was called # Verify the polymorphic method was called
mock_provider.list_all_known_models.assert_called_once() mock_provider.list_models.assert_called_once_with(
respect_restrictions=False,
include_aliases=True,
lowercase=True,
unique=True,
)
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini"}) # Restrict to specific model @patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini"}) # Restrict to specific model
def test_complex_alias_chains_handled_correctly(self): def test_complex_alias_chains_handled_correctly(self):
@@ -250,7 +261,7 @@ class TestAliasTargetRestrictions:
- A restriction on "o4-mini" (target) would not be recognized as valid - A restriction on "o4-mini" (target) would not be recognized as valid
After the fix: After the fix:
- list_all_known_models() returns ["mini", "o3mini", "o4-mini", "o3-mini"] (aliases + targets) - list_models(respect_restrictions=False, include_aliases=True, lowercase=True, unique=True) returns ["mini", "o3mini", "o4-mini", "o3-mini"] (aliases + targets)
- validate_against_known_models() checks against all names - validate_against_known_models() checks against all names
- A restriction on "o4-mini" is recognized as valid - A restriction on "o4-mini" is recognized as valid
""" """
@@ -262,7 +273,7 @@ class TestAliasTargetRestrictions:
provider_instances = {ProviderType.OPENAI: provider} provider_instances = {ProviderType.OPENAI: provider}
# Get all known models - should include BOTH aliases AND targets # Get all known models - should include BOTH aliases AND targets
all_known = provider.list_all_known_models() all_known = provider.list_models(respect_restrictions=False, include_aliases=True, lowercase=True, unique=True)
# Critical check: should contain both aliases and their targets # Critical check: should contain both aliases and their targets
assert "mini" in all_known # alias assert "mini" in all_known # alias
@@ -310,7 +321,7 @@ class TestAliasTargetRestrictions:
the restriction is properly enforced and the target is recognized as a valid the restriction is properly enforced and the target is recognized as a valid
model to restrict. model to restrict.
The bug: If list_all_known_models() doesn't include targets, then validation The bug: If list_models(respect_restrictions=False, include_aliases=True, lowercase=True, unique=True) doesn't include targets, then validation
would incorrectly warn that target model names are "not recognized", making would incorrectly warn that target model names are "not recognized", making
it appear that target-based restrictions don't work. it appear that target-based restrictions don't work.
""" """
@@ -325,7 +336,9 @@ class TestAliasTargetRestrictions:
provider = OpenAIModelProvider(api_key="test-key") provider = OpenAIModelProvider(api_key="test-key")
# These specific target models should be recognized as valid # These specific target models should be recognized as valid
all_known = provider.list_all_known_models() all_known = provider.list_models(
respect_restrictions=False, include_aliases=True, lowercase=True, unique=True
)
assert "o4-mini" in all_known, "Target model o4-mini should be known" assert "o4-mini" in all_known, "Target model o4-mini should be known"
assert "o3-mini" in all_known, "Target model o3-mini should be known" assert "o3-mini" in all_known, "Target model o3-mini should be known"

View File

@@ -1,12 +1,9 @@
""" """
Tests that demonstrate the OLD BUGGY BEHAVIOR is now FIXED. Regression scenarios ensuring alias-aware model listings stay correct.
These tests verify that scenarios which would have incorrectly passed Each test captures behavior that previously regressed so we can guard it
before our fix now behave correctly. Each test documents the specific permanently. The focus is confirming aliases and their canonical targets
bug that was fixed and what the old vs new behavior should be. remain visible to the restriction service and related validation logic.
IMPORTANT: These tests PASS with our fix, but would have FAILED to catch
bugs with the old code (before list_all_known_models was implemented).
""" """
import os import os
@@ -21,42 +18,34 @@ from utils.model_restrictions import ModelRestrictionService
class TestBuggyBehaviorPrevention: class TestBuggyBehaviorPrevention:
""" """Regression tests for alias-aware restriction validation."""
These tests prove that our fix prevents the HIGH-severity regression
that was identified by the O3 precommit analysis.
OLD BUG: list_models() only returned alias keys, not targets def test_alias_listing_includes_targets_for_restriction_validation(self):
FIX: list_all_known_models() returns both aliases AND targets """Alias-aware lists expose both aliases and canonical targets."""
"""
def test_old_bug_would_miss_target_restrictions(self):
"""
OLD BUG: If restriction was set on target model (e.g., 'o4-mini'),
validation would incorrectly warn it's not recognized because
list_models() only returned aliases ['mini', 'o3mini'].
NEW BEHAVIOR: list_all_known_models() includes targets, so 'o4-mini'
is recognized as valid and no warning is generated.
"""
provider = OpenAIModelProvider(api_key="test-key") provider = OpenAIModelProvider(api_key="test-key")
# This is what the old broken list_models() would return - aliases only # Baseline alias-only list captured for regression documentation
old_broken_list = ["mini", "o3mini"] # Missing 'o4-mini', 'o3-mini' targets alias_only_snapshot = ["mini", "o3mini"] # Missing 'o4-mini', 'o3-mini' targets
# This is what our fixed list_all_known_models() returns # Canonical listing with aliases and targets
new_fixed_list = provider.list_all_known_models() comprehensive_list = provider.list_models(
respect_restrictions=False,
include_aliases=True,
lowercase=True,
unique=True,
)
# Verify the fix: new method includes both aliases AND targets # Comprehensive listing should contain aliases and their targets
assert "mini" in new_fixed_list # alias assert "mini" in comprehensive_list
assert "o4-mini" in new_fixed_list # target - THIS WAS MISSING IN OLD CODE assert "o4-mini" in comprehensive_list
assert "o3mini" in new_fixed_list # alias assert "o3mini" in comprehensive_list
assert "o3-mini" in new_fixed_list # target - THIS WAS MISSING IN OLD CODE assert "o3-mini" in comprehensive_list
# Prove the old behavior was broken # Legacy alias-only snapshots exclude targets
assert "o4-mini" not in old_broken_list # Old code didn't include targets assert "o4-mini" not in alias_only_snapshot
assert "o3-mini" not in old_broken_list # Old code didn't include targets assert "o3-mini" not in alias_only_snapshot
# This target validation would have FAILED with old code # This scenario previously failed when targets were omitted
service = ModelRestrictionService() service = ModelRestrictionService()
service.restrictions = {ProviderType.OPENAI: {"o4-mini"}} # Restrict to target service.restrictions = {ProviderType.OPENAI: {"o4-mini"}} # Restrict to target
@@ -64,24 +53,19 @@ class TestBuggyBehaviorPrevention:
provider_instances = {ProviderType.OPENAI: provider} provider_instances = {ProviderType.OPENAI: provider}
service.validate_against_known_models(provider_instances) service.validate_against_known_models(provider_instances)
# NEW BEHAVIOR: No warnings because o4-mini is now in list_all_known_models # No warnings expected because alias-aware list includes the target
target_warnings = [ target_warnings = [
call call
for call in mock_logger.warning.call_args_list for call in mock_logger.warning.call_args_list
if "o4-mini" in str(call) and "not a recognized" in str(call) if "o4-mini" in str(call) and "not a recognized" in str(call)
] ]
assert len(target_warnings) == 0, "o4-mini should be recognized with our fix" assert len(target_warnings) == 0, "o4-mini should be recognized as a valid target"
def test_old_bug_would_incorrectly_warn_about_valid_targets(self): def test_target_models_are_recognized_during_validation(self):
""" """Target model restrictions should not trigger false warnings."""
OLD BUG: Admins setting restrictions on target models would get
false warnings that their restriction models are "not recognized".
NEW BEHAVIOR: Target models are properly recognized.
"""
# Test with Gemini provider too # Test with Gemini provider too
provider = GeminiModelProvider(api_key="test-key") provider = GeminiModelProvider(api_key="test-key")
all_known = provider.list_all_known_models() all_known = provider.list_models(respect_restrictions=False, include_aliases=True, lowercase=True, unique=True)
# Verify both aliases and targets are included # Verify both aliases and targets are included
assert "flash" in all_known # alias assert "flash" in all_known # alias
@@ -108,13 +92,8 @@ class TestBuggyBehaviorPrevention:
assert "gemini-2.5-flash" not in warning or "not a recognized" not in warning assert "gemini-2.5-flash" not in warning or "not a recognized" not in warning
assert "gemini-2.5-pro" not in warning or "not a recognized" not in warning assert "gemini-2.5-pro" not in warning or "not a recognized" not in warning
def test_old_bug_policy_bypass_prevention(self): def test_policy_enforcement_remains_comprehensive(self):
""" """Policy validation must account for both aliases and targets."""
OLD BUG: Policy enforcement was incomplete because validation
didn't know about target models. This could allow policy bypasses.
NEW BEHAVIOR: Complete validation against all known model names.
"""
provider = OpenAIModelProvider(api_key="test-key") provider = OpenAIModelProvider(api_key="test-key")
# Simulate a scenario where admin wants to restrict specific targets # Simulate a scenario where admin wants to restrict specific targets
@@ -138,64 +117,85 @@ class TestBuggyBehaviorPrevention:
# But o4mini (the actual alias for o4-mini) should work # But o4mini (the actual alias for o4-mini) should work
assert provider.validate_model_name("o4mini") # Resolves to o4-mini, which IS allowed assert provider.validate_model_name("o4mini") # Resolves to o4-mini, which IS allowed
# Verify our list_all_known_models includes the restricted models # Verify our alias-aware list includes the restricted models
all_known = provider.list_all_known_models() all_known = provider.list_models(
respect_restrictions=False,
include_aliases=True,
lowercase=True,
unique=True,
)
assert "o3-mini" in all_known # Should be known (and allowed) assert "o3-mini" in all_known # Should be known (and allowed)
assert "o4-mini" in all_known # Should be known (and allowed) assert "o4-mini" in all_known # Should be known (and allowed)
assert "o3-pro" in all_known # Should be known (but blocked) assert "o3-pro" in all_known # Should be known (but blocked)
assert "mini" in all_known # Should be known (and allowed since it resolves to o4-mini) assert "mini" in all_known # Should be known (and allowed since it resolves to o4-mini)
def test_demonstration_of_old_vs_new_interface(self): def test_alias_aware_listing_extends_canonical_view(self):
""" """Alias-aware list should be a superset of restriction-filtered names."""
Direct comparison of old vs new interface to document the fix.
"""
provider = OpenAIModelProvider(api_key="test-key") provider = OpenAIModelProvider(api_key="test-key")
# OLD interface (still exists for backward compatibility) baseline_models = provider.list_models(respect_restrictions=False)
old_style_models = provider.list_models(respect_restrictions=False)
# NEW interface (our fix) alias_aware_models = provider.list_models(
new_comprehensive_models = provider.list_all_known_models() respect_restrictions=False,
include_aliases=True,
lowercase=True,
unique=True,
)
# The new interface should be a superset of the old one # Alias-aware variant should contain everything from the baseline
for model in old_style_models: for model in baseline_models:
assert model.lower() in [ assert model.lower() in [
m.lower() for m in new_comprehensive_models m.lower() for m in alias_aware_models
], f"New interface missing model {model} from old interface" ], f"Alias-aware listing missing baseline model {model}"
# The new interface should include target models that old one might miss # Alias-aware variant should include canonical targets as well
targets_that_should_exist = ["o4-mini", "o3-mini"] for target in ("o4-mini", "o3-mini"):
for target in targets_that_should_exist: assert target in alias_aware_models, f"Alias-aware listing should include target model {target}"
assert target in new_comprehensive_models, f"New interface should include target model {target}"
def test_old_validation_interface_still_works(self): def test_restriction_validation_uses_alias_aware_variant(self):
""" """Validation should request the alias-aware lowercased, deduped list."""
Verify our fix doesn't break existing validation workflows.
"""
service = ModelRestrictionService() service = ModelRestrictionService()
# Create a mock provider that simulates the old behavior # Simulate a provider that only returns aliases when asked for models
old_style_provider = MagicMock() alias_only_provider = MagicMock()
old_style_provider.MODEL_CAPABILITIES = { alias_only_provider.MODEL_CAPABILITIES = {
"mini": "o4-mini", "mini": "o4-mini",
"o3mini": "o3-mini", "o3mini": "o3-mini",
"o4-mini": {"context_window": 200000}, "o4-mini": {"context_window": 200000},
"o3-mini": {"context_window": 200000}, "o3-mini": {"context_window": 200000},
} }
# OLD BROKEN: This would only return aliases
old_style_provider.list_models.return_value = ["mini", "o3mini"] # Simulate alias-only vs. alias-aware behavior using a side effect
# NEW FIXED: This includes both aliases and targets def list_models_side_effect(**kwargs):
old_style_provider.list_all_known_models.return_value = ["mini", "o3mini", "o4-mini", "o3-mini"] respect_restrictions = kwargs.get("respect_restrictions", True)
include_aliases = kwargs.get("include_aliases", True)
lowercase = kwargs.get("lowercase", False)
unique = kwargs.get("unique", False)
if respect_restrictions and include_aliases and not lowercase and not unique:
return ["mini", "o3mini"]
if not respect_restrictions and include_aliases and lowercase and unique:
return ["mini", "o3mini", "o4-mini", "o3-mini"]
raise AssertionError(f"Unexpected list_models call: {kwargs}")
alias_only_provider.list_models.side_effect = list_models_side_effect
# Test that validation now uses the comprehensive method # Test that validation now uses the comprehensive method
service.restrictions = {ProviderType.OPENAI: {"o4-mini"}} # Restrict to target service.restrictions = {ProviderType.OPENAI: {"o4-mini"}} # Restrict to target
with patch("utils.model_restrictions.logger") as mock_logger: with patch("utils.model_restrictions.logger") as mock_logger:
provider_instances = {ProviderType.OPENAI: old_style_provider} provider_instances = {ProviderType.OPENAI: alias_only_provider}
service.validate_against_known_models(provider_instances) service.validate_against_known_models(provider_instances)
# Verify the new method was called, not the old one # Verify the alias-aware variant was used
old_style_provider.list_all_known_models.assert_called_once() alias_only_provider.list_models.assert_called_with(
respect_restrictions=False,
include_aliases=True,
lowercase=True,
unique=True,
)
# Should not warn about o4-mini being unrecognized # Should not warn about o4-mini being unrecognized
target_warnings = [ target_warnings = [
@@ -205,17 +205,17 @@ class TestBuggyBehaviorPrevention:
] ]
assert len(target_warnings) == 0 assert len(target_warnings) == 0
def test_regression_proof_comprehensive_coverage(self): def test_alias_listing_covers_targets_for_all_providers(self):
""" """Alias-aware listings should expose targets across providers."""
Comprehensive test to prove our fix covers all provider types.
"""
providers_to_test = [ providers_to_test = [
(OpenAIModelProvider(api_key="test-key"), "mini", "o4-mini"), (OpenAIModelProvider(api_key="test-key"), "mini", "o4-mini"),
(GeminiModelProvider(api_key="test-key"), "flash", "gemini-2.5-flash"), (GeminiModelProvider(api_key="test-key"), "flash", "gemini-2.5-flash"),
] ]
for provider, alias, target in providers_to_test: for provider, alias, target in providers_to_test:
all_known = provider.list_all_known_models() all_known = provider.list_models(
respect_restrictions=False, include_aliases=True, lowercase=True, unique=True
)
# Every provider should include both aliases and targets # Every provider should include both aliases and targets
assert alias in all_known, f"{provider.__class__.__name__} missing alias {alias}" assert alias in all_known, f"{provider.__class__.__name__} missing alias {alias}"
@@ -226,13 +226,7 @@ class TestBuggyBehaviorPrevention:
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini,invalid-model"}) @patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini,invalid-model"})
def test_validation_correctly_identifies_invalid_models(self): def test_validation_correctly_identifies_invalid_models(self):
""" """Validation should flag invalid models while listing valid targets."""
Test that validation still catches truly invalid models while
properly recognizing valid target models.
This proves our fix works: o4-mini appears in the "Known models" list
because list_all_known_models() now includes target models.
"""
# Clear cached restriction service # Clear cached restriction service
import utils.model_restrictions import utils.model_restrictions
@@ -245,7 +239,6 @@ class TestBuggyBehaviorPrevention:
provider_instances = {ProviderType.OPENAI: provider} provider_instances = {ProviderType.OPENAI: provider}
service.validate_against_known_models(provider_instances) service.validate_against_known_models(provider_instances)
# Should warn about 'invalid-model' (truly invalid)
invalid_warnings = [ invalid_warnings = [
call call
for call in mock_logger.warning.call_args_list for call in mock_logger.warning.call_args_list
@@ -253,39 +246,37 @@ class TestBuggyBehaviorPrevention:
] ]
assert len(invalid_warnings) > 0, "Should warn about truly invalid models" assert len(invalid_warnings) > 0, "Should warn about truly invalid models"
# The warning should mention o4-mini in the "Known models" list (proving our fix works) # The warning should mention o4-mini in the known models list
warning_text = str(mock_logger.warning.call_args_list[0]) warning_text = str(mock_logger.warning.call_args_list[0])
assert "Known models:" in warning_text, "Warning should include known models list" assert "Known models:" in warning_text, "Warning should include known models list"
assert "o4-mini" in warning_text, "o4-mini should appear in known models (proves our fix works)" assert "o4-mini" in warning_text, "o4-mini should appear in known models"
assert "o3-mini" in warning_text, "o3-mini should appear in known models (proves our fix works)" assert "o3-mini" in warning_text, "o3-mini should appear in known models"
# But the warning should be specifically about invalid-model # But the warning should be specifically about invalid-model
assert "'invalid-model'" in warning_text, "Warning should specifically mention invalid-model" assert "'invalid-model'" in warning_text, "Warning should specifically mention invalid-model"
def test_custom_provider_also_implements_fix(self): def test_custom_provider_alias_listing(self):
""" """Custom provider should expose alias-aware listings as well."""
Verify that custom provider also implements the comprehensive interface.
"""
from providers.custom import CustomProvider from providers.custom import CustomProvider
# This might fail if no URL is set, but that's expected # This might fail if no URL is set, but that's expected
try: try:
provider = CustomProvider(base_url="http://test.com/v1") provider = CustomProvider(base_url="http://test.com/v1")
all_known = provider.list_all_known_models() all_known = provider.list_models(
respect_restrictions=False, include_aliases=True, lowercase=True, unique=True
)
# Should return a list (might be empty if registry not loaded) # Should return a list (might be empty if registry not loaded)
assert isinstance(all_known, list) assert isinstance(all_known, list)
except ValueError: except ValueError:
# Expected if no base_url configured, skip this test # Expected if no base_url configured, skip this test
pytest.skip("Custom provider requires URL configuration") pytest.skip("Custom provider requires URL configuration")
def test_openrouter_provider_also_implements_fix(self): def test_openrouter_provider_alias_listing(self):
""" """OpenRouter provider should expose alias-aware listings."""
Verify that OpenRouter provider also implements the comprehensive interface.
"""
from providers.openrouter import OpenRouterProvider from providers.openrouter import OpenRouterProvider
provider = OpenRouterProvider(api_key="test-key") provider = OpenRouterProvider(api_key="test-key")
all_known = provider.list_all_known_models() all_known = provider.list_models(respect_restrictions=False, include_aliases=True, lowercase=True, unique=True)
# Should return a list with both aliases and targets # Should return a list with both aliases and targets
assert isinstance(all_known, list) assert isinstance(all_known, list)

View File

@@ -142,7 +142,7 @@ class TestModelRestrictionService:
"o3-mini": {"context_window": 200000}, "o3-mini": {"context_window": 200000},
"o4-mini": {"context_window": 200000}, "o4-mini": {"context_window": 200000},
} }
mock_provider.list_all_known_models.return_value = ["o3", "o3-mini", "o4-mini"] mock_provider.list_models.return_value = ["o3", "o3-mini", "o4-mini"]
provider_instances = {ProviderType.OPENAI: mock_provider} provider_instances = {ProviderType.OPENAI: mock_provider}
service.validate_against_known_models(provider_instances) service.validate_against_known_models(provider_instances)
@@ -447,7 +447,13 @@ class TestRegistryIntegration:
} }
mock_openai.get_provider_type.return_value = ProviderType.OPENAI mock_openai.get_provider_type.return_value = ProviderType.OPENAI
def openai_list_models(respect_restrictions=True): def openai_list_models(
*,
respect_restrictions: bool = True,
include_aliases: bool = True,
lowercase: bool = False,
unique: bool = False,
):
from utils.model_restrictions import get_restriction_service from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None restriction_service = get_restriction_service() if respect_restrictions else None
@@ -457,15 +463,26 @@ class TestRegistryIntegration:
target_model = config target_model = config
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model): if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model):
continue continue
models.append(model_name) if include_aliases:
models.append(model_name)
else: else:
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, model_name): if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, model_name):
continue continue
models.append(model_name) models.append(model_name)
if lowercase:
models = [m.lower() for m in models]
if unique:
seen = set()
ordered = []
for name in models:
if name in seen:
continue
seen.add(name)
ordered.append(name)
models = ordered
return models return models
mock_openai.list_models = openai_list_models mock_openai.list_models = MagicMock(side_effect=openai_list_models)
mock_openai.list_all_known_models.return_value = ["o3", "o3-mini"]
mock_gemini = MagicMock() mock_gemini = MagicMock()
mock_gemini.MODEL_CAPABILITIES = { mock_gemini.MODEL_CAPABILITIES = {
@@ -474,7 +491,13 @@ class TestRegistryIntegration:
} }
mock_gemini.get_provider_type.return_value = ProviderType.GOOGLE mock_gemini.get_provider_type.return_value = ProviderType.GOOGLE
def gemini_list_models(respect_restrictions=True): def gemini_list_models(
*,
respect_restrictions: bool = True,
include_aliases: bool = True,
lowercase: bool = False,
unique: bool = False,
):
from utils.model_restrictions import get_restriction_service from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None restriction_service = get_restriction_service() if respect_restrictions else None
@@ -484,18 +507,26 @@ class TestRegistryIntegration:
target_model = config target_model = config
if restriction_service and not restriction_service.is_allowed(ProviderType.GOOGLE, target_model): if restriction_service and not restriction_service.is_allowed(ProviderType.GOOGLE, target_model):
continue continue
models.append(model_name) if include_aliases:
models.append(model_name)
else: else:
if restriction_service and not restriction_service.is_allowed(ProviderType.GOOGLE, model_name): if restriction_service and not restriction_service.is_allowed(ProviderType.GOOGLE, model_name):
continue continue
models.append(model_name) models.append(model_name)
if lowercase:
models = [m.lower() for m in models]
if unique:
seen = set()
ordered = []
for name in models:
if name in seen:
continue
seen.add(name)
ordered.append(name)
models = ordered
return models return models
mock_gemini.list_models = gemini_list_models mock_gemini.list_models = MagicMock(side_effect=gemini_list_models)
mock_gemini.list_all_known_models.return_value = [
"gemini-2.5-pro",
"gemini-2.5-flash",
]
def get_provider_side_effect(provider_type): def get_provider_side_effect(provider_type):
if provider_type == ProviderType.OPENAI: if provider_type == ProviderType.OPENAI:
@@ -615,7 +646,13 @@ class TestAutoModeWithRestrictions:
} }
mock_openai.get_provider_type.return_value = ProviderType.OPENAI mock_openai.get_provider_type.return_value = ProviderType.OPENAI
def openai_list_models(respect_restrictions=True): def openai_list_models(
*,
respect_restrictions: bool = True,
include_aliases: bool = True,
lowercase: bool = False,
unique: bool = False,
):
from utils.model_restrictions import get_restriction_service from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None restriction_service = get_restriction_service() if respect_restrictions else None
@@ -625,15 +662,26 @@ class TestAutoModeWithRestrictions:
target_model = config target_model = config
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model): if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model):
continue continue
models.append(model_name) if include_aliases:
models.append(model_name)
else: else:
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, model_name): if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, model_name):
continue continue
models.append(model_name) models.append(model_name)
if lowercase:
models = [m.lower() for m in models]
if unique:
seen = set()
ordered = []
for name in models:
if name in seen:
continue
seen.add(name)
ordered.append(name)
models = ordered
return models return models
mock_openai.list_models = openai_list_models mock_openai.list_models = MagicMock(side_effect=openai_list_models)
mock_openai.list_all_known_models.return_value = ["o3", "o3-mini", "o4-mini"]
# Add get_preferred_model method to mock to match new implementation # Add get_preferred_model method to mock to match new implementation
def get_preferred_model(category, allowed_models): def get_preferred_model(category, allowed_models):

View File

@@ -1,216 +0,0 @@
"""
Tests that simulate the OLD BROKEN BEHAVIOR to prove it was indeed broken.
These tests create mock providers that behave like the old code (before our fix)
and demonstrate that they would have failed to catch the HIGH-severity bug.
IMPORTANT: These tests show what WOULD HAVE HAPPENED with the old code.
They prove that our fix was necessary and actually addresses real problems.
"""
from unittest.mock import MagicMock, patch
from providers.shared import ProviderType
from utils.model_restrictions import ModelRestrictionService
class TestOldBehaviorSimulation:
"""
Simulate the old broken behavior to prove it was buggy.
"""
def test_old_behavior_would_miss_target_restrictions(self):
"""
SIMULATION: This test recreates the OLD BROKEN BEHAVIOR and proves it was buggy.
OLD BUG: When validation service called provider.list_models(), it only got
aliases back, not targets. This meant target-based restrictions weren't validated.
"""
# Create a mock provider that simulates the OLD BROKEN BEHAVIOR
old_broken_provider = MagicMock()
old_broken_provider.MODEL_CAPABILITIES = {
"mini": "o4-mini", # alias -> target
"o3mini": "o3-mini", # alias -> target
"o4-mini": {"context_window": 200000},
"o3-mini": {"context_window": 200000},
}
# OLD BROKEN: list_models only returned aliases, missing targets
old_broken_provider.list_models.return_value = ["mini", "o3mini"]
# OLD BROKEN: There was no list_all_known_models method!
# We simulate this by making it behave like the old list_models
old_broken_provider.list_all_known_models.return_value = ["mini", "o3mini"] # MISSING TARGETS!
# Now test what happens when admin tries to restrict by target model
service = ModelRestrictionService()
service.restrictions = {ProviderType.OPENAI: {"o4-mini"}} # Restrict to target model
with patch("utils.model_restrictions.logger") as mock_logger:
provider_instances = {ProviderType.OPENAI: old_broken_provider}
service.validate_against_known_models(provider_instances)
# OLD BROKEN BEHAVIOR: Would warn about o4-mini being "not recognized"
# because it wasn't in the list_all_known_models response
target_warnings = [
call
for call in mock_logger.warning.call_args_list
if "o4-mini" in str(call) and "not a recognized" in str(call)
]
# This proves the old behavior was broken - it would generate false warnings
assert len(target_warnings) > 0, "OLD BROKEN BEHAVIOR: Would incorrectly warn about valid target models"
# Verify the warning message shows the broken list
warning_text = str(target_warnings[0])
assert "mini" in warning_text # Alias was included
assert "o3mini" in warning_text # Alias was included
# But targets were missing from the known models list in old behavior
def test_new_behavior_fixes_the_problem(self):
"""
Compare old vs new behavior to show our fix works.
"""
# Create mock provider with NEW FIXED BEHAVIOR
new_fixed_provider = MagicMock()
new_fixed_provider.MODEL_CAPABILITIES = {
"mini": "o4-mini",
"o3mini": "o3-mini",
"o4-mini": {"context_window": 200000},
"o3-mini": {"context_window": 200000},
}
# NEW FIXED: list_all_known_models includes BOTH aliases AND targets
new_fixed_provider.list_all_known_models.return_value = [
"mini",
"o3mini", # aliases
"o4-mini",
"o3-mini", # targets - THESE WERE MISSING IN OLD CODE!
]
# Same restriction scenario
service = ModelRestrictionService()
service.restrictions = {ProviderType.OPENAI: {"o4-mini"}} # Restrict to target model
with patch("utils.model_restrictions.logger") as mock_logger:
provider_instances = {ProviderType.OPENAI: new_fixed_provider}
service.validate_against_known_models(provider_instances)
# NEW FIXED BEHAVIOR: No warnings about o4-mini being unrecognized
target_warnings = [
call
for call in mock_logger.warning.call_args_list
if "o4-mini" in str(call) and "not a recognized" in str(call)
]
# Our fix prevents false warnings
assert len(target_warnings) == 0, "NEW FIXED BEHAVIOR: Should not warn about valid target models"
def test_policy_bypass_prevention_old_vs_new(self):
"""
Show how the old behavior could have led to policy bypass scenarios.
"""
# OLD BROKEN: Admin thinks they've restricted access to o4-mini,
# but validation doesn't recognize it as a valid restriction target
old_broken_provider = MagicMock()
old_broken_provider.list_all_known_models.return_value = ["mini", "o3mini"] # Missing targets
# NEW FIXED: Same provider with our fix
new_fixed_provider = MagicMock()
new_fixed_provider.list_all_known_models.return_value = ["mini", "o3mini", "o4-mini", "o3-mini"]
# Test restriction on target model - use completely separate service instances
old_service = ModelRestrictionService()
old_service.restrictions = {ProviderType.OPENAI: {"o4-mini", "completely-invalid-model"}}
new_service = ModelRestrictionService()
new_service.restrictions = {ProviderType.OPENAI: {"o4-mini", "completely-invalid-model"}}
# OLD BEHAVIOR: Would warn about BOTH models being unrecognized
with patch("utils.model_restrictions.logger") as mock_logger_old:
provider_instances = {ProviderType.OPENAI: old_broken_provider}
old_service.validate_against_known_models(provider_instances)
old_warnings = [str(call) for call in mock_logger_old.warning.call_args_list]
print(f"OLD warnings: {old_warnings}") # Debug output
# NEW BEHAVIOR: Only warns about truly invalid model
with patch("utils.model_restrictions.logger") as mock_logger_new:
provider_instances = {ProviderType.OPENAI: new_fixed_provider}
new_service.validate_against_known_models(provider_instances)
new_warnings = [str(call) for call in mock_logger_new.warning.call_args_list]
print(f"NEW warnings: {new_warnings}") # Debug output
# For now, just verify that we get some warnings in both cases
# The key point is that the "Known models" list is different
assert len(old_warnings) > 0, "OLD: Should have warnings"
assert len(new_warnings) > 0, "NEW: Should have warnings for invalid model"
# Verify the known models list is different between old and new
str(old_warnings[0]) if old_warnings else ""
new_warning_text = str(new_warnings[0]) if new_warnings else ""
if "Known models:" in new_warning_text:
# NEW behavior should include o4-mini in known models list
assert "o4-mini" in new_warning_text, "NEW: Should include o4-mini in known models"
print("This test demonstrates that our fix improves the 'Known models' list shown to users.")
def test_demonstrate_target_coverage_improvement(self):
"""
Show the exact improvement in target model coverage.
"""
# Simulate different provider implementations
providers_old_vs_new = [
# (old_broken_list, new_fixed_list, provider_name)
(["mini", "o3mini"], ["mini", "o3mini", "o4-mini", "o3-mini"], "OpenAI"),
(
["flash", "pro"],
["flash", "pro", "gemini-2.5-flash", "gemini-2.5-pro"],
"Gemini",
),
]
for old_list, new_list, provider_name in providers_old_vs_new:
# Count how many additional models are now covered
old_coverage = set(old_list)
new_coverage = set(new_list)
additional_coverage = new_coverage - old_coverage
# There should be additional target models covered
assert len(additional_coverage) > 0, f"{provider_name}: Should have additional target coverage"
# All old models should still be covered
assert old_coverage.issubset(new_coverage), f"{provider_name}: Should maintain backward compatibility"
print(f"{provider_name} provider:")
print(f" Old coverage: {sorted(old_coverage)}")
print(f" New coverage: {sorted(new_coverage)}")
print(f" Additional models: {sorted(additional_coverage)}")
def test_comprehensive_alias_target_mapping_verification(self):
"""
Verify that our fix provides comprehensive alias->target coverage.
"""
from providers.gemini import GeminiModelProvider
from providers.openai_provider import OpenAIModelProvider
# Test real providers to ensure they implement our fix correctly
providers = [OpenAIModelProvider(api_key="test-key"), GeminiModelProvider(api_key="test-key")]
for provider in providers:
all_known = provider.list_all_known_models()
# Check that every model and its aliases appear in the comprehensive list
for model_name, config in provider.MODEL_CAPABILITIES.items():
assert model_name.lower() in all_known, f"{provider.__class__.__name__}: Missing model {model_name}"
for alias in getattr(config, "aliases", []):
assert (
alias.lower() in all_known
), f"{provider.__class__.__name__}: Missing alias {alias} for model {model_name}"
assert (
provider._resolve_model_name(alias) == model_name
), f"{provider.__class__.__name__}: Alias {alias} should resolve to {model_name}"

View File

@@ -26,10 +26,7 @@ class TestOpenAICompatibleTokenUsage(unittest.TestCase):
def validate_model_name(self, model_name): def validate_model_name(self, model_name):
return True return True
def list_models(self, respect_restrictions=True): def list_models(self, **kwargs):
return ["test-model"]
def list_all_known_models(self):
return ["test-model"] return ["test-model"]
self.provider = TestProvider("test-key") self.provider = TestProvider("test-key")

View File

@@ -151,7 +151,7 @@ class TestOpenRouterAutoMode:
os.environ["DEFAULT_MODEL"] = "auto" os.environ["DEFAULT_MODEL"] = "auto"
mock_registry = Mock() mock_registry = Mock()
mock_registry.list_models.return_value = [ model_names = [
"google/gemini-2.5-flash", "google/gemini-2.5-flash",
"google/gemini-2.5-pro", "google/gemini-2.5-pro",
"openai/o3", "openai/o3",
@@ -159,6 +159,18 @@ class TestOpenRouterAutoMode:
"anthropic/claude-opus-4.1", "anthropic/claude-opus-4.1",
"anthropic/claude-sonnet-4.1", "anthropic/claude-sonnet-4.1",
] ]
mock_registry.list_models.return_value = model_names
# Mock resolve to return a ModelCapabilities-like object for each model
def mock_resolve(model_name):
if model_name in model_names:
mock_config = Mock()
mock_config.is_custom = False
mock_config.aliases = [] # Empty list of aliases
return mock_config
return None
mock_registry.resolve.side_effect = mock_resolve
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider) ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
@@ -171,8 +183,7 @@ class TestOpenRouterAutoMode:
assert len(available_models) > 0, "Should find OpenRouter models in auto mode" assert len(available_models) > 0, "Should find OpenRouter models in auto mode"
assert all(provider_type == ProviderType.OPENROUTER for provider_type in available_models.values()) assert all(provider_type == ProviderType.OPENROUTER for provider_type in available_models.values())
expected_models = mock_registry.list_models.return_value for model in model_names:
for model in expected_models:
assert model in available_models, f"Model {model} should be available" assert model in available_models, f"Model {model} should be available"
@pytest.mark.no_mock_provider @pytest.mark.no_mock_provider

View File

@@ -151,11 +151,16 @@ class TestSupportedModelsAliases:
assert "o3-2025-04-16" in dial_models assert "o3-2025-04-16" in dial_models
assert "o3" in dial_models assert "o3" in dial_models
def test_list_all_known_models_includes_aliases(self): def test_list_models_all_known_variant_includes_aliases(self):
"""Test that list_all_known_models returns all models and aliases in lowercase.""" """Unified list_models should support lowercase, alias-inclusive listings."""
# Test Gemini # Test Gemini
gemini_provider = GeminiModelProvider("test-key") gemini_provider = GeminiModelProvider("test-key")
gemini_all = gemini_provider.list_all_known_models() gemini_all = gemini_provider.list_models(
respect_restrictions=False,
include_aliases=True,
lowercase=True,
unique=True,
)
assert "gemini-2.5-flash" in gemini_all assert "gemini-2.5-flash" in gemini_all
assert "flash" in gemini_all assert "flash" in gemini_all
assert "gemini-2.5-pro" in gemini_all assert "gemini-2.5-pro" in gemini_all
@@ -165,7 +170,12 @@ class TestSupportedModelsAliases:
# Test OpenAI # Test OpenAI
openai_provider = OpenAIModelProvider("test-key") openai_provider = OpenAIModelProvider("test-key")
openai_all = openai_provider.list_all_known_models() openai_all = openai_provider.list_models(
respect_restrictions=False,
include_aliases=True,
lowercase=True,
unique=True,
)
assert "o4-mini" in openai_all assert "o4-mini" in openai_all
assert "mini" in openai_all assert "mini" in openai_all
assert "o3-mini" in openai_all assert "o3-mini" in openai_all

View File

@@ -30,13 +30,20 @@ logger = logging.getLogger(__name__)
class ModelRestrictionService: class ModelRestrictionService:
""" """Central authority for environment-driven model allowlists.
Centralized service for managing model usage restrictions.
This service: Role
1. Loads restrictions from environment variables at startup Interpret ``*_ALLOWED_MODELS`` environment variables, keep their
2. Validates restrictions against known models entries normalised (lowercase), and answer whether a provider/model
3. Provides a simple interface to check if a model is allowed pairing is permitted.
Responsibilities
* Parse, cache, and expose per-provider restriction sets
* Validate configuration by cross-checking each entry against the
providers alias-aware model list
* Offer helper methods such as ``is_allowed`` and ``filter_models`` to
enforce policy everywhere model names appear (tool selection, CLI
commands, etc.).
""" """
# Environment variable names # Environment variable names
@@ -94,9 +101,14 @@ class ModelRestrictionService:
# Get all supported models using the clean polymorphic interface # Get all supported models using the clean polymorphic interface
try: try:
# Use list_all_known_models to get both aliases and their targets # Gather canonical models and aliases with consistent formatting
all_models = provider.list_all_known_models() all_models = provider.list_models(
supported_models = {model.lower() for model in all_models} respect_restrictions=False,
include_aliases=True,
lowercase=True,
unique=True,
)
supported_models = set(all_models)
except Exception as e: except Exception as e:
logger.debug(f"Could not get model list from {provider_type.value} provider: {e}") logger.debug(f"Could not get model list from {provider_type.value} provider: {e}")
supported_models = set() supported_models = set()