refactor: moved temperature method from base provider to model capabilities
refactor: model listing cleanup, moved logic to model_capabilities.py docs: added AGENTS.md for onboarding Codex
This commit is contained in:
@@ -18,13 +18,26 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class ModelProvider(ABC):
|
class ModelProvider(ABC):
|
||||||
"""Defines the contract implemented by every model provider backend.
|
"""Abstract base class for all model backends in the MCP server.
|
||||||
|
|
||||||
Subclasses adapt third-party SDKs into the MCP server by exposing
|
Role
|
||||||
capability metadata, request execution, and token counting through a
|
Defines the interface every provider must implement so the registry,
|
||||||
consistent interface. Shared helper methods (temperature validation,
|
restriction service, and tools have a uniform surface for listing
|
||||||
alias resolution, image handling, etc.) live here so individual providers
|
models, resolving aliases, and executing requests.
|
||||||
only need to focus on provider-specific details.
|
|
||||||
|
Responsibilities
|
||||||
|
* expose static capability metadata for each supported model via
|
||||||
|
:class:`ModelCapabilities`
|
||||||
|
* accept user prompts, forward them to the underlying SDK, and wrap
|
||||||
|
responses in :class:`ModelResponse`
|
||||||
|
* report tokenizer counts for budgeting and validation logic
|
||||||
|
* advertise provider identity (``ProviderType``) so restriction
|
||||||
|
policies can map environment configuration onto providers
|
||||||
|
* validate whether a model name or alias is recognised by the provider
|
||||||
|
|
||||||
|
Shared helpers like temperature validation, alias resolution, and
|
||||||
|
restriction-aware ``list_models`` live here so concrete subclasses only
|
||||||
|
need to supply their catalogue and wire up SDK-specific behaviour.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# All concrete providers must define their supported models
|
# All concrete providers must define their supported models
|
||||||
@@ -151,67 +164,52 @@ class ModelProvider(ABC):
|
|||||||
# If not found, return as-is
|
# If not found, return as-is
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
def list_models(
|
||||||
"""Return a list of model names supported by this provider.
|
self,
|
||||||
|
*,
|
||||||
This implementation uses the get_model_configurations() hook
|
respect_restrictions: bool = True,
|
||||||
to support different model configuration sources.
|
include_aliases: bool = True,
|
||||||
|
lowercase: bool = False,
|
||||||
|
unique: bool = False,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Return formatted model names supported by this provider.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
respect_restrictions: Apply provider restriction policy.
|
||||||
|
include_aliases: Include aliases alongside canonical model names.
|
||||||
|
lowercase: Normalize returned names to lowercase.
|
||||||
|
unique: Deduplicate names after formatting.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of model names available from this provider
|
List of model names formatted according to the provided options.
|
||||||
"""
|
"""
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
|
||||||
models = []
|
|
||||||
|
|
||||||
# Get model configurations from the hook method
|
|
||||||
model_configs = self.get_model_configurations()
|
model_configs = self.get_model_configurations()
|
||||||
|
if not model_configs:
|
||||||
|
return []
|
||||||
|
|
||||||
for model_name in model_configs:
|
restriction_service = None
|
||||||
# Check restrictions if enabled
|
if respect_restrictions:
|
||||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
from utils.model_restrictions import get_restriction_service
|
||||||
continue
|
|
||||||
|
|
||||||
# Add the base model
|
restriction_service = get_restriction_service()
|
||||||
models.append(model_name)
|
|
||||||
|
|
||||||
# Add aliases derived from the model configurations
|
if restriction_service:
|
||||||
alias_map = ModelCapabilities.collect_aliases(model_configs)
|
allowed_configs = {}
|
||||||
for model_name, aliases in alias_map.items():
|
for model_name, config in model_configs.items():
|
||||||
# Only add aliases for models that passed restriction check
|
if restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||||
if model_name in models:
|
allowed_configs[model_name] = config
|
||||||
models.extend(aliases)
|
model_configs = allowed_configs
|
||||||
|
|
||||||
return models
|
if not model_configs:
|
||||||
|
return []
|
||||||
|
|
||||||
def list_all_known_models(self) -> list[str]:
|
return ModelCapabilities.collect_model_names(
|
||||||
"""Return all model names known by this provider, including alias targets.
|
model_configs,
|
||||||
|
include_aliases=include_aliases,
|
||||||
This is used for validation purposes to ensure restriction policies
|
lowercase=lowercase,
|
||||||
can validate against both aliases and their target model names.
|
unique=unique,
|
||||||
|
)
|
||||||
Returns:
|
|
||||||
List of all model names and alias targets known by this provider
|
|
||||||
"""
|
|
||||||
all_models = set()
|
|
||||||
|
|
||||||
# Get model configurations from the hook method
|
|
||||||
model_configs = self.get_model_configurations()
|
|
||||||
|
|
||||||
# Add all base model names
|
|
||||||
for model_name in model_configs:
|
|
||||||
all_models.add(model_name.lower())
|
|
||||||
|
|
||||||
# Add aliases derived from the model configurations
|
|
||||||
for aliases in ModelCapabilities.collect_aliases(model_configs).values():
|
|
||||||
for alias in aliases:
|
|
||||||
all_models.add(alias.lower())
|
|
||||||
|
|
||||||
return list(all_models)
|
|
||||||
|
|
||||||
def validate_image(self, image_path: str, max_size_mb: float = None) -> tuple[bytes, str]:
|
def validate_image(self, image_path: str, max_size_mb: float = None) -> tuple[bytes, str]:
|
||||||
"""Provider-independent image validation.
|
"""Provider-independent image validation.
|
||||||
|
|||||||
@@ -32,11 +32,20 @@ _TEMP_UNSUPPORTED_KEYWORDS = [
|
|||||||
class CustomProvider(OpenAICompatibleProvider):
|
class CustomProvider(OpenAICompatibleProvider):
|
||||||
"""Adapter for self-hosted or local OpenAI-compatible endpoints.
|
"""Adapter for self-hosted or local OpenAI-compatible endpoints.
|
||||||
|
|
||||||
The provider reuses the :mod:`providers.shared` registry to surface
|
Role
|
||||||
user-defined aliases and capability metadata. It also normalises
|
Provide a uniform bridge between the MCP server and user-managed
|
||||||
Ollama-style version tags (``model:latest``) and enforces the same
|
OpenAI-compatible services (Ollama, vLLM, LM Studio, bespoke gateways).
|
||||||
restriction policies used by cloud providers, ensuring consistent
|
By subclassing :class:`OpenAICompatibleProvider` it inherits request and
|
||||||
behaviour regardless of where the model is hosted.
|
token handling, while the custom registry exposes locally defined model
|
||||||
|
metadata.
|
||||||
|
|
||||||
|
Notable behaviour
|
||||||
|
* Uses :class:`OpenRouterModelRegistry` to load model definitions and
|
||||||
|
aliases so custom deployments share the same metadata pipeline as
|
||||||
|
OpenRouter itself.
|
||||||
|
* Normalises version-tagged model names (``model:latest``) and applies
|
||||||
|
restriction policies just like cloud providers, ensuring consistent
|
||||||
|
behaviour across environments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
FRIENDLY_NAME = "Custom API"
|
FRIENDLY_NAME = "Custom API"
|
||||||
|
|||||||
@@ -17,9 +17,19 @@ from .shared import (
|
|||||||
class OpenRouterProvider(OpenAICompatibleProvider):
|
class OpenRouterProvider(OpenAICompatibleProvider):
|
||||||
"""Client for OpenRouter's multi-model aggregation service.
|
"""Client for OpenRouter's multi-model aggregation service.
|
||||||
|
|
||||||
OpenRouter surfaces dozens of upstream vendors. This provider layers alias
|
Role
|
||||||
resolution, restriction-aware filtering, and sensible capability defaults
|
Surface OpenRouter’s dynamic catalogue through the same interface as
|
||||||
on top of the generic OpenAI-compatible plumbing.
|
native providers so tools can reference OpenRouter models and aliases
|
||||||
|
without special cases.
|
||||||
|
|
||||||
|
Characteristics
|
||||||
|
* Pulls live model definitions from :class:`OpenRouterModelRegistry`
|
||||||
|
(aliases, provider-specific metadata, capability hints)
|
||||||
|
* Applies alias-aware restriction checks before exposing models to the
|
||||||
|
registry or tooling
|
||||||
|
* Reuses :class:`OpenAICompatibleProvider` infrastructure for request
|
||||||
|
execution so OpenRouter endpoints behave like standard OpenAI-style
|
||||||
|
APIs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
FRIENDLY_NAME = "OpenRouter"
|
FRIENDLY_NAME = "OpenRouter"
|
||||||
@@ -208,75 +218,56 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
|||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
def list_models(
|
||||||
"""Return a list of model names supported by this provider.
|
self,
|
||||||
|
*,
|
||||||
|
respect_restrictions: bool = True,
|
||||||
|
include_aliases: bool = True,
|
||||||
|
lowercase: bool = False,
|
||||||
|
unique: bool = False,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Return formatted OpenRouter model names, respecting alias-aware restrictions."""
|
||||||
|
|
||||||
Args:
|
if not self._registry:
|
||||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
return []
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of model names available from this provider
|
|
||||||
"""
|
|
||||||
from utils.model_restrictions import get_restriction_service
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||||
models = []
|
allowed_configs: dict[str, ModelCapabilities] = {}
|
||||||
|
|
||||||
if self._registry:
|
for model_name in self._registry.list_models():
|
||||||
for model_name in self._registry.list_models():
|
config = self._registry.resolve(model_name)
|
||||||
# =====================================================================================
|
if not config:
|
||||||
# CRITICAL ALIAS-AWARE RESTRICTION CHECKING (Fixed Issue #98)
|
continue
|
||||||
# =====================================================================================
|
|
||||||
# Previously, restrictions only checked full model names (e.g., "google/gemini-2.5-pro")
|
|
||||||
# but users specify aliases in OPENROUTER_ALLOWED_MODELS (e.g., "pro").
|
|
||||||
# This caused "no models available" error even with valid restrictions.
|
|
||||||
#
|
|
||||||
# Fix: Check both model name AND all aliases against restrictions
|
|
||||||
# TEST COVERAGE: tests/test_provider_routing_bugs.py::TestOpenRouterAliasRestrictions
|
|
||||||
# =====================================================================================
|
|
||||||
if restriction_service:
|
|
||||||
# Get model config to check aliases as well
|
|
||||||
model_config = self._registry.resolve(model_name)
|
|
||||||
allowed = False
|
|
||||||
|
|
||||||
# Check if model name itself is allowed
|
if restriction_service:
|
||||||
if restriction_service.is_allowed(self.get_provider_type(), model_name):
|
allowed = restriction_service.is_allowed(self.get_provider_type(), model_name)
|
||||||
allowed = True
|
|
||||||
|
|
||||||
# CRITICAL: Also check aliases - this fixes the alias restriction bug
|
if not allowed and config.aliases:
|
||||||
if not allowed and model_config and model_config.aliases:
|
for alias in config.aliases:
|
||||||
for alias in model_config.aliases:
|
if restriction_service.is_allowed(self.get_provider_type(), alias):
|
||||||
if restriction_service.is_allowed(self.get_provider_type(), alias):
|
allowed = True
|
||||||
allowed = True
|
break
|
||||||
break
|
|
||||||
|
|
||||||
if not allowed:
|
if not allowed:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
models.append(model_name)
|
allowed_configs[model_name] = config
|
||||||
|
|
||||||
return models
|
if not allowed_configs:
|
||||||
|
return []
|
||||||
|
|
||||||
def list_all_known_models(self) -> list[str]:
|
# When restrictions are in place, don't include aliases to avoid confusion
|
||||||
"""Return all model names known by this provider, including alias targets.
|
# Only return the canonical model names that are actually allowed
|
||||||
|
actual_include_aliases = include_aliases and not respect_restrictions
|
||||||
|
|
||||||
Returns:
|
return ModelCapabilities.collect_model_names(
|
||||||
List of all model names and alias targets known by this provider
|
allowed_configs,
|
||||||
"""
|
include_aliases=actual_include_aliases,
|
||||||
all_models = set()
|
lowercase=lowercase,
|
||||||
|
unique=unique,
|
||||||
if self._registry:
|
)
|
||||||
# Get all models and aliases from the registry
|
|
||||||
all_models.update(model.lower() for model in self._registry.list_models())
|
|
||||||
all_models.update(alias.lower() for alias in self._registry.list_aliases())
|
|
||||||
|
|
||||||
# For each alias, also add its target
|
|
||||||
for alias in self._registry.list_aliases():
|
|
||||||
config = self._registry.resolve(alias)
|
|
||||||
if config:
|
|
||||||
all_models.add(config.model_name.lower())
|
|
||||||
|
|
||||||
return list(all_models)
|
|
||||||
|
|
||||||
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
||||||
"""Get model configurations from the registry.
|
"""Get model configurations from the registry.
|
||||||
|
|||||||
@@ -17,12 +17,21 @@ from .shared import (
|
|||||||
|
|
||||||
|
|
||||||
class OpenRouterModelRegistry:
|
class OpenRouterModelRegistry:
|
||||||
"""Loads and validates the OpenRouter/custom model catalogue.
|
"""In-memory view of OpenRouter and custom model metadata.
|
||||||
|
|
||||||
The registry parses ``conf/custom_models.json`` (or an override supplied via
|
Role
|
||||||
environment variable), builds case-insensitive alias maps, and exposes
|
Parse the packaged ``conf/custom_models.json`` (or user-specified
|
||||||
:class:`~providers.shared.ModelCapabilities` objects used by several
|
overrides), construct alias and capability maps, and serve those
|
||||||
providers.
|
structures to providers that rely on OpenRouter semantics (both the
|
||||||
|
OpenRouter provider itself and the Custom provider).
|
||||||
|
|
||||||
|
Key duties
|
||||||
|
* Load :class:`ModelCapabilities` definitions from configuration files
|
||||||
|
* Maintain a case-insensitive alias → canonical name map for fast
|
||||||
|
resolution
|
||||||
|
* Provide helpers to list models, list aliases, and resolve an arbitrary
|
||||||
|
name to its capability object without repeatedly touching the file
|
||||||
|
system.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config_path: Optional[str] = None):
|
def __init__(self, config_path: Optional[str] = None):
|
||||||
|
|||||||
@@ -12,11 +12,22 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class ModelProviderRegistry:
|
class ModelProviderRegistry:
|
||||||
"""Singleton that caches provider instances and coordinates priority order.
|
"""Central catalogue of provider implementations used by the MCP server.
|
||||||
|
|
||||||
Responsibilities include resolving API keys from the environment, lazily
|
Role
|
||||||
instantiating providers, and choosing the best provider for a model based
|
Holds the mapping between :class:`ProviderType` values and concrete
|
||||||
on restriction policies and provider priority.
|
:class:`ModelProvider` subclasses/factories. At runtime the registry
|
||||||
|
is responsible for instantiating providers, caching them for reuse, and
|
||||||
|
mediating lookup of providers and model names in provider priority
|
||||||
|
order.
|
||||||
|
|
||||||
|
Core responsibilities
|
||||||
|
* Resolve API keys and other runtime configuration for each provider
|
||||||
|
* Lazily create provider instances so unused backends incur no cost
|
||||||
|
* Expose convenience methods for enumerating available models and
|
||||||
|
locating which provider can service a requested model name or alias
|
||||||
|
* Honour the project-wide provider priority policy so namespaces (or
|
||||||
|
alias collisions) are resolved deterministically.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_instance = None
|
_instance = None
|
||||||
|
|||||||
@@ -11,24 +11,46 @@ __all__ = ["ModelCapabilities"]
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelCapabilities:
|
class ModelCapabilities:
|
||||||
"""Static capabilities and constraints for a provider-managed model."""
|
"""Static description of what a model can do within a provider.
|
||||||
|
|
||||||
|
Role
|
||||||
|
Acts as the canonical record for everything the server needs to know
|
||||||
|
about a model—its provider, token limits, feature switches, aliases,
|
||||||
|
and temperature rules. Providers populate these objects so tools and
|
||||||
|
higher-level services can rely on a consistent schema.
|
||||||
|
|
||||||
|
Typical usage
|
||||||
|
* Provider subclasses declare `MODEL_CAPABILITIES` maps containing these
|
||||||
|
objects (for example ``OpenAIModelProvider``)
|
||||||
|
* Helper utilities (e.g. restriction validation, alias expansion) read
|
||||||
|
these objects to build model lists for tooling and policy enforcement
|
||||||
|
* Tool selection logic inspects attributes such as
|
||||||
|
``supports_extended_thinking`` or ``context_window`` to choose an
|
||||||
|
appropriate model for a task.
|
||||||
|
"""
|
||||||
|
|
||||||
provider: ProviderType
|
provider: ProviderType
|
||||||
model_name: str
|
model_name: str
|
||||||
friendly_name: str
|
friendly_name: str
|
||||||
context_window: int
|
description: str = ""
|
||||||
max_output_tokens: int
|
aliases: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
# Capacity limits / resource budgets
|
||||||
|
context_window: int = 0
|
||||||
|
max_output_tokens: int = 0
|
||||||
|
max_thinking_tokens: int = 0
|
||||||
|
|
||||||
|
# Capability flags
|
||||||
supports_extended_thinking: bool = False
|
supports_extended_thinking: bool = False
|
||||||
supports_system_prompts: bool = True
|
supports_system_prompts: bool = True
|
||||||
supports_streaming: bool = True
|
supports_streaming: bool = True
|
||||||
supports_function_calling: bool = False
|
supports_function_calling: bool = False
|
||||||
supports_images: bool = False
|
supports_images: bool = False
|
||||||
max_image_size_mb: float = 0.0
|
|
||||||
supports_temperature: bool = True
|
|
||||||
description: str = ""
|
|
||||||
aliases: list[str] = field(default_factory=list)
|
|
||||||
supports_json_mode: bool = False
|
supports_json_mode: bool = False
|
||||||
max_thinking_tokens: int = 0
|
supports_temperature: bool = True
|
||||||
|
|
||||||
|
# Additional attributes
|
||||||
|
max_image_size_mb: float = 0.0
|
||||||
is_custom: bool = False
|
is_custom: bool = False
|
||||||
temperature_constraint: TemperatureConstraint = field(
|
temperature_constraint: TemperatureConstraint = field(
|
||||||
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
||||||
@@ -56,3 +78,45 @@ class ModelCapabilities:
|
|||||||
for base_model, capabilities in model_configs.items()
|
for base_model, capabilities in model_configs.items()
|
||||||
if capabilities.aliases
|
if capabilities.aliases
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def collect_model_names(
|
||||||
|
model_configs: dict[str, "ModelCapabilities"],
|
||||||
|
*,
|
||||||
|
include_aliases: bool = True,
|
||||||
|
lowercase: bool = False,
|
||||||
|
unique: bool = False,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Build an ordered list of model names and aliases.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_configs: Mapping of canonical model names to capabilities.
|
||||||
|
include_aliases: When True, include aliases for each model.
|
||||||
|
lowercase: When True, normalize names to lowercase.
|
||||||
|
unique: When True, ensure each returned name appears once (after formatting).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Ordered list of model names (and optionally aliases) formatted per options.
|
||||||
|
"""
|
||||||
|
|
||||||
|
formatted_names: list[str] = []
|
||||||
|
seen: set[str] | None = set() if unique else None
|
||||||
|
|
||||||
|
def append_name(name: str) -> None:
|
||||||
|
formatted = name.lower() if lowercase else name
|
||||||
|
|
||||||
|
if seen is not None:
|
||||||
|
if formatted in seen:
|
||||||
|
return
|
||||||
|
seen.add(formatted)
|
||||||
|
|
||||||
|
formatted_names.append(formatted)
|
||||||
|
|
||||||
|
for base_model, capabilities in model_configs.items():
|
||||||
|
append_name(base_model)
|
||||||
|
|
||||||
|
if include_aliases and capabilities.aliases:
|
||||||
|
for alias in capabilities.aliases:
|
||||||
|
append_name(alias)
|
||||||
|
|
||||||
|
return formatted_names
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ class TestAliasTargetRestrictions:
|
|||||||
provider = OpenAIModelProvider(api_key="test-key")
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
# Get all known models including aliases and targets
|
# Get all known models including aliases and targets
|
||||||
all_known = provider.list_all_known_models()
|
all_known = provider.list_models(respect_restrictions=False, include_aliases=True, lowercase=True, unique=True)
|
||||||
|
|
||||||
# Should include both aliases and their targets
|
# Should include both aliases and their targets
|
||||||
assert "mini" in all_known # alias
|
assert "mini" in all_known # alias
|
||||||
@@ -35,7 +35,7 @@ class TestAliasTargetRestrictions:
|
|||||||
provider = GeminiModelProvider(api_key="test-key")
|
provider = GeminiModelProvider(api_key="test-key")
|
||||||
|
|
||||||
# Get all known models including aliases and targets
|
# Get all known models including aliases and targets
|
||||||
all_known = provider.list_all_known_models()
|
all_known = provider.list_models(respect_restrictions=False, include_aliases=True, lowercase=True, unique=True)
|
||||||
|
|
||||||
# Should include both aliases and their targets
|
# Should include both aliases and their targets
|
||||||
assert "flash" in all_known # alias
|
assert "flash" in all_known # alias
|
||||||
@@ -162,7 +162,9 @@ class TestAliasTargetRestrictions:
|
|||||||
"""
|
"""
|
||||||
# Test OpenAI provider
|
# Test OpenAI provider
|
||||||
openai_provider = OpenAIModelProvider(api_key="test-key")
|
openai_provider = OpenAIModelProvider(api_key="test-key")
|
||||||
openai_all_known = openai_provider.list_all_known_models()
|
openai_all_known = openai_provider.list_models(
|
||||||
|
respect_restrictions=False, include_aliases=True, lowercase=True, unique=True
|
||||||
|
)
|
||||||
|
|
||||||
# Verify that for each alias, its target is also included
|
# Verify that for each alias, its target is also included
|
||||||
for model_name, config in openai_provider.MODEL_CAPABILITIES.items():
|
for model_name, config in openai_provider.MODEL_CAPABILITIES.items():
|
||||||
@@ -175,7 +177,9 @@ class TestAliasTargetRestrictions:
|
|||||||
|
|
||||||
# Test Gemini provider
|
# Test Gemini provider
|
||||||
gemini_provider = GeminiModelProvider(api_key="test-key")
|
gemini_provider = GeminiModelProvider(api_key="test-key")
|
||||||
gemini_all_known = gemini_provider.list_all_known_models()
|
gemini_all_known = gemini_provider.list_models(
|
||||||
|
respect_restrictions=False, include_aliases=True, lowercase=True, unique=True
|
||||||
|
)
|
||||||
|
|
||||||
# Verify that for each alias, its target is also included
|
# Verify that for each alias, its target is also included
|
||||||
for model_name, config in gemini_provider.MODEL_CAPABILITIES.items():
|
for model_name, config in gemini_provider.MODEL_CAPABILITIES.items():
|
||||||
@@ -186,8 +190,8 @@ class TestAliasTargetRestrictions:
|
|||||||
config.lower() in gemini_all_known
|
config.lower() in gemini_all_known
|
||||||
), f"Target '{config}' for alias '{model_name}' not in known models"
|
), f"Target '{config}' for alias '{model_name}' not in known models"
|
||||||
|
|
||||||
def test_no_duplicate_models_in_list_all_known_models(self):
|
def test_no_duplicate_models_in_alias_aware_listing(self):
|
||||||
"""Test that list_all_known_models doesn't return duplicates."""
|
"""Test that alias-aware list_models variant doesn't return duplicates."""
|
||||||
# Test all providers
|
# Test all providers
|
||||||
providers = [
|
providers = [
|
||||||
OpenAIModelProvider(api_key="test-key"),
|
OpenAIModelProvider(api_key="test-key"),
|
||||||
@@ -195,7 +199,9 @@ class TestAliasTargetRestrictions:
|
|||||||
]
|
]
|
||||||
|
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
all_known = provider.list_all_known_models()
|
all_known = provider.list_models(
|
||||||
|
respect_restrictions=False, include_aliases=True, lowercase=True, unique=True
|
||||||
|
)
|
||||||
# Should not have duplicates
|
# Should not have duplicates
|
||||||
assert len(all_known) == len(set(all_known)), f"{provider.__class__.__name__} returns duplicate models"
|
assert len(all_known) == len(set(all_known)), f"{provider.__class__.__name__} returns duplicate models"
|
||||||
|
|
||||||
@@ -207,7 +213,7 @@ class TestAliasTargetRestrictions:
|
|||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
mock_provider = MagicMock()
|
mock_provider = MagicMock()
|
||||||
mock_provider.list_all_known_models.return_value = ["model1", "model2", "target-model"]
|
mock_provider.list_models.return_value = ["model1", "model2", "target-model"]
|
||||||
|
|
||||||
# Set up a restriction that should trigger validation
|
# Set up a restriction that should trigger validation
|
||||||
service.restrictions = {ProviderType.OPENAI: {"invalid-model"}}
|
service.restrictions = {ProviderType.OPENAI: {"invalid-model"}}
|
||||||
@@ -218,7 +224,12 @@ class TestAliasTargetRestrictions:
|
|||||||
service.validate_against_known_models(provider_instances)
|
service.validate_against_known_models(provider_instances)
|
||||||
|
|
||||||
# Verify the polymorphic method was called
|
# Verify the polymorphic method was called
|
||||||
mock_provider.list_all_known_models.assert_called_once()
|
mock_provider.list_models.assert_called_once_with(
|
||||||
|
respect_restrictions=False,
|
||||||
|
include_aliases=True,
|
||||||
|
lowercase=True,
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
|
||||||
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini"}) # Restrict to specific model
|
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini"}) # Restrict to specific model
|
||||||
def test_complex_alias_chains_handled_correctly(self):
|
def test_complex_alias_chains_handled_correctly(self):
|
||||||
@@ -250,7 +261,7 @@ class TestAliasTargetRestrictions:
|
|||||||
- A restriction on "o4-mini" (target) would not be recognized as valid
|
- A restriction on "o4-mini" (target) would not be recognized as valid
|
||||||
|
|
||||||
After the fix:
|
After the fix:
|
||||||
- list_all_known_models() returns ["mini", "o3mini", "o4-mini", "o3-mini"] (aliases + targets)
|
- list_models(respect_restrictions=False, include_aliases=True, lowercase=True, unique=True) returns ["mini", "o3mini", "o4-mini", "o3-mini"] (aliases + targets)
|
||||||
- validate_against_known_models() checks against all names
|
- validate_against_known_models() checks against all names
|
||||||
- A restriction on "o4-mini" is recognized as valid
|
- A restriction on "o4-mini" is recognized as valid
|
||||||
"""
|
"""
|
||||||
@@ -262,7 +273,7 @@ class TestAliasTargetRestrictions:
|
|||||||
provider_instances = {ProviderType.OPENAI: provider}
|
provider_instances = {ProviderType.OPENAI: provider}
|
||||||
|
|
||||||
# Get all known models - should include BOTH aliases AND targets
|
# Get all known models - should include BOTH aliases AND targets
|
||||||
all_known = provider.list_all_known_models()
|
all_known = provider.list_models(respect_restrictions=False, include_aliases=True, lowercase=True, unique=True)
|
||||||
|
|
||||||
# Critical check: should contain both aliases and their targets
|
# Critical check: should contain both aliases and their targets
|
||||||
assert "mini" in all_known # alias
|
assert "mini" in all_known # alias
|
||||||
@@ -310,7 +321,7 @@ class TestAliasTargetRestrictions:
|
|||||||
the restriction is properly enforced and the target is recognized as a valid
|
the restriction is properly enforced and the target is recognized as a valid
|
||||||
model to restrict.
|
model to restrict.
|
||||||
|
|
||||||
The bug: If list_all_known_models() doesn't include targets, then validation
|
The bug: If list_models(respect_restrictions=False, include_aliases=True, lowercase=True, unique=True) doesn't include targets, then validation
|
||||||
would incorrectly warn that target model names are "not recognized", making
|
would incorrectly warn that target model names are "not recognized", making
|
||||||
it appear that target-based restrictions don't work.
|
it appear that target-based restrictions don't work.
|
||||||
"""
|
"""
|
||||||
@@ -325,7 +336,9 @@ class TestAliasTargetRestrictions:
|
|||||||
provider = OpenAIModelProvider(api_key="test-key")
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
# These specific target models should be recognized as valid
|
# These specific target models should be recognized as valid
|
||||||
all_known = provider.list_all_known_models()
|
all_known = provider.list_models(
|
||||||
|
respect_restrictions=False, include_aliases=True, lowercase=True, unique=True
|
||||||
|
)
|
||||||
assert "o4-mini" in all_known, "Target model o4-mini should be known"
|
assert "o4-mini" in all_known, "Target model o4-mini should be known"
|
||||||
assert "o3-mini" in all_known, "Target model o3-mini should be known"
|
assert "o3-mini" in all_known, "Target model o3-mini should be known"
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
Tests that demonstrate the OLD BUGGY BEHAVIOR is now FIXED.
|
Regression scenarios ensuring alias-aware model listings stay correct.
|
||||||
|
|
||||||
These tests verify that scenarios which would have incorrectly passed
|
Each test captures behavior that previously regressed so we can guard it
|
||||||
before our fix now behave correctly. Each test documents the specific
|
permanently. The focus is confirming aliases and their canonical targets
|
||||||
bug that was fixed and what the old vs new behavior should be.
|
remain visible to the restriction service and related validation logic.
|
||||||
|
|
||||||
IMPORTANT: These tests PASS with our fix, but would have FAILED to catch
|
|
||||||
bugs with the old code (before list_all_known_models was implemented).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@@ -21,42 +18,34 @@ from utils.model_restrictions import ModelRestrictionService
|
|||||||
|
|
||||||
|
|
||||||
class TestBuggyBehaviorPrevention:
|
class TestBuggyBehaviorPrevention:
|
||||||
"""
|
"""Regression tests for alias-aware restriction validation."""
|
||||||
These tests prove that our fix prevents the HIGH-severity regression
|
|
||||||
that was identified by the O3 precommit analysis.
|
|
||||||
|
|
||||||
OLD BUG: list_models() only returned alias keys, not targets
|
def test_alias_listing_includes_targets_for_restriction_validation(self):
|
||||||
FIX: list_all_known_models() returns both aliases AND targets
|
"""Alias-aware lists expose both aliases and canonical targets."""
|
||||||
"""
|
|
||||||
|
|
||||||
def test_old_bug_would_miss_target_restrictions(self):
|
|
||||||
"""
|
|
||||||
OLD BUG: If restriction was set on target model (e.g., 'o4-mini'),
|
|
||||||
validation would incorrectly warn it's not recognized because
|
|
||||||
list_models() only returned aliases ['mini', 'o3mini'].
|
|
||||||
|
|
||||||
NEW BEHAVIOR: list_all_known_models() includes targets, so 'o4-mini'
|
|
||||||
is recognized as valid and no warning is generated.
|
|
||||||
"""
|
|
||||||
provider = OpenAIModelProvider(api_key="test-key")
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
# This is what the old broken list_models() would return - aliases only
|
# Baseline alias-only list captured for regression documentation
|
||||||
old_broken_list = ["mini", "o3mini"] # Missing 'o4-mini', 'o3-mini' targets
|
alias_only_snapshot = ["mini", "o3mini"] # Missing 'o4-mini', 'o3-mini' targets
|
||||||
|
|
||||||
# This is what our fixed list_all_known_models() returns
|
# Canonical listing with aliases and targets
|
||||||
new_fixed_list = provider.list_all_known_models()
|
comprehensive_list = provider.list_models(
|
||||||
|
respect_restrictions=False,
|
||||||
|
include_aliases=True,
|
||||||
|
lowercase=True,
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Verify the fix: new method includes both aliases AND targets
|
# Comprehensive listing should contain aliases and their targets
|
||||||
assert "mini" in new_fixed_list # alias
|
assert "mini" in comprehensive_list
|
||||||
assert "o4-mini" in new_fixed_list # target - THIS WAS MISSING IN OLD CODE
|
assert "o4-mini" in comprehensive_list
|
||||||
assert "o3mini" in new_fixed_list # alias
|
assert "o3mini" in comprehensive_list
|
||||||
assert "o3-mini" in new_fixed_list # target - THIS WAS MISSING IN OLD CODE
|
assert "o3-mini" in comprehensive_list
|
||||||
|
|
||||||
# Prove the old behavior was broken
|
# Legacy alias-only snapshots exclude targets
|
||||||
assert "o4-mini" not in old_broken_list # Old code didn't include targets
|
assert "o4-mini" not in alias_only_snapshot
|
||||||
assert "o3-mini" not in old_broken_list # Old code didn't include targets
|
assert "o3-mini" not in alias_only_snapshot
|
||||||
|
|
||||||
# This target validation would have FAILED with old code
|
# This scenario previously failed when targets were omitted
|
||||||
service = ModelRestrictionService()
|
service = ModelRestrictionService()
|
||||||
service.restrictions = {ProviderType.OPENAI: {"o4-mini"}} # Restrict to target
|
service.restrictions = {ProviderType.OPENAI: {"o4-mini"}} # Restrict to target
|
||||||
|
|
||||||
@@ -64,24 +53,19 @@ class TestBuggyBehaviorPrevention:
|
|||||||
provider_instances = {ProviderType.OPENAI: provider}
|
provider_instances = {ProviderType.OPENAI: provider}
|
||||||
service.validate_against_known_models(provider_instances)
|
service.validate_against_known_models(provider_instances)
|
||||||
|
|
||||||
# NEW BEHAVIOR: No warnings because o4-mini is now in list_all_known_models
|
# No warnings expected because alias-aware list includes the target
|
||||||
target_warnings = [
|
target_warnings = [
|
||||||
call
|
call
|
||||||
for call in mock_logger.warning.call_args_list
|
for call in mock_logger.warning.call_args_list
|
||||||
if "o4-mini" in str(call) and "not a recognized" in str(call)
|
if "o4-mini" in str(call) and "not a recognized" in str(call)
|
||||||
]
|
]
|
||||||
assert len(target_warnings) == 0, "o4-mini should be recognized with our fix"
|
assert len(target_warnings) == 0, "o4-mini should be recognized as a valid target"
|
||||||
|
|
||||||
def test_old_bug_would_incorrectly_warn_about_valid_targets(self):
|
def test_target_models_are_recognized_during_validation(self):
|
||||||
"""
|
"""Target model restrictions should not trigger false warnings."""
|
||||||
OLD BUG: Admins setting restrictions on target models would get
|
|
||||||
false warnings that their restriction models are "not recognized".
|
|
||||||
|
|
||||||
NEW BEHAVIOR: Target models are properly recognized.
|
|
||||||
"""
|
|
||||||
# Test with Gemini provider too
|
# Test with Gemini provider too
|
||||||
provider = GeminiModelProvider(api_key="test-key")
|
provider = GeminiModelProvider(api_key="test-key")
|
||||||
all_known = provider.list_all_known_models()
|
all_known = provider.list_models(respect_restrictions=False, include_aliases=True, lowercase=True, unique=True)
|
||||||
|
|
||||||
# Verify both aliases and targets are included
|
# Verify both aliases and targets are included
|
||||||
assert "flash" in all_known # alias
|
assert "flash" in all_known # alias
|
||||||
@@ -108,13 +92,8 @@ class TestBuggyBehaviorPrevention:
|
|||||||
assert "gemini-2.5-flash" not in warning or "not a recognized" not in warning
|
assert "gemini-2.5-flash" not in warning or "not a recognized" not in warning
|
||||||
assert "gemini-2.5-pro" not in warning or "not a recognized" not in warning
|
assert "gemini-2.5-pro" not in warning or "not a recognized" not in warning
|
||||||
|
|
||||||
def test_old_bug_policy_bypass_prevention(self):
|
def test_policy_enforcement_remains_comprehensive(self):
|
||||||
"""
|
"""Policy validation must account for both aliases and targets."""
|
||||||
OLD BUG: Policy enforcement was incomplete because validation
|
|
||||||
didn't know about target models. This could allow policy bypasses.
|
|
||||||
|
|
||||||
NEW BEHAVIOR: Complete validation against all known model names.
|
|
||||||
"""
|
|
||||||
provider = OpenAIModelProvider(api_key="test-key")
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
# Simulate a scenario where admin wants to restrict specific targets
|
# Simulate a scenario where admin wants to restrict specific targets
|
||||||
@@ -138,64 +117,85 @@ class TestBuggyBehaviorPrevention:
|
|||||||
# But o4mini (the actual alias for o4-mini) should work
|
# But o4mini (the actual alias for o4-mini) should work
|
||||||
assert provider.validate_model_name("o4mini") # Resolves to o4-mini, which IS allowed
|
assert provider.validate_model_name("o4mini") # Resolves to o4-mini, which IS allowed
|
||||||
|
|
||||||
# Verify our list_all_known_models includes the restricted models
|
# Verify our alias-aware list includes the restricted models
|
||||||
all_known = provider.list_all_known_models()
|
all_known = provider.list_models(
|
||||||
|
respect_restrictions=False,
|
||||||
|
include_aliases=True,
|
||||||
|
lowercase=True,
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
assert "o3-mini" in all_known # Should be known (and allowed)
|
assert "o3-mini" in all_known # Should be known (and allowed)
|
||||||
assert "o4-mini" in all_known # Should be known (and allowed)
|
assert "o4-mini" in all_known # Should be known (and allowed)
|
||||||
assert "o3-pro" in all_known # Should be known (but blocked)
|
assert "o3-pro" in all_known # Should be known (but blocked)
|
||||||
assert "mini" in all_known # Should be known (and allowed since it resolves to o4-mini)
|
assert "mini" in all_known # Should be known (and allowed since it resolves to o4-mini)
|
||||||
|
|
||||||
def test_demonstration_of_old_vs_new_interface(self):
|
def test_alias_aware_listing_extends_canonical_view(self):
|
||||||
"""
|
"""Alias-aware list should be a superset of restriction-filtered names."""
|
||||||
Direct comparison of old vs new interface to document the fix.
|
|
||||||
"""
|
|
||||||
provider = OpenAIModelProvider(api_key="test-key")
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
# OLD interface (still exists for backward compatibility)
|
baseline_models = provider.list_models(respect_restrictions=False)
|
||||||
old_style_models = provider.list_models(respect_restrictions=False)
|
|
||||||
|
|
||||||
# NEW interface (our fix)
|
alias_aware_models = provider.list_models(
|
||||||
new_comprehensive_models = provider.list_all_known_models()
|
respect_restrictions=False,
|
||||||
|
include_aliases=True,
|
||||||
|
lowercase=True,
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
|
||||||
# The new interface should be a superset of the old one
|
# Alias-aware variant should contain everything from the baseline
|
||||||
for model in old_style_models:
|
for model in baseline_models:
|
||||||
assert model.lower() in [
|
assert model.lower() in [
|
||||||
m.lower() for m in new_comprehensive_models
|
m.lower() for m in alias_aware_models
|
||||||
], f"New interface missing model {model} from old interface"
|
], f"Alias-aware listing missing baseline model {model}"
|
||||||
|
|
||||||
# The new interface should include target models that old one might miss
|
# Alias-aware variant should include canonical targets as well
|
||||||
targets_that_should_exist = ["o4-mini", "o3-mini"]
|
for target in ("o4-mini", "o3-mini"):
|
||||||
for target in targets_that_should_exist:
|
assert target in alias_aware_models, f"Alias-aware listing should include target model {target}"
|
||||||
assert target in new_comprehensive_models, f"New interface should include target model {target}"
|
|
||||||
|
|
||||||
def test_old_validation_interface_still_works(self):
|
def test_restriction_validation_uses_alias_aware_variant(self):
|
||||||
"""
|
"""Validation should request the alias-aware lowercased, deduped list."""
|
||||||
Verify our fix doesn't break existing validation workflows.
|
|
||||||
"""
|
|
||||||
service = ModelRestrictionService()
|
service = ModelRestrictionService()
|
||||||
|
|
||||||
# Create a mock provider that simulates the old behavior
|
# Simulate a provider that only returns aliases when asked for models
|
||||||
old_style_provider = MagicMock()
|
alias_only_provider = MagicMock()
|
||||||
old_style_provider.MODEL_CAPABILITIES = {
|
alias_only_provider.MODEL_CAPABILITIES = {
|
||||||
"mini": "o4-mini",
|
"mini": "o4-mini",
|
||||||
"o3mini": "o3-mini",
|
"o3mini": "o3-mini",
|
||||||
"o4-mini": {"context_window": 200000},
|
"o4-mini": {"context_window": 200000},
|
||||||
"o3-mini": {"context_window": 200000},
|
"o3-mini": {"context_window": 200000},
|
||||||
}
|
}
|
||||||
# OLD BROKEN: This would only return aliases
|
|
||||||
old_style_provider.list_models.return_value = ["mini", "o3mini"]
|
# Simulate alias-only vs. alias-aware behavior using a side effect
|
||||||
# NEW FIXED: This includes both aliases and targets
|
def list_models_side_effect(**kwargs):
|
||||||
old_style_provider.list_all_known_models.return_value = ["mini", "o3mini", "o4-mini", "o3-mini"]
|
respect_restrictions = kwargs.get("respect_restrictions", True)
|
||||||
|
include_aliases = kwargs.get("include_aliases", True)
|
||||||
|
lowercase = kwargs.get("lowercase", False)
|
||||||
|
unique = kwargs.get("unique", False)
|
||||||
|
|
||||||
|
if respect_restrictions and include_aliases and not lowercase and not unique:
|
||||||
|
return ["mini", "o3mini"]
|
||||||
|
|
||||||
|
if not respect_restrictions and include_aliases and lowercase and unique:
|
||||||
|
return ["mini", "o3mini", "o4-mini", "o3-mini"]
|
||||||
|
|
||||||
|
raise AssertionError(f"Unexpected list_models call: {kwargs}")
|
||||||
|
|
||||||
|
alias_only_provider.list_models.side_effect = list_models_side_effect
|
||||||
|
|
||||||
# Test that validation now uses the comprehensive method
|
# Test that validation now uses the comprehensive method
|
||||||
service.restrictions = {ProviderType.OPENAI: {"o4-mini"}} # Restrict to target
|
service.restrictions = {ProviderType.OPENAI: {"o4-mini"}} # Restrict to target
|
||||||
|
|
||||||
with patch("utils.model_restrictions.logger") as mock_logger:
|
with patch("utils.model_restrictions.logger") as mock_logger:
|
||||||
provider_instances = {ProviderType.OPENAI: old_style_provider}
|
provider_instances = {ProviderType.OPENAI: alias_only_provider}
|
||||||
service.validate_against_known_models(provider_instances)
|
service.validate_against_known_models(provider_instances)
|
||||||
|
|
||||||
# Verify the new method was called, not the old one
|
# Verify the alias-aware variant was used
|
||||||
old_style_provider.list_all_known_models.assert_called_once()
|
alias_only_provider.list_models.assert_called_with(
|
||||||
|
respect_restrictions=False,
|
||||||
|
include_aliases=True,
|
||||||
|
lowercase=True,
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Should not warn about o4-mini being unrecognized
|
# Should not warn about o4-mini being unrecognized
|
||||||
target_warnings = [
|
target_warnings = [
|
||||||
@@ -205,17 +205,17 @@ class TestBuggyBehaviorPrevention:
|
|||||||
]
|
]
|
||||||
assert len(target_warnings) == 0
|
assert len(target_warnings) == 0
|
||||||
|
|
||||||
def test_regression_proof_comprehensive_coverage(self):
|
def test_alias_listing_covers_targets_for_all_providers(self):
|
||||||
"""
|
"""Alias-aware listings should expose targets across providers."""
|
||||||
Comprehensive test to prove our fix covers all provider types.
|
|
||||||
"""
|
|
||||||
providers_to_test = [
|
providers_to_test = [
|
||||||
(OpenAIModelProvider(api_key="test-key"), "mini", "o4-mini"),
|
(OpenAIModelProvider(api_key="test-key"), "mini", "o4-mini"),
|
||||||
(GeminiModelProvider(api_key="test-key"), "flash", "gemini-2.5-flash"),
|
(GeminiModelProvider(api_key="test-key"), "flash", "gemini-2.5-flash"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for provider, alias, target in providers_to_test:
|
for provider, alias, target in providers_to_test:
|
||||||
all_known = provider.list_all_known_models()
|
all_known = provider.list_models(
|
||||||
|
respect_restrictions=False, include_aliases=True, lowercase=True, unique=True
|
||||||
|
)
|
||||||
|
|
||||||
# Every provider should include both aliases and targets
|
# Every provider should include both aliases and targets
|
||||||
assert alias in all_known, f"{provider.__class__.__name__} missing alias {alias}"
|
assert alias in all_known, f"{provider.__class__.__name__} missing alias {alias}"
|
||||||
@@ -226,13 +226,7 @@ class TestBuggyBehaviorPrevention:
|
|||||||
|
|
||||||
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini,invalid-model"})
|
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini,invalid-model"})
|
||||||
def test_validation_correctly_identifies_invalid_models(self):
|
def test_validation_correctly_identifies_invalid_models(self):
|
||||||
"""
|
"""Validation should flag invalid models while listing valid targets."""
|
||||||
Test that validation still catches truly invalid models while
|
|
||||||
properly recognizing valid target models.
|
|
||||||
|
|
||||||
This proves our fix works: o4-mini appears in the "Known models" list
|
|
||||||
because list_all_known_models() now includes target models.
|
|
||||||
"""
|
|
||||||
# Clear cached restriction service
|
# Clear cached restriction service
|
||||||
import utils.model_restrictions
|
import utils.model_restrictions
|
||||||
|
|
||||||
@@ -245,7 +239,6 @@ class TestBuggyBehaviorPrevention:
|
|||||||
provider_instances = {ProviderType.OPENAI: provider}
|
provider_instances = {ProviderType.OPENAI: provider}
|
||||||
service.validate_against_known_models(provider_instances)
|
service.validate_against_known_models(provider_instances)
|
||||||
|
|
||||||
# Should warn about 'invalid-model' (truly invalid)
|
|
||||||
invalid_warnings = [
|
invalid_warnings = [
|
||||||
call
|
call
|
||||||
for call in mock_logger.warning.call_args_list
|
for call in mock_logger.warning.call_args_list
|
||||||
@@ -253,39 +246,37 @@ class TestBuggyBehaviorPrevention:
|
|||||||
]
|
]
|
||||||
assert len(invalid_warnings) > 0, "Should warn about truly invalid models"
|
assert len(invalid_warnings) > 0, "Should warn about truly invalid models"
|
||||||
|
|
||||||
# The warning should mention o4-mini in the "Known models" list (proving our fix works)
|
# The warning should mention o4-mini in the known models list
|
||||||
warning_text = str(mock_logger.warning.call_args_list[0])
|
warning_text = str(mock_logger.warning.call_args_list[0])
|
||||||
assert "Known models:" in warning_text, "Warning should include known models list"
|
assert "Known models:" in warning_text, "Warning should include known models list"
|
||||||
assert "o4-mini" in warning_text, "o4-mini should appear in known models (proves our fix works)"
|
assert "o4-mini" in warning_text, "o4-mini should appear in known models"
|
||||||
assert "o3-mini" in warning_text, "o3-mini should appear in known models (proves our fix works)"
|
assert "o3-mini" in warning_text, "o3-mini should appear in known models"
|
||||||
|
|
||||||
# But the warning should be specifically about invalid-model
|
# But the warning should be specifically about invalid-model
|
||||||
assert "'invalid-model'" in warning_text, "Warning should specifically mention invalid-model"
|
assert "'invalid-model'" in warning_text, "Warning should specifically mention invalid-model"
|
||||||
|
|
||||||
def test_custom_provider_also_implements_fix(self):
|
def test_custom_provider_alias_listing(self):
|
||||||
"""
|
"""Custom provider should expose alias-aware listings as well."""
|
||||||
Verify that custom provider also implements the comprehensive interface.
|
|
||||||
"""
|
|
||||||
from providers.custom import CustomProvider
|
from providers.custom import CustomProvider
|
||||||
|
|
||||||
# This might fail if no URL is set, but that's expected
|
# This might fail if no URL is set, but that's expected
|
||||||
try:
|
try:
|
||||||
provider = CustomProvider(base_url="http://test.com/v1")
|
provider = CustomProvider(base_url="http://test.com/v1")
|
||||||
all_known = provider.list_all_known_models()
|
all_known = provider.list_models(
|
||||||
|
respect_restrictions=False, include_aliases=True, lowercase=True, unique=True
|
||||||
|
)
|
||||||
# Should return a list (might be empty if registry not loaded)
|
# Should return a list (might be empty if registry not loaded)
|
||||||
assert isinstance(all_known, list)
|
assert isinstance(all_known, list)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# Expected if no base_url configured, skip this test
|
# Expected if no base_url configured, skip this test
|
||||||
pytest.skip("Custom provider requires URL configuration")
|
pytest.skip("Custom provider requires URL configuration")
|
||||||
|
|
||||||
def test_openrouter_provider_also_implements_fix(self):
|
def test_openrouter_provider_alias_listing(self):
|
||||||
"""
|
"""OpenRouter provider should expose alias-aware listings."""
|
||||||
Verify that OpenRouter provider also implements the comprehensive interface.
|
|
||||||
"""
|
|
||||||
from providers.openrouter import OpenRouterProvider
|
from providers.openrouter import OpenRouterProvider
|
||||||
|
|
||||||
provider = OpenRouterProvider(api_key="test-key")
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
all_known = provider.list_all_known_models()
|
all_known = provider.list_models(respect_restrictions=False, include_aliases=True, lowercase=True, unique=True)
|
||||||
|
|
||||||
# Should return a list with both aliases and targets
|
# Should return a list with both aliases and targets
|
||||||
assert isinstance(all_known, list)
|
assert isinstance(all_known, list)
|
||||||
|
|||||||
@@ -142,7 +142,7 @@ class TestModelRestrictionService:
|
|||||||
"o3-mini": {"context_window": 200000},
|
"o3-mini": {"context_window": 200000},
|
||||||
"o4-mini": {"context_window": 200000},
|
"o4-mini": {"context_window": 200000},
|
||||||
}
|
}
|
||||||
mock_provider.list_all_known_models.return_value = ["o3", "o3-mini", "o4-mini"]
|
mock_provider.list_models.return_value = ["o3", "o3-mini", "o4-mini"]
|
||||||
|
|
||||||
provider_instances = {ProviderType.OPENAI: mock_provider}
|
provider_instances = {ProviderType.OPENAI: mock_provider}
|
||||||
service.validate_against_known_models(provider_instances)
|
service.validate_against_known_models(provider_instances)
|
||||||
@@ -447,7 +447,13 @@ class TestRegistryIntegration:
|
|||||||
}
|
}
|
||||||
mock_openai.get_provider_type.return_value = ProviderType.OPENAI
|
mock_openai.get_provider_type.return_value = ProviderType.OPENAI
|
||||||
|
|
||||||
def openai_list_models(respect_restrictions=True):
|
def openai_list_models(
|
||||||
|
*,
|
||||||
|
respect_restrictions: bool = True,
|
||||||
|
include_aliases: bool = True,
|
||||||
|
lowercase: bool = False,
|
||||||
|
unique: bool = False,
|
||||||
|
):
|
||||||
from utils.model_restrictions import get_restriction_service
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||||
@@ -457,15 +463,26 @@ class TestRegistryIntegration:
|
|||||||
target_model = config
|
target_model = config
|
||||||
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model):
|
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model):
|
||||||
continue
|
continue
|
||||||
models.append(model_name)
|
if include_aliases:
|
||||||
|
models.append(model_name)
|
||||||
else:
|
else:
|
||||||
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, model_name):
|
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, model_name):
|
||||||
continue
|
continue
|
||||||
models.append(model_name)
|
models.append(model_name)
|
||||||
|
if lowercase:
|
||||||
|
models = [m.lower() for m in models]
|
||||||
|
if unique:
|
||||||
|
seen = set()
|
||||||
|
ordered = []
|
||||||
|
for name in models:
|
||||||
|
if name in seen:
|
||||||
|
continue
|
||||||
|
seen.add(name)
|
||||||
|
ordered.append(name)
|
||||||
|
models = ordered
|
||||||
return models
|
return models
|
||||||
|
|
||||||
mock_openai.list_models = openai_list_models
|
mock_openai.list_models = MagicMock(side_effect=openai_list_models)
|
||||||
mock_openai.list_all_known_models.return_value = ["o3", "o3-mini"]
|
|
||||||
|
|
||||||
mock_gemini = MagicMock()
|
mock_gemini = MagicMock()
|
||||||
mock_gemini.MODEL_CAPABILITIES = {
|
mock_gemini.MODEL_CAPABILITIES = {
|
||||||
@@ -474,7 +491,13 @@ class TestRegistryIntegration:
|
|||||||
}
|
}
|
||||||
mock_gemini.get_provider_type.return_value = ProviderType.GOOGLE
|
mock_gemini.get_provider_type.return_value = ProviderType.GOOGLE
|
||||||
|
|
||||||
def gemini_list_models(respect_restrictions=True):
|
def gemini_list_models(
|
||||||
|
*,
|
||||||
|
respect_restrictions: bool = True,
|
||||||
|
include_aliases: bool = True,
|
||||||
|
lowercase: bool = False,
|
||||||
|
unique: bool = False,
|
||||||
|
):
|
||||||
from utils.model_restrictions import get_restriction_service
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||||
@@ -484,18 +507,26 @@ class TestRegistryIntegration:
|
|||||||
target_model = config
|
target_model = config
|
||||||
if restriction_service and not restriction_service.is_allowed(ProviderType.GOOGLE, target_model):
|
if restriction_service and not restriction_service.is_allowed(ProviderType.GOOGLE, target_model):
|
||||||
continue
|
continue
|
||||||
models.append(model_name)
|
if include_aliases:
|
||||||
|
models.append(model_name)
|
||||||
else:
|
else:
|
||||||
if restriction_service and not restriction_service.is_allowed(ProviderType.GOOGLE, model_name):
|
if restriction_service and not restriction_service.is_allowed(ProviderType.GOOGLE, model_name):
|
||||||
continue
|
continue
|
||||||
models.append(model_name)
|
models.append(model_name)
|
||||||
|
if lowercase:
|
||||||
|
models = [m.lower() for m in models]
|
||||||
|
if unique:
|
||||||
|
seen = set()
|
||||||
|
ordered = []
|
||||||
|
for name in models:
|
||||||
|
if name in seen:
|
||||||
|
continue
|
||||||
|
seen.add(name)
|
||||||
|
ordered.append(name)
|
||||||
|
models = ordered
|
||||||
return models
|
return models
|
||||||
|
|
||||||
mock_gemini.list_models = gemini_list_models
|
mock_gemini.list_models = MagicMock(side_effect=gemini_list_models)
|
||||||
mock_gemini.list_all_known_models.return_value = [
|
|
||||||
"gemini-2.5-pro",
|
|
||||||
"gemini-2.5-flash",
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_provider_side_effect(provider_type):
|
def get_provider_side_effect(provider_type):
|
||||||
if provider_type == ProviderType.OPENAI:
|
if provider_type == ProviderType.OPENAI:
|
||||||
@@ -615,7 +646,13 @@ class TestAutoModeWithRestrictions:
|
|||||||
}
|
}
|
||||||
mock_openai.get_provider_type.return_value = ProviderType.OPENAI
|
mock_openai.get_provider_type.return_value = ProviderType.OPENAI
|
||||||
|
|
||||||
def openai_list_models(respect_restrictions=True):
|
def openai_list_models(
|
||||||
|
*,
|
||||||
|
respect_restrictions: bool = True,
|
||||||
|
include_aliases: bool = True,
|
||||||
|
lowercase: bool = False,
|
||||||
|
unique: bool = False,
|
||||||
|
):
|
||||||
from utils.model_restrictions import get_restriction_service
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||||
@@ -625,15 +662,26 @@ class TestAutoModeWithRestrictions:
|
|||||||
target_model = config
|
target_model = config
|
||||||
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model):
|
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model):
|
||||||
continue
|
continue
|
||||||
models.append(model_name)
|
if include_aliases:
|
||||||
|
models.append(model_name)
|
||||||
else:
|
else:
|
||||||
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, model_name):
|
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, model_name):
|
||||||
continue
|
continue
|
||||||
models.append(model_name)
|
models.append(model_name)
|
||||||
|
if lowercase:
|
||||||
|
models = [m.lower() for m in models]
|
||||||
|
if unique:
|
||||||
|
seen = set()
|
||||||
|
ordered = []
|
||||||
|
for name in models:
|
||||||
|
if name in seen:
|
||||||
|
continue
|
||||||
|
seen.add(name)
|
||||||
|
ordered.append(name)
|
||||||
|
models = ordered
|
||||||
return models
|
return models
|
||||||
|
|
||||||
mock_openai.list_models = openai_list_models
|
mock_openai.list_models = MagicMock(side_effect=openai_list_models)
|
||||||
mock_openai.list_all_known_models.return_value = ["o3", "o3-mini", "o4-mini"]
|
|
||||||
|
|
||||||
# Add get_preferred_model method to mock to match new implementation
|
# Add get_preferred_model method to mock to match new implementation
|
||||||
def get_preferred_model(category, allowed_models):
|
def get_preferred_model(category, allowed_models):
|
||||||
|
|||||||
@@ -1,216 +0,0 @@
|
|||||||
"""
|
|
||||||
Tests that simulate the OLD BROKEN BEHAVIOR to prove it was indeed broken.
|
|
||||||
|
|
||||||
These tests create mock providers that behave like the old code (before our fix)
|
|
||||||
and demonstrate that they would have failed to catch the HIGH-severity bug.
|
|
||||||
|
|
||||||
IMPORTANT: These tests show what WOULD HAVE HAPPENED with the old code.
|
|
||||||
They prove that our fix was necessary and actually addresses real problems.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
from providers.shared import ProviderType
|
|
||||||
from utils.model_restrictions import ModelRestrictionService
|
|
||||||
|
|
||||||
|
|
||||||
class TestOldBehaviorSimulation:
|
|
||||||
"""
|
|
||||||
Simulate the old broken behavior to prove it was buggy.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def test_old_behavior_would_miss_target_restrictions(self):
|
|
||||||
"""
|
|
||||||
SIMULATION: This test recreates the OLD BROKEN BEHAVIOR and proves it was buggy.
|
|
||||||
|
|
||||||
OLD BUG: When validation service called provider.list_models(), it only got
|
|
||||||
aliases back, not targets. This meant target-based restrictions weren't validated.
|
|
||||||
"""
|
|
||||||
# Create a mock provider that simulates the OLD BROKEN BEHAVIOR
|
|
||||||
old_broken_provider = MagicMock()
|
|
||||||
old_broken_provider.MODEL_CAPABILITIES = {
|
|
||||||
"mini": "o4-mini", # alias -> target
|
|
||||||
"o3mini": "o3-mini", # alias -> target
|
|
||||||
"o4-mini": {"context_window": 200000},
|
|
||||||
"o3-mini": {"context_window": 200000},
|
|
||||||
}
|
|
||||||
|
|
||||||
# OLD BROKEN: list_models only returned aliases, missing targets
|
|
||||||
old_broken_provider.list_models.return_value = ["mini", "o3mini"]
|
|
||||||
|
|
||||||
# OLD BROKEN: There was no list_all_known_models method!
|
|
||||||
# We simulate this by making it behave like the old list_models
|
|
||||||
old_broken_provider.list_all_known_models.return_value = ["mini", "o3mini"] # MISSING TARGETS!
|
|
||||||
|
|
||||||
# Now test what happens when admin tries to restrict by target model
|
|
||||||
service = ModelRestrictionService()
|
|
||||||
service.restrictions = {ProviderType.OPENAI: {"o4-mini"}} # Restrict to target model
|
|
||||||
|
|
||||||
with patch("utils.model_restrictions.logger") as mock_logger:
|
|
||||||
provider_instances = {ProviderType.OPENAI: old_broken_provider}
|
|
||||||
service.validate_against_known_models(provider_instances)
|
|
||||||
|
|
||||||
# OLD BROKEN BEHAVIOR: Would warn about o4-mini being "not recognized"
|
|
||||||
# because it wasn't in the list_all_known_models response
|
|
||||||
target_warnings = [
|
|
||||||
call
|
|
||||||
for call in mock_logger.warning.call_args_list
|
|
||||||
if "o4-mini" in str(call) and "not a recognized" in str(call)
|
|
||||||
]
|
|
||||||
|
|
||||||
# This proves the old behavior was broken - it would generate false warnings
|
|
||||||
assert len(target_warnings) > 0, "OLD BROKEN BEHAVIOR: Would incorrectly warn about valid target models"
|
|
||||||
|
|
||||||
# Verify the warning message shows the broken list
|
|
||||||
warning_text = str(target_warnings[0])
|
|
||||||
assert "mini" in warning_text # Alias was included
|
|
||||||
assert "o3mini" in warning_text # Alias was included
|
|
||||||
# But targets were missing from the known models list in old behavior
|
|
||||||
|
|
||||||
def test_new_behavior_fixes_the_problem(self):
|
|
||||||
"""
|
|
||||||
Compare old vs new behavior to show our fix works.
|
|
||||||
"""
|
|
||||||
# Create mock provider with NEW FIXED BEHAVIOR
|
|
||||||
new_fixed_provider = MagicMock()
|
|
||||||
new_fixed_provider.MODEL_CAPABILITIES = {
|
|
||||||
"mini": "o4-mini",
|
|
||||||
"o3mini": "o3-mini",
|
|
||||||
"o4-mini": {"context_window": 200000},
|
|
||||||
"o3-mini": {"context_window": 200000},
|
|
||||||
}
|
|
||||||
|
|
||||||
# NEW FIXED: list_all_known_models includes BOTH aliases AND targets
|
|
||||||
new_fixed_provider.list_all_known_models.return_value = [
|
|
||||||
"mini",
|
|
||||||
"o3mini", # aliases
|
|
||||||
"o4-mini",
|
|
||||||
"o3-mini", # targets - THESE WERE MISSING IN OLD CODE!
|
|
||||||
]
|
|
||||||
|
|
||||||
# Same restriction scenario
|
|
||||||
service = ModelRestrictionService()
|
|
||||||
service.restrictions = {ProviderType.OPENAI: {"o4-mini"}} # Restrict to target model
|
|
||||||
|
|
||||||
with patch("utils.model_restrictions.logger") as mock_logger:
|
|
||||||
provider_instances = {ProviderType.OPENAI: new_fixed_provider}
|
|
||||||
service.validate_against_known_models(provider_instances)
|
|
||||||
|
|
||||||
# NEW FIXED BEHAVIOR: No warnings about o4-mini being unrecognized
|
|
||||||
target_warnings = [
|
|
||||||
call
|
|
||||||
for call in mock_logger.warning.call_args_list
|
|
||||||
if "o4-mini" in str(call) and "not a recognized" in str(call)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Our fix prevents false warnings
|
|
||||||
assert len(target_warnings) == 0, "NEW FIXED BEHAVIOR: Should not warn about valid target models"
|
|
||||||
|
|
||||||
def test_policy_bypass_prevention_old_vs_new(self):
|
|
||||||
"""
|
|
||||||
Show how the old behavior could have led to policy bypass scenarios.
|
|
||||||
"""
|
|
||||||
# OLD BROKEN: Admin thinks they've restricted access to o4-mini,
|
|
||||||
# but validation doesn't recognize it as a valid restriction target
|
|
||||||
old_broken_provider = MagicMock()
|
|
||||||
old_broken_provider.list_all_known_models.return_value = ["mini", "o3mini"] # Missing targets
|
|
||||||
|
|
||||||
# NEW FIXED: Same provider with our fix
|
|
||||||
new_fixed_provider = MagicMock()
|
|
||||||
new_fixed_provider.list_all_known_models.return_value = ["mini", "o3mini", "o4-mini", "o3-mini"]
|
|
||||||
|
|
||||||
# Test restriction on target model - use completely separate service instances
|
|
||||||
old_service = ModelRestrictionService()
|
|
||||||
old_service.restrictions = {ProviderType.OPENAI: {"o4-mini", "completely-invalid-model"}}
|
|
||||||
|
|
||||||
new_service = ModelRestrictionService()
|
|
||||||
new_service.restrictions = {ProviderType.OPENAI: {"o4-mini", "completely-invalid-model"}}
|
|
||||||
|
|
||||||
# OLD BEHAVIOR: Would warn about BOTH models being unrecognized
|
|
||||||
with patch("utils.model_restrictions.logger") as mock_logger_old:
|
|
||||||
provider_instances = {ProviderType.OPENAI: old_broken_provider}
|
|
||||||
old_service.validate_against_known_models(provider_instances)
|
|
||||||
|
|
||||||
old_warnings = [str(call) for call in mock_logger_old.warning.call_args_list]
|
|
||||||
print(f"OLD warnings: {old_warnings}") # Debug output
|
|
||||||
|
|
||||||
# NEW BEHAVIOR: Only warns about truly invalid model
|
|
||||||
with patch("utils.model_restrictions.logger") as mock_logger_new:
|
|
||||||
provider_instances = {ProviderType.OPENAI: new_fixed_provider}
|
|
||||||
new_service.validate_against_known_models(provider_instances)
|
|
||||||
|
|
||||||
new_warnings = [str(call) for call in mock_logger_new.warning.call_args_list]
|
|
||||||
print(f"NEW warnings: {new_warnings}") # Debug output
|
|
||||||
|
|
||||||
# For now, just verify that we get some warnings in both cases
|
|
||||||
# The key point is that the "Known models" list is different
|
|
||||||
assert len(old_warnings) > 0, "OLD: Should have warnings"
|
|
||||||
assert len(new_warnings) > 0, "NEW: Should have warnings for invalid model"
|
|
||||||
|
|
||||||
# Verify the known models list is different between old and new
|
|
||||||
str(old_warnings[0]) if old_warnings else ""
|
|
||||||
new_warning_text = str(new_warnings[0]) if new_warnings else ""
|
|
||||||
|
|
||||||
if "Known models:" in new_warning_text:
|
|
||||||
# NEW behavior should include o4-mini in known models list
|
|
||||||
assert "o4-mini" in new_warning_text, "NEW: Should include o4-mini in known models"
|
|
||||||
|
|
||||||
print("This test demonstrates that our fix improves the 'Known models' list shown to users.")
|
|
||||||
|
|
||||||
def test_demonstrate_target_coverage_improvement(self):
|
|
||||||
"""
|
|
||||||
Show the exact improvement in target model coverage.
|
|
||||||
"""
|
|
||||||
# Simulate different provider implementations
|
|
||||||
providers_old_vs_new = [
|
|
||||||
# (old_broken_list, new_fixed_list, provider_name)
|
|
||||||
(["mini", "o3mini"], ["mini", "o3mini", "o4-mini", "o3-mini"], "OpenAI"),
|
|
||||||
(
|
|
||||||
["flash", "pro"],
|
|
||||||
["flash", "pro", "gemini-2.5-flash", "gemini-2.5-pro"],
|
|
||||||
"Gemini",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
for old_list, new_list, provider_name in providers_old_vs_new:
|
|
||||||
# Count how many additional models are now covered
|
|
||||||
old_coverage = set(old_list)
|
|
||||||
new_coverage = set(new_list)
|
|
||||||
|
|
||||||
additional_coverage = new_coverage - old_coverage
|
|
||||||
|
|
||||||
# There should be additional target models covered
|
|
||||||
assert len(additional_coverage) > 0, f"{provider_name}: Should have additional target coverage"
|
|
||||||
|
|
||||||
# All old models should still be covered
|
|
||||||
assert old_coverage.issubset(new_coverage), f"{provider_name}: Should maintain backward compatibility"
|
|
||||||
|
|
||||||
print(f"{provider_name} provider:")
|
|
||||||
print(f" Old coverage: {sorted(old_coverage)}")
|
|
||||||
print(f" New coverage: {sorted(new_coverage)}")
|
|
||||||
print(f" Additional models: {sorted(additional_coverage)}")
|
|
||||||
|
|
||||||
def test_comprehensive_alias_target_mapping_verification(self):
|
|
||||||
"""
|
|
||||||
Verify that our fix provides comprehensive alias->target coverage.
|
|
||||||
"""
|
|
||||||
from providers.gemini import GeminiModelProvider
|
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
|
||||||
|
|
||||||
# Test real providers to ensure they implement our fix correctly
|
|
||||||
providers = [OpenAIModelProvider(api_key="test-key"), GeminiModelProvider(api_key="test-key")]
|
|
||||||
|
|
||||||
for provider in providers:
|
|
||||||
all_known = provider.list_all_known_models()
|
|
||||||
|
|
||||||
# Check that every model and its aliases appear in the comprehensive list
|
|
||||||
for model_name, config in provider.MODEL_CAPABILITIES.items():
|
|
||||||
assert model_name.lower() in all_known, f"{provider.__class__.__name__}: Missing model {model_name}"
|
|
||||||
|
|
||||||
for alias in getattr(config, "aliases", []):
|
|
||||||
assert (
|
|
||||||
alias.lower() in all_known
|
|
||||||
), f"{provider.__class__.__name__}: Missing alias {alias} for model {model_name}"
|
|
||||||
assert (
|
|
||||||
provider._resolve_model_name(alias) == model_name
|
|
||||||
), f"{provider.__class__.__name__}: Alias {alias} should resolve to {model_name}"
|
|
||||||
@@ -26,10 +26,7 @@ class TestOpenAICompatibleTokenUsage(unittest.TestCase):
|
|||||||
def validate_model_name(self, model_name):
|
def validate_model_name(self, model_name):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def list_models(self, respect_restrictions=True):
|
def list_models(self, **kwargs):
|
||||||
return ["test-model"]
|
|
||||||
|
|
||||||
def list_all_known_models(self):
|
|
||||||
return ["test-model"]
|
return ["test-model"]
|
||||||
|
|
||||||
self.provider = TestProvider("test-key")
|
self.provider = TestProvider("test-key")
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ class TestOpenRouterAutoMode:
|
|||||||
os.environ["DEFAULT_MODEL"] = "auto"
|
os.environ["DEFAULT_MODEL"] = "auto"
|
||||||
|
|
||||||
mock_registry = Mock()
|
mock_registry = Mock()
|
||||||
mock_registry.list_models.return_value = [
|
model_names = [
|
||||||
"google/gemini-2.5-flash",
|
"google/gemini-2.5-flash",
|
||||||
"google/gemini-2.5-pro",
|
"google/gemini-2.5-pro",
|
||||||
"openai/o3",
|
"openai/o3",
|
||||||
@@ -159,6 +159,18 @@ class TestOpenRouterAutoMode:
|
|||||||
"anthropic/claude-opus-4.1",
|
"anthropic/claude-opus-4.1",
|
||||||
"anthropic/claude-sonnet-4.1",
|
"anthropic/claude-sonnet-4.1",
|
||||||
]
|
]
|
||||||
|
mock_registry.list_models.return_value = model_names
|
||||||
|
|
||||||
|
# Mock resolve to return a ModelCapabilities-like object for each model
|
||||||
|
def mock_resolve(model_name):
|
||||||
|
if model_name in model_names:
|
||||||
|
mock_config = Mock()
|
||||||
|
mock_config.is_custom = False
|
||||||
|
mock_config.aliases = [] # Empty list of aliases
|
||||||
|
return mock_config
|
||||||
|
return None
|
||||||
|
|
||||||
|
mock_registry.resolve.side_effect = mock_resolve
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||||
|
|
||||||
@@ -171,8 +183,7 @@ class TestOpenRouterAutoMode:
|
|||||||
assert len(available_models) > 0, "Should find OpenRouter models in auto mode"
|
assert len(available_models) > 0, "Should find OpenRouter models in auto mode"
|
||||||
assert all(provider_type == ProviderType.OPENROUTER for provider_type in available_models.values())
|
assert all(provider_type == ProviderType.OPENROUTER for provider_type in available_models.values())
|
||||||
|
|
||||||
expected_models = mock_registry.list_models.return_value
|
for model in model_names:
|
||||||
for model in expected_models:
|
|
||||||
assert model in available_models, f"Model {model} should be available"
|
assert model in available_models, f"Model {model} should be available"
|
||||||
|
|
||||||
@pytest.mark.no_mock_provider
|
@pytest.mark.no_mock_provider
|
||||||
|
|||||||
@@ -151,11 +151,16 @@ class TestSupportedModelsAliases:
|
|||||||
assert "o3-2025-04-16" in dial_models
|
assert "o3-2025-04-16" in dial_models
|
||||||
assert "o3" in dial_models
|
assert "o3" in dial_models
|
||||||
|
|
||||||
def test_list_all_known_models_includes_aliases(self):
|
def test_list_models_all_known_variant_includes_aliases(self):
|
||||||
"""Test that list_all_known_models returns all models and aliases in lowercase."""
|
"""Unified list_models should support lowercase, alias-inclusive listings."""
|
||||||
# Test Gemini
|
# Test Gemini
|
||||||
gemini_provider = GeminiModelProvider("test-key")
|
gemini_provider = GeminiModelProvider("test-key")
|
||||||
gemini_all = gemini_provider.list_all_known_models()
|
gemini_all = gemini_provider.list_models(
|
||||||
|
respect_restrictions=False,
|
||||||
|
include_aliases=True,
|
||||||
|
lowercase=True,
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
assert "gemini-2.5-flash" in gemini_all
|
assert "gemini-2.5-flash" in gemini_all
|
||||||
assert "flash" in gemini_all
|
assert "flash" in gemini_all
|
||||||
assert "gemini-2.5-pro" in gemini_all
|
assert "gemini-2.5-pro" in gemini_all
|
||||||
@@ -165,7 +170,12 @@ class TestSupportedModelsAliases:
|
|||||||
|
|
||||||
# Test OpenAI
|
# Test OpenAI
|
||||||
openai_provider = OpenAIModelProvider("test-key")
|
openai_provider = OpenAIModelProvider("test-key")
|
||||||
openai_all = openai_provider.list_all_known_models()
|
openai_all = openai_provider.list_models(
|
||||||
|
respect_restrictions=False,
|
||||||
|
include_aliases=True,
|
||||||
|
lowercase=True,
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
assert "o4-mini" in openai_all
|
assert "o4-mini" in openai_all
|
||||||
assert "mini" in openai_all
|
assert "mini" in openai_all
|
||||||
assert "o3-mini" in openai_all
|
assert "o3-mini" in openai_all
|
||||||
|
|||||||
@@ -30,13 +30,20 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class ModelRestrictionService:
|
class ModelRestrictionService:
|
||||||
"""
|
"""Central authority for environment-driven model allowlists.
|
||||||
Centralized service for managing model usage restrictions.
|
|
||||||
|
|
||||||
This service:
|
Role
|
||||||
1. Loads restrictions from environment variables at startup
|
Interpret ``*_ALLOWED_MODELS`` environment variables, keep their
|
||||||
2. Validates restrictions against known models
|
entries normalised (lowercase), and answer whether a provider/model
|
||||||
3. Provides a simple interface to check if a model is allowed
|
pairing is permitted.
|
||||||
|
|
||||||
|
Responsibilities
|
||||||
|
* Parse, cache, and expose per-provider restriction sets
|
||||||
|
* Validate configuration by cross-checking each entry against the
|
||||||
|
provider’s alias-aware model list
|
||||||
|
* Offer helper methods such as ``is_allowed`` and ``filter_models`` to
|
||||||
|
enforce policy everywhere model names appear (tool selection, CLI
|
||||||
|
commands, etc.).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Environment variable names
|
# Environment variable names
|
||||||
@@ -94,9 +101,14 @@ class ModelRestrictionService:
|
|||||||
|
|
||||||
# Get all supported models using the clean polymorphic interface
|
# Get all supported models using the clean polymorphic interface
|
||||||
try:
|
try:
|
||||||
# Use list_all_known_models to get both aliases and their targets
|
# Gather canonical models and aliases with consistent formatting
|
||||||
all_models = provider.list_all_known_models()
|
all_models = provider.list_models(
|
||||||
supported_models = {model.lower() for model in all_models}
|
respect_restrictions=False,
|
||||||
|
include_aliases=True,
|
||||||
|
lowercase=True,
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
supported_models = set(all_models)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Could not get model list from {provider_type.value} provider: {e}")
|
logger.debug(f"Could not get model list from {provider_type.value} provider: {e}")
|
||||||
supported_models = set()
|
supported_models = set()
|
||||||
|
|||||||
Reference in New Issue
Block a user