Proper fix for model discovery per provider
This commit is contained in:
39
README.md
39
README.md
@@ -365,13 +365,50 @@ and there may be more potential vulnerabilities. Find and share related code."
|
|||||||
**[📖 Read More](docs/tools/codereview.md)** - Professional review capabilities and parallel analysis
|
**[📖 Read More](docs/tools/codereview.md)** - Professional review capabilities and parallel analysis
|
||||||
|
|
||||||
### 6. `precommit` - Pre-Commit Validation
|
### 6. `precommit` - Pre-Commit Validation
|
||||||
Comprehensive review of staged/unstaged git changes across multiple repositories. Validates changes against requirements and detects potential regressions.
|
Comprehensive review of staged/unstaged git changes across multiple repositories. Validates changes against requirements
|
||||||
|
and detects potential regressions.
|
||||||
|
|
||||||
```
|
```
|
||||||
Perform a thorough precommit with o3, we want to only highlight critical issues, no blockers, no regressions. I need
|
Perform a thorough precommit with o3, we want to only highlight critical issues, no blockers, no regressions. I need
|
||||||
100% confidence this covers all the edge cases listed in FIX_ISSUE_ABC.md
|
100% confidence this covers all the edge cases listed in FIX_ISSUE_ABC.md
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Real Example:
|
||||||
|
|
||||||
|
* I asked claude (sonnet-4) to refactor some code to fix a known bug; gave it exact instructions
|
||||||
|
* The fix looked correct — and at a glance, so did the changes.
|
||||||
|
* More importantly, all existing unit tests passed (540+ unit tests) without a hitch. Zero regressions. Claude seemed confident and eager to commit. Everything looked perfect.
|
||||||
|
|
||||||
|
I then ran:
|
||||||
|
|
||||||
|
```text
|
||||||
|
Run a precommit with o3 confirm our changes are sound and diffs are valid. Confirm this won't cause breakage or
|
||||||
|
regressions and codesmells are out
|
||||||
|
```
|
||||||
|
|
||||||
|
This passed the tracked and untracked changes / diffs and the related code to o3 for a thorough review via Zen's `precommit` tool.
|
||||||
|
|
||||||
|
Output:
|
||||||
|
|
||||||
|
```text
|
||||||
|
⏺ Precommit Analysis Results - CRITICAL ISSUE FOUND
|
||||||
|
|
||||||
|
The O3 model has identified one HIGH-severity regression and several medium/low issues in our refactor. Here's the summary:
|
||||||
|
|
||||||
|
🚨 HIGH SEVERITY - MUST FIX BEFORE COMMIT
|
||||||
|
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
The reported issue was in fact a _very subtle bug_ that slipped through the quick glance — and a unit test for this exact case apparently
|
||||||
|
was missing (out of 540 existing tests!) - explains the zero reported regressions. The fix was ultimately simple, but the
|
||||||
|
fact Claude (and by extension, I) overlooked this, was a stark reminder: no number of eyeballs is ever enough. Fixed the
|
||||||
|
issue, ran `precommit` with o3 again and got:
|
||||||
|
|
||||||
|
**RECOMMENDATION: PROCEED WITH COMMIT**
|
||||||
|
|
||||||
|
Nice!
|
||||||
|
|
||||||
**[📖 Read More](docs/tools/precommit.md)** - Multi-repository validation and change analysis
|
**[📖 Read More](docs/tools/precommit.md)** - Multi-repository validation and change analysis
|
||||||
|
|
||||||
### 7. `debug` - Expert Debugging Assistant
|
### 7. `debug` - Expert Debugging Assistant
|
||||||
|
|||||||
@@ -221,3 +221,27 @@ class ModelProvider(ABC):
|
|||||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||||
"""Check if the model supports extended thinking mode."""
|
"""Check if the model supports extended thinking mode."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||||
|
"""Return a list of model names supported by this provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of model names available from this provider
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_all_known_models(self) -> list[str]:
|
||||||
|
"""Return all model names known by this provider, including alias targets.
|
||||||
|
|
||||||
|
This is used for validation purposes to ensure restriction policies
|
||||||
|
can validate against both aliases and their target model names.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of all model names and alias targets known by this provider
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|||||||
@@ -276,3 +276,56 @@ class CustomProvider(OpenAICompatibleProvider):
|
|||||||
False (custom models generally don't support thinking mode)
|
False (custom models generally don't support thinking mode)
|
||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||||
|
"""Return a list of model names supported by this provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of model names available from this provider
|
||||||
|
"""
|
||||||
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
|
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||||
|
models = []
|
||||||
|
|
||||||
|
if self._registry:
|
||||||
|
# Get all models from the registry
|
||||||
|
all_models = self._registry.list_models()
|
||||||
|
aliases = self._registry.list_aliases()
|
||||||
|
|
||||||
|
# Add models that are validated by the custom provider
|
||||||
|
for model_name in all_models + aliases:
|
||||||
|
# Use the provider's validation logic to determine if this model
|
||||||
|
# is appropriate for the custom endpoint
|
||||||
|
if self.validate_model_name(model_name):
|
||||||
|
# Check restrictions if enabled
|
||||||
|
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||||
|
continue
|
||||||
|
|
||||||
|
models.append(model_name)
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
def list_all_known_models(self) -> list[str]:
|
||||||
|
"""Return all model names known by this provider, including alias targets.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of all model names and alias targets known by this provider
|
||||||
|
"""
|
||||||
|
all_models = set()
|
||||||
|
|
||||||
|
if self._registry:
|
||||||
|
# Get all models and aliases from the registry
|
||||||
|
all_models.update(model.lower() for model in self._registry.list_models())
|
||||||
|
all_models.update(alias.lower() for alias in self._registry.list_aliases())
|
||||||
|
|
||||||
|
# For each alias, also add its target
|
||||||
|
for alias in self._registry.list_aliases():
|
||||||
|
config = self._registry.resolve(alias)
|
||||||
|
if config:
|
||||||
|
all_models.add(config.model_name.lower())
|
||||||
|
|
||||||
|
return list(all_models)
|
||||||
|
|||||||
@@ -287,6 +287,56 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
|
|
||||||
return int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
|
return int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
|
||||||
|
|
||||||
|
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||||
|
"""Return a list of model names supported by this provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of model names available from this provider
|
||||||
|
"""
|
||||||
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
|
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||||
|
models = []
|
||||||
|
|
||||||
|
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||||
|
# Handle both base models (dict configs) and aliases (string values)
|
||||||
|
if isinstance(config, str):
|
||||||
|
# This is an alias - check if the target model would be allowed
|
||||||
|
target_model = config
|
||||||
|
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
|
||||||
|
continue
|
||||||
|
# Allow the alias
|
||||||
|
models.append(model_name)
|
||||||
|
else:
|
||||||
|
# This is a base model with config dict
|
||||||
|
# Check restrictions if enabled
|
||||||
|
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||||
|
continue
|
||||||
|
models.append(model_name)
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
def list_all_known_models(self) -> list[str]:
|
||||||
|
"""Return all model names known by this provider, including alias targets.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of all model names and alias targets known by this provider
|
||||||
|
"""
|
||||||
|
all_models = set()
|
||||||
|
|
||||||
|
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||||
|
# Add the model name itself
|
||||||
|
all_models.add(model_name.lower())
|
||||||
|
|
||||||
|
# If it's an alias (string value), add the target model too
|
||||||
|
if isinstance(config, str):
|
||||||
|
all_models.add(config.lower())
|
||||||
|
|
||||||
|
return list(all_models)
|
||||||
|
|
||||||
def _resolve_model_name(self, model_name: str) -> str:
|
def _resolve_model_name(self, model_name: str) -> str:
|
||||||
"""Resolve model shorthand to full name."""
|
"""Resolve model shorthand to full name."""
|
||||||
# Check if it's a shorthand
|
# Check if it's a shorthand
|
||||||
|
|||||||
@@ -163,6 +163,56 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
# This may change with future O3 models
|
# This may change with future O3 models
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||||
|
"""Return a list of model names supported by this provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of model names available from this provider
|
||||||
|
"""
|
||||||
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
|
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||||
|
models = []
|
||||||
|
|
||||||
|
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||||
|
# Handle both base models (dict configs) and aliases (string values)
|
||||||
|
if isinstance(config, str):
|
||||||
|
# This is an alias - check if the target model would be allowed
|
||||||
|
target_model = config
|
||||||
|
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
|
||||||
|
continue
|
||||||
|
# Allow the alias
|
||||||
|
models.append(model_name)
|
||||||
|
else:
|
||||||
|
# This is a base model with config dict
|
||||||
|
# Check restrictions if enabled
|
||||||
|
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||||
|
continue
|
||||||
|
models.append(model_name)
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
def list_all_known_models(self) -> list[str]:
|
||||||
|
"""Return all model names known by this provider, including alias targets.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of all model names and alias targets known by this provider
|
||||||
|
"""
|
||||||
|
all_models = set()
|
||||||
|
|
||||||
|
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||||
|
# Add the model name itself
|
||||||
|
all_models.add(model_name.lower())
|
||||||
|
|
||||||
|
# If it's an alias (string value), add the target model too
|
||||||
|
if isinstance(config, str):
|
||||||
|
all_models.add(config.lower())
|
||||||
|
|
||||||
|
return list(all_models)
|
||||||
|
|
||||||
def _resolve_model_name(self, model_name: str) -> str:
|
def _resolve_model_name(self, model_name: str) -> str:
|
||||||
"""Resolve model shorthand to full name."""
|
"""Resolve model shorthand to full name."""
|
||||||
# Check if it's a shorthand
|
# Check if it's a shorthand
|
||||||
|
|||||||
@@ -190,3 +190,48 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
|||||||
False (no OpenRouter models currently support thinking mode)
|
False (no OpenRouter models currently support thinking mode)
|
||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||||
|
"""Return a list of model names supported by this provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of model names available from this provider
|
||||||
|
"""
|
||||||
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
|
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||||
|
models = []
|
||||||
|
|
||||||
|
if self._registry:
|
||||||
|
for model_name in self._registry.list_models():
|
||||||
|
# Check restrictions if enabled
|
||||||
|
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||||
|
continue
|
||||||
|
|
||||||
|
models.append(model_name)
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
def list_all_known_models(self) -> list[str]:
|
||||||
|
"""Return all model names known by this provider, including alias targets.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of all model names and alias targets known by this provider
|
||||||
|
"""
|
||||||
|
all_models = set()
|
||||||
|
|
||||||
|
if self._registry:
|
||||||
|
# Get all models and aliases from the registry
|
||||||
|
all_models.update(model.lower() for model in self._registry.list_models())
|
||||||
|
all_models.update(alias.lower() for alias in self._registry.list_aliases())
|
||||||
|
|
||||||
|
# For each alias, also add its target
|
||||||
|
for alias in self._registry.list_aliases():
|
||||||
|
config = self._registry.resolve(alias)
|
||||||
|
if config:
|
||||||
|
all_models.add(config.model_name.lower())
|
||||||
|
|
||||||
|
return list(all_models)
|
||||||
|
|||||||
@@ -160,66 +160,29 @@ class ModelProviderRegistry:
|
|||||||
Returns:
|
Returns:
|
||||||
Dict mapping model names to provider types
|
Dict mapping model names to provider types
|
||||||
"""
|
"""
|
||||||
models = {}
|
|
||||||
instance = cls()
|
|
||||||
|
|
||||||
# Import here to avoid circular imports
|
# Import here to avoid circular imports
|
||||||
from utils.model_restrictions import get_restriction_service
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||||
|
models: dict[str, ProviderType] = {}
|
||||||
|
instance = cls()
|
||||||
|
|
||||||
for provider_type in instance._providers:
|
for provider_type in instance._providers:
|
||||||
provider = cls.get_provider(provider_type)
|
provider = cls.get_provider(provider_type)
|
||||||
if provider:
|
if not provider:
|
||||||
# Get supported models based on provider type
|
continue
|
||||||
if hasattr(provider, "SUPPORTED_MODELS"):
|
|
||||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
|
||||||
# Handle both base models (dict configs) and aliases (string values)
|
|
||||||
if isinstance(config, str):
|
|
||||||
# This is an alias - check if the target model would be allowed
|
|
||||||
target_model = config
|
|
||||||
if restriction_service and not restriction_service.is_allowed(provider_type, target_model):
|
|
||||||
logging.debug(f"Alias {model_name} -> {target_model} filtered by restrictions")
|
|
||||||
continue
|
|
||||||
# Allow the alias
|
|
||||||
models[model_name] = provider_type
|
|
||||||
else:
|
|
||||||
# This is a base model with config dict
|
|
||||||
# Check restrictions if enabled
|
|
||||||
if restriction_service and not restriction_service.is_allowed(provider_type, model_name):
|
|
||||||
logging.debug(f"Model {model_name} filtered by restrictions")
|
|
||||||
continue
|
|
||||||
models[model_name] = provider_type
|
|
||||||
elif provider_type == ProviderType.OPENROUTER:
|
|
||||||
# OpenRouter uses a registry system instead of SUPPORTED_MODELS
|
|
||||||
if hasattr(provider, "_registry") and provider._registry:
|
|
||||||
for model_name in provider._registry.list_models():
|
|
||||||
# Check restrictions if enabled
|
|
||||||
if restriction_service and not restriction_service.is_allowed(provider_type, model_name):
|
|
||||||
logging.debug(f"Model {model_name} filtered by restrictions")
|
|
||||||
continue
|
|
||||||
|
|
||||||
models[model_name] = provider_type
|
try:
|
||||||
elif provider_type == ProviderType.CUSTOM:
|
available = provider.list_models(respect_restrictions=respect_restrictions)
|
||||||
# Custom provider also uses a registry system (shared with OpenRouter)
|
except NotImplementedError:
|
||||||
if hasattr(provider, "_registry") and provider._registry:
|
logging.warning("Provider %s does not implement list_models", provider_type)
|
||||||
# Get all models from the registry
|
continue
|
||||||
all_models = provider._registry.list_models()
|
|
||||||
aliases = provider._registry.list_aliases()
|
|
||||||
|
|
||||||
# Add models that are validated by the custom provider
|
for model_name in available:
|
||||||
for model_name in all_models + aliases:
|
if restriction_service and not restriction_service.is_allowed(provider_type, model_name):
|
||||||
# Use the provider's validation logic to determine if this model
|
logging.debug("Model %s filtered by restrictions", model_name)
|
||||||
# is appropriate for the custom endpoint
|
continue
|
||||||
if provider.validate_model_name(model_name):
|
models[model_name] = provider_type
|
||||||
# Check restrictions if enabled
|
|
||||||
if restriction_service and not restriction_service.is_allowed(
|
|
||||||
provider_type, model_name
|
|
||||||
):
|
|
||||||
logging.debug(f"Model {model_name} filtered by restrictions")
|
|
||||||
continue
|
|
||||||
|
|
||||||
models[model_name] = provider_type
|
|
||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|||||||
@@ -126,6 +126,56 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
|||||||
# This may change with future GROK model releases
|
# This may change with future GROK model releases
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||||
|
"""Return a list of model names supported by this provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of model names available from this provider
|
||||||
|
"""
|
||||||
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
|
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||||
|
models = []
|
||||||
|
|
||||||
|
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||||
|
# Handle both base models (dict configs) and aliases (string values)
|
||||||
|
if isinstance(config, str):
|
||||||
|
# This is an alias - check if the target model would be allowed
|
||||||
|
target_model = config
|
||||||
|
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
|
||||||
|
continue
|
||||||
|
# Allow the alias
|
||||||
|
models.append(model_name)
|
||||||
|
else:
|
||||||
|
# This is a base model with config dict
|
||||||
|
# Check restrictions if enabled
|
||||||
|
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||||
|
continue
|
||||||
|
models.append(model_name)
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
def list_all_known_models(self) -> list[str]:
|
||||||
|
"""Return all model names known by this provider, including alias targets.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of all model names and alias targets known by this provider
|
||||||
|
"""
|
||||||
|
all_models = set()
|
||||||
|
|
||||||
|
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||||
|
# Add the model name itself
|
||||||
|
all_models.add(model_name.lower())
|
||||||
|
|
||||||
|
# If it's an alias (string value), add the target model too
|
||||||
|
if isinstance(config, str):
|
||||||
|
all_models.add(config.lower())
|
||||||
|
|
||||||
|
return list(all_models)
|
||||||
|
|
||||||
def _resolve_model_name(self, model_name: str) -> str:
|
def _resolve_model_name(self, model_name: str) -> str:
|
||||||
"""Resolve model shorthand to full name."""
|
"""Resolve model shorthand to full name."""
|
||||||
# Check if it's a shorthand
|
# Check if it's a shorthand
|
||||||
|
|||||||
339
tests/test_alias_target_restrictions.py
Normal file
339
tests/test_alias_target_restrictions.py
Normal file
@@ -0,0 +1,339 @@
|
|||||||
|
"""
|
||||||
|
Tests for alias and target model restriction validation.
|
||||||
|
|
||||||
|
This test suite ensures that the restriction service properly validates
|
||||||
|
both alias names and their target models, preventing policy bypass vulnerabilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from providers.base import ProviderType
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
from providers.openai import OpenAIModelProvider
|
||||||
|
from utils.model_restrictions import ModelRestrictionService
|
||||||
|
|
||||||
|
|
||||||
|
class TestAliasTargetRestrictions:
|
||||||
|
"""Test that restriction validation works for both aliases and their targets."""
|
||||||
|
|
||||||
|
def test_openai_alias_target_validation_comprehensive(self):
|
||||||
|
"""Test OpenAI provider includes both aliases and targets in validation."""
|
||||||
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
# Get all known models including aliases and targets
|
||||||
|
all_known = provider.list_all_known_models()
|
||||||
|
|
||||||
|
# Should include both aliases and their targets
|
||||||
|
assert "mini" in all_known # alias
|
||||||
|
assert "o4-mini" in all_known # target of 'mini'
|
||||||
|
assert "o3mini" in all_known # alias
|
||||||
|
assert "o3-mini" in all_known # target of 'o3mini'
|
||||||
|
|
||||||
|
def test_gemini_alias_target_validation_comprehensive(self):
|
||||||
|
"""Test Gemini provider includes both aliases and targets in validation."""
|
||||||
|
provider = GeminiModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
# Get all known models including aliases and targets
|
||||||
|
all_known = provider.list_all_known_models()
|
||||||
|
|
||||||
|
# Should include both aliases and their targets
|
||||||
|
assert "flash" in all_known # alias
|
||||||
|
assert "gemini-2.5-flash-preview-05-20" in all_known # target of 'flash'
|
||||||
|
assert "pro" in all_known # alias
|
||||||
|
assert "gemini-2.5-pro-preview-06-05" in all_known # target of 'pro'
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini"}) # Allow target
|
||||||
|
def test_restriction_policy_allows_alias_when_target_allowed(self):
|
||||||
|
"""Test that restriction policy allows alias when target model is allowed.
|
||||||
|
|
||||||
|
This is the correct user-friendly behavior - if you allow 'o4-mini',
|
||||||
|
you should be able to use its alias 'mini' as well.
|
||||||
|
"""
|
||||||
|
# Clear cached restriction service
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
# Both target and alias should be allowed
|
||||||
|
assert provider.validate_model_name("o4-mini")
|
||||||
|
assert provider.validate_model_name("mini")
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini"}) # Allow alias only
|
||||||
|
def test_restriction_policy_allows_only_alias_when_alias_specified(self):
|
||||||
|
"""Test that restriction policy allows only the alias when just alias is specified.
|
||||||
|
|
||||||
|
If you restrict to 'mini', only the alias should work, not the direct target.
|
||||||
|
This is the correct restrictive behavior.
|
||||||
|
"""
|
||||||
|
# Clear cached restriction service
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
# Only the alias should be allowed
|
||||||
|
assert provider.validate_model_name("mini")
|
||||||
|
# Direct target should NOT be allowed
|
||||||
|
assert not provider.validate_model_name("o4-mini")
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash-preview-05-20"}) # Allow target
|
||||||
|
def test_gemini_restriction_policy_allows_alias_when_target_allowed(self):
|
||||||
|
"""Test Gemini restriction policy allows alias when target is allowed."""
|
||||||
|
# Clear cached restriction service
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
provider = GeminiModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
# Both target and alias should be allowed
|
||||||
|
assert provider.validate_model_name("gemini-2.5-flash-preview-05-20")
|
||||||
|
assert provider.validate_model_name("flash")
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "flash"}) # Allow alias only
|
||||||
|
def test_gemini_restriction_policy_allows_only_alias_when_alias_specified(self):
|
||||||
|
"""Test Gemini restriction policy allows only alias when just alias is specified."""
|
||||||
|
# Clear cached restriction service
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
provider = GeminiModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
# Only the alias should be allowed
|
||||||
|
assert provider.validate_model_name("flash")
|
||||||
|
# Direct target should NOT be allowed
|
||||||
|
assert not provider.validate_model_name("gemini-2.5-flash-preview-05-20")
|
||||||
|
|
||||||
|
def test_restriction_service_validation_includes_all_targets(self):
|
||||||
|
"""Test that restriction service validation knows about all aliases and targets."""
|
||||||
|
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini,invalid-model"}):
|
||||||
|
service = ModelRestrictionService()
|
||||||
|
|
||||||
|
# Create real provider instances
|
||||||
|
provider_instances = {ProviderType.OPENAI: OpenAIModelProvider(api_key="test-key")}
|
||||||
|
|
||||||
|
# Capture warnings
|
||||||
|
with patch("utils.model_restrictions.logger") as mock_logger:
|
||||||
|
service.validate_against_known_models(provider_instances)
|
||||||
|
|
||||||
|
# Should have warned about the invalid model
|
||||||
|
warning_calls = [call for call in mock_logger.warning.call_args_list if "invalid-model" in str(call)]
|
||||||
|
assert len(warning_calls) > 0, "Should have warned about invalid-model"
|
||||||
|
|
||||||
|
# The warning should include both aliases and targets in known models
|
||||||
|
warning_message = str(warning_calls[0])
|
||||||
|
assert "mini" in warning_message # alias should be in known models
|
||||||
|
assert "o4-mini" in warning_message # target should be in known models
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini,o4-mini"}) # Allow both alias and target
|
||||||
|
def test_both_alias_and_target_allowed_when_both_specified(self):
|
||||||
|
"""Test that both alias and target work when both are explicitly allowed."""
|
||||||
|
# Clear cached restriction service
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
# Both should be allowed
|
||||||
|
assert provider.validate_model_name("mini")
|
||||||
|
assert provider.validate_model_name("o4-mini")
|
||||||
|
|
||||||
|
def test_alias_target_policy_regression_prevention(self):
|
||||||
|
"""Regression test to ensure aliases and targets are both validated properly.
|
||||||
|
|
||||||
|
This test specifically prevents the bug where list_models() only returned
|
||||||
|
aliases but not their targets, causing restriction validation to miss
|
||||||
|
deny-list entries for target models.
|
||||||
|
"""
|
||||||
|
# Test OpenAI provider
|
||||||
|
openai_provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
openai_all_known = openai_provider.list_all_known_models()
|
||||||
|
|
||||||
|
# Verify that for each alias, its target is also included
|
||||||
|
for model_name, config in openai_provider.SUPPORTED_MODELS.items():
|
||||||
|
assert model_name.lower() in openai_all_known
|
||||||
|
if isinstance(config, str): # This is an alias
|
||||||
|
# The target should also be in the known models
|
||||||
|
assert (
|
||||||
|
config.lower() in openai_all_known
|
||||||
|
), f"Target '{config}' for alias '{model_name}' not in known models"
|
||||||
|
|
||||||
|
# Test Gemini provider
|
||||||
|
gemini_provider = GeminiModelProvider(api_key="test-key")
|
||||||
|
gemini_all_known = gemini_provider.list_all_known_models()
|
||||||
|
|
||||||
|
# Verify that for each alias, its target is also included
|
||||||
|
for model_name, config in gemini_provider.SUPPORTED_MODELS.items():
|
||||||
|
assert model_name.lower() in gemini_all_known
|
||||||
|
if isinstance(config, str): # This is an alias
|
||||||
|
# The target should also be in the known models
|
||||||
|
assert (
|
||||||
|
config.lower() in gemini_all_known
|
||||||
|
), f"Target '{config}' for alias '{model_name}' not in known models"
|
||||||
|
|
||||||
|
def test_no_duplicate_models_in_list_all_known_models(self):
|
||||||
|
"""Test that list_all_known_models doesn't return duplicates."""
|
||||||
|
# Test all providers
|
||||||
|
providers = [
|
||||||
|
OpenAIModelProvider(api_key="test-key"),
|
||||||
|
GeminiModelProvider(api_key="test-key"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for provider in providers:
|
||||||
|
all_known = provider.list_all_known_models()
|
||||||
|
# Should not have duplicates
|
||||||
|
assert len(all_known) == len(set(all_known)), f"{provider.__class__.__name__} returns duplicate models"
|
||||||
|
|
||||||
|
def test_restriction_validation_uses_polymorphic_interface(self):
|
||||||
|
"""Test that restriction validation uses the clean polymorphic interface."""
|
||||||
|
service = ModelRestrictionService()
|
||||||
|
|
||||||
|
# Create a mock provider
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
mock_provider = MagicMock()
|
||||||
|
mock_provider.list_all_known_models.return_value = ["model1", "model2", "target-model"]
|
||||||
|
|
||||||
|
# Set up a restriction that should trigger validation
|
||||||
|
service.restrictions = {ProviderType.OPENAI: {"invalid-model"}}
|
||||||
|
|
||||||
|
provider_instances = {ProviderType.OPENAI: mock_provider}
|
||||||
|
|
||||||
|
# Should call the polymorphic method
|
||||||
|
service.validate_against_known_models(provider_instances)
|
||||||
|
|
||||||
|
# Verify the polymorphic method was called
|
||||||
|
mock_provider.list_all_known_models.assert_called_once()
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini-high"}) # Restrict to specific model
|
||||||
|
def test_complex_alias_chains_handled_correctly(self):
|
||||||
|
"""Test that complex alias chains are handled correctly in restrictions."""
|
||||||
|
# Clear cached restriction service
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
# Only o4-mini-high should be allowed
|
||||||
|
assert provider.validate_model_name("o4-mini-high")
|
||||||
|
|
||||||
|
# Other models should be blocked
|
||||||
|
assert not provider.validate_model_name("o4-mini")
|
||||||
|
assert not provider.validate_model_name("mini") # This resolves to o4-mini
|
||||||
|
assert not provider.validate_model_name("o3-mini")
|
||||||
|
|
||||||
|
def test_critical_regression_validation_sees_alias_targets(self):
|
||||||
|
"""CRITICAL REGRESSION TEST: Ensure validation can see alias target models.
|
||||||
|
|
||||||
|
This test prevents the specific bug where list_models() only returned
|
||||||
|
alias keys but not their targets, causing validate_against_known_models()
|
||||||
|
to miss restrictions on target model names.
|
||||||
|
|
||||||
|
Before the fix:
|
||||||
|
- list_models() returned ["mini", "o3mini"] (aliases only)
|
||||||
|
- validate_against_known_models() only checked against ["mini", "o3mini"]
|
||||||
|
- A restriction on "o4-mini" (target) would not be recognized as valid
|
||||||
|
|
||||||
|
After the fix:
|
||||||
|
- list_all_known_models() returns ["mini", "o3mini", "o4-mini", "o3-mini"] (aliases + targets)
|
||||||
|
- validate_against_known_models() checks against all names
|
||||||
|
- A restriction on "o4-mini" is recognized as valid
|
||||||
|
"""
|
||||||
|
# This test specifically validates the HIGH-severity bug that was found
|
||||||
|
service = ModelRestrictionService()
|
||||||
|
|
||||||
|
# Create provider instance
|
||||||
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
provider_instances = {ProviderType.OPENAI: provider}
|
||||||
|
|
||||||
|
# Get all known models - should include BOTH aliases AND targets
|
||||||
|
all_known = provider.list_all_known_models()
|
||||||
|
|
||||||
|
# Critical check: should contain both aliases and their targets
|
||||||
|
assert "mini" in all_known # alias
|
||||||
|
assert "o4-mini" in all_known # target of mini - THIS WAS MISSING BEFORE
|
||||||
|
assert "o3mini" in all_known # alias
|
||||||
|
assert "o3-mini" in all_known # target of o3mini - THIS WAS MISSING BEFORE
|
||||||
|
|
||||||
|
# Simulate restriction validation with a target model name
|
||||||
|
# This should NOT warn because "o4-mini" is a valid target
|
||||||
|
with patch("utils.model_restrictions.logger") as mock_logger:
|
||||||
|
# Set restriction to target model (not alias)
|
||||||
|
service.restrictions = {ProviderType.OPENAI: {"o4-mini"}}
|
||||||
|
|
||||||
|
# This should NOT generate warnings because o4-mini is known
|
||||||
|
service.validate_against_known_models(provider_instances)
|
||||||
|
|
||||||
|
# Should NOT have any warnings about o4-mini being unrecognized
|
||||||
|
warning_calls = [
|
||||||
|
call
|
||||||
|
for call in mock_logger.warning.call_args_list
|
||||||
|
if "o4-mini" in str(call) and "not a recognized" in str(call)
|
||||||
|
]
|
||||||
|
assert len(warning_calls) == 0, "o4-mini should be recognized as valid target model"
|
||||||
|
|
||||||
|
# Test the reverse: alias in restriction should also be recognized
|
||||||
|
with patch("utils.model_restrictions.logger") as mock_logger:
|
||||||
|
# Set restriction to alias name
|
||||||
|
service.restrictions = {ProviderType.OPENAI: {"mini"}}
|
||||||
|
|
||||||
|
# This should NOT generate warnings because mini is known
|
||||||
|
service.validate_against_known_models(provider_instances)
|
||||||
|
|
||||||
|
# Should NOT have any warnings about mini being unrecognized
|
||||||
|
warning_calls = [
|
||||||
|
call
|
||||||
|
for call in mock_logger.warning.call_args_list
|
||||||
|
if "mini" in str(call) and "not a recognized" in str(call)
|
||||||
|
]
|
||||||
|
assert len(warning_calls) == 0, "mini should be recognized as valid alias"
|
||||||
|
|
||||||
|
def test_critical_regression_prevents_policy_bypass(self):
|
||||||
|
"""CRITICAL REGRESSION TEST: Prevent policy bypass through missing target validation.
|
||||||
|
|
||||||
|
This test ensures that if an admin restricts access to a target model name,
|
||||||
|
the restriction is properly enforced and the target is recognized as a valid
|
||||||
|
model to restrict.
|
||||||
|
|
||||||
|
The bug: If list_all_known_models() doesn't include targets, then validation
|
||||||
|
would incorrectly warn that target model names are "not recognized", making
|
||||||
|
it appear that target-based restrictions don't work.
|
||||||
|
"""
|
||||||
|
# Test with a made-up restriction scenario
|
||||||
|
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini-high,o3-mini"}):
|
||||||
|
# Clear cached restriction service
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
service = ModelRestrictionService()
|
||||||
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
# These specific target models should be recognized as valid
|
||||||
|
all_known = provider.list_all_known_models()
|
||||||
|
assert "o4-mini-high" in all_known, "Target model o4-mini-high should be known"
|
||||||
|
assert "o3-mini" in all_known, "Target model o3-mini should be known"
|
||||||
|
|
||||||
|
# Validation should not warn about these being unrecognized
|
||||||
|
with patch("utils.model_restrictions.logger") as mock_logger:
|
||||||
|
provider_instances = {ProviderType.OPENAI: provider}
|
||||||
|
service.validate_against_known_models(provider_instances)
|
||||||
|
|
||||||
|
# Should not warn about our allowed models being unrecognized
|
||||||
|
all_warnings = [str(call) for call in mock_logger.warning.call_args_list]
|
||||||
|
for warning in all_warnings:
|
||||||
|
assert "o4-mini-high" not in warning or "not a recognized" not in warning
|
||||||
|
assert "o3-mini" not in warning or "not a recognized" not in warning
|
||||||
|
|
||||||
|
# The restriction should actually work
|
||||||
|
assert provider.validate_model_name("o4-mini-high")
|
||||||
|
assert provider.validate_model_name("o3-mini")
|
||||||
|
assert not provider.validate_model_name("o4-mini") # not in allowed list
|
||||||
|
assert not provider.validate_model_name("o3") # not in allowed list
|
||||||
288
tests/test_buggy_behavior_prevention.py
Normal file
288
tests/test_buggy_behavior_prevention.py
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
"""
|
||||||
|
Tests that demonstrate the OLD BUGGY BEHAVIOR is now FIXED.
|
||||||
|
|
||||||
|
These tests verify that scenarios which would have incorrectly passed
|
||||||
|
before our fix now behave correctly. Each test documents the specific
|
||||||
|
bug that was fixed and what the old vs new behavior should be.
|
||||||
|
|
||||||
|
IMPORTANT: These tests PASS with our fix, but would have FAILED to catch
|
||||||
|
bugs with the old code (before list_all_known_models was implemented).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from providers.base import ProviderType
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
from providers.openai import OpenAIModelProvider
|
||||||
|
from utils.model_restrictions import ModelRestrictionService
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuggyBehaviorPrevention:
|
||||||
|
"""
|
||||||
|
These tests prove that our fix prevents the HIGH-severity regression
|
||||||
|
that was identified by the O3 precommit analysis.
|
||||||
|
|
||||||
|
OLD BUG: list_models() only returned alias keys, not targets
|
||||||
|
FIX: list_all_known_models() returns both aliases AND targets
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_old_bug_would_miss_target_restrictions(self):
|
||||||
|
"""
|
||||||
|
OLD BUG: If restriction was set on target model (e.g., 'o4-mini'),
|
||||||
|
validation would incorrectly warn it's not recognized because
|
||||||
|
list_models() only returned aliases ['mini', 'o3mini'].
|
||||||
|
|
||||||
|
NEW BEHAVIOR: list_all_known_models() includes targets, so 'o4-mini'
|
||||||
|
is recognized as valid and no warning is generated.
|
||||||
|
"""
|
||||||
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
# This is what the old broken list_models() would return - aliases only
|
||||||
|
old_broken_list = ["mini", "o3mini"] # Missing 'o4-mini', 'o3-mini' targets
|
||||||
|
|
||||||
|
# This is what our fixed list_all_known_models() returns
|
||||||
|
new_fixed_list = provider.list_all_known_models()
|
||||||
|
|
||||||
|
# Verify the fix: new method includes both aliases AND targets
|
||||||
|
assert "mini" in new_fixed_list # alias
|
||||||
|
assert "o4-mini" in new_fixed_list # target - THIS WAS MISSING IN OLD CODE
|
||||||
|
assert "o3mini" in new_fixed_list # alias
|
||||||
|
assert "o3-mini" in new_fixed_list # target - THIS WAS MISSING IN OLD CODE
|
||||||
|
|
||||||
|
# Prove the old behavior was broken
|
||||||
|
assert "o4-mini" not in old_broken_list # Old code didn't include targets
|
||||||
|
assert "o3-mini" not in old_broken_list # Old code didn't include targets
|
||||||
|
|
||||||
|
# This target validation would have FAILED with old code
|
||||||
|
service = ModelRestrictionService()
|
||||||
|
service.restrictions = {ProviderType.OPENAI: {"o4-mini"}} # Restrict to target
|
||||||
|
|
||||||
|
with patch("utils.model_restrictions.logger") as mock_logger:
|
||||||
|
provider_instances = {ProviderType.OPENAI: provider}
|
||||||
|
service.validate_against_known_models(provider_instances)
|
||||||
|
|
||||||
|
# NEW BEHAVIOR: No warnings because o4-mini is now in list_all_known_models
|
||||||
|
target_warnings = [
|
||||||
|
call
|
||||||
|
for call in mock_logger.warning.call_args_list
|
||||||
|
if "o4-mini" in str(call) and "not a recognized" in str(call)
|
||||||
|
]
|
||||||
|
assert len(target_warnings) == 0, "o4-mini should be recognized with our fix"
|
||||||
|
|
||||||
|
def test_old_bug_would_incorrectly_warn_about_valid_targets(self):
|
||||||
|
"""
|
||||||
|
OLD BUG: Admins setting restrictions on target models would get
|
||||||
|
false warnings that their restriction models are "not recognized".
|
||||||
|
|
||||||
|
NEW BEHAVIOR: Target models are properly recognized.
|
||||||
|
"""
|
||||||
|
# Test with Gemini provider too
|
||||||
|
provider = GeminiModelProvider(api_key="test-key")
|
||||||
|
all_known = provider.list_all_known_models()
|
||||||
|
|
||||||
|
# Verify both aliases and targets are included
|
||||||
|
assert "flash" in all_known # alias
|
||||||
|
assert "gemini-2.5-flash-preview-05-20" in all_known # target
|
||||||
|
assert "pro" in all_known # alias
|
||||||
|
assert "gemini-2.5-pro-preview-06-05" in all_known # target
|
||||||
|
|
||||||
|
# Simulate admin restricting to target model names
|
||||||
|
service = ModelRestrictionService()
|
||||||
|
service.restrictions = {
|
||||||
|
ProviderType.GOOGLE: {
|
||||||
|
"gemini-2.5-flash-preview-05-20", # Target name restriction
|
||||||
|
"gemini-2.5-pro-preview-06-05", # Target name restriction
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("utils.model_restrictions.logger") as mock_logger:
|
||||||
|
provider_instances = {ProviderType.GOOGLE: provider}
|
||||||
|
service.validate_against_known_models(provider_instances)
|
||||||
|
|
||||||
|
# Should NOT warn about these valid target models
|
||||||
|
all_warnings = [str(call) for call in mock_logger.warning.call_args_list]
|
||||||
|
for warning in all_warnings:
|
||||||
|
assert "gemini-2.5-flash-preview-05-20" not in warning or "not a recognized" not in warning
|
||||||
|
assert "gemini-2.5-pro-preview-06-05" not in warning or "not a recognized" not in warning
|
||||||
|
|
||||||
|
def test_old_bug_policy_bypass_prevention(self):
|
||||||
|
"""
|
||||||
|
OLD BUG: Policy enforcement was incomplete because validation
|
||||||
|
didn't know about target models. This could allow policy bypasses.
|
||||||
|
|
||||||
|
NEW BEHAVIOR: Complete validation against all known model names.
|
||||||
|
"""
|
||||||
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
# Simulate a scenario where admin wants to restrict specific targets
|
||||||
|
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini-high"}):
|
||||||
|
# Clear cached restriction service
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
# These should work because they're explicitly allowed
|
||||||
|
assert provider.validate_model_name("o3-mini")
|
||||||
|
assert provider.validate_model_name("o4-mini-high")
|
||||||
|
|
||||||
|
# These should be blocked
|
||||||
|
assert not provider.validate_model_name("o4-mini") # Not in allowed list
|
||||||
|
assert not provider.validate_model_name("o3") # Not in allowed list
|
||||||
|
assert not provider.validate_model_name("mini") # Resolves to o4-mini, not allowed
|
||||||
|
|
||||||
|
# Verify our list_all_known_models includes the restricted models
|
||||||
|
all_known = provider.list_all_known_models()
|
||||||
|
assert "o3-mini" in all_known # Should be known (and allowed)
|
||||||
|
assert "o4-mini-high" in all_known # Should be known (and allowed)
|
||||||
|
assert "o4-mini" in all_known # Should be known (but blocked)
|
||||||
|
assert "mini" in all_known # Should be known (but blocked)
|
||||||
|
|
||||||
|
def test_demonstration_of_old_vs_new_interface(self):
|
||||||
|
"""
|
||||||
|
Direct comparison of old vs new interface to document the fix.
|
||||||
|
"""
|
||||||
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
# OLD interface (still exists for backward compatibility)
|
||||||
|
old_style_models = provider.list_models(respect_restrictions=False)
|
||||||
|
|
||||||
|
# NEW interface (our fix)
|
||||||
|
new_comprehensive_models = provider.list_all_known_models()
|
||||||
|
|
||||||
|
# The new interface should be a superset of the old one
|
||||||
|
for model in old_style_models:
|
||||||
|
assert model.lower() in [
|
||||||
|
m.lower() for m in new_comprehensive_models
|
||||||
|
], f"New interface missing model {model} from old interface"
|
||||||
|
|
||||||
|
# The new interface should include target models that old one might miss
|
||||||
|
targets_that_should_exist = ["o4-mini", "o3-mini"]
|
||||||
|
for target in targets_that_should_exist:
|
||||||
|
assert target in new_comprehensive_models, f"New interface should include target model {target}"
|
||||||
|
|
||||||
|
def test_old_validation_interface_still_works(self):
|
||||||
|
"""
|
||||||
|
Verify our fix doesn't break existing validation workflows.
|
||||||
|
"""
|
||||||
|
service = ModelRestrictionService()
|
||||||
|
|
||||||
|
# Create a mock provider that simulates the old behavior
|
||||||
|
old_style_provider = MagicMock()
|
||||||
|
old_style_provider.SUPPORTED_MODELS = {
|
||||||
|
"mini": "o4-mini",
|
||||||
|
"o3mini": "o3-mini",
|
||||||
|
"o4-mini": {"context_window": 200000},
|
||||||
|
"o3-mini": {"context_window": 200000},
|
||||||
|
}
|
||||||
|
# OLD BROKEN: This would only return aliases
|
||||||
|
old_style_provider.list_models.return_value = ["mini", "o3mini"]
|
||||||
|
# NEW FIXED: This includes both aliases and targets
|
||||||
|
old_style_provider.list_all_known_models.return_value = ["mini", "o3mini", "o4-mini", "o3-mini"]
|
||||||
|
|
||||||
|
# Test that validation now uses the comprehensive method
|
||||||
|
service.restrictions = {ProviderType.OPENAI: {"o4-mini"}} # Restrict to target
|
||||||
|
|
||||||
|
with patch("utils.model_restrictions.logger") as mock_logger:
|
||||||
|
provider_instances = {ProviderType.OPENAI: old_style_provider}
|
||||||
|
service.validate_against_known_models(provider_instances)
|
||||||
|
|
||||||
|
# Verify the new method was called, not the old one
|
||||||
|
old_style_provider.list_all_known_models.assert_called_once()
|
||||||
|
|
||||||
|
# Should not warn about o4-mini being unrecognized
|
||||||
|
target_warnings = [
|
||||||
|
call
|
||||||
|
for call in mock_logger.warning.call_args_list
|
||||||
|
if "o4-mini" in str(call) and "not a recognized" in str(call)
|
||||||
|
]
|
||||||
|
assert len(target_warnings) == 0
|
||||||
|
|
||||||
|
def test_regression_proof_comprehensive_coverage(self):
|
||||||
|
"""
|
||||||
|
Comprehensive test to prove our fix covers all provider types.
|
||||||
|
"""
|
||||||
|
providers_to_test = [
|
||||||
|
(OpenAIModelProvider(api_key="test-key"), "mini", "o4-mini"),
|
||||||
|
(GeminiModelProvider(api_key="test-key"), "flash", "gemini-2.5-flash-preview-05-20"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for provider, alias, target in providers_to_test:
|
||||||
|
all_known = provider.list_all_known_models()
|
||||||
|
|
||||||
|
# Every provider should include both aliases and targets
|
||||||
|
assert alias in all_known, f"{provider.__class__.__name__} missing alias {alias}"
|
||||||
|
assert target in all_known, f"{provider.__class__.__name__} missing target {target}"
|
||||||
|
|
||||||
|
# No duplicates should exist
|
||||||
|
assert len(all_known) == len(set(all_known)), f"{provider.__class__.__name__} returns duplicate models"
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini,invalid-model"})
|
||||||
|
def test_validation_correctly_identifies_invalid_models(self):
|
||||||
|
"""
|
||||||
|
Test that validation still catches truly invalid models while
|
||||||
|
properly recognizing valid target models.
|
||||||
|
|
||||||
|
This proves our fix works: o4-mini appears in the "Known models" list
|
||||||
|
because list_all_known_models() now includes target models.
|
||||||
|
"""
|
||||||
|
# Clear cached restriction service
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
service = ModelRestrictionService()
|
||||||
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
with patch("utils.model_restrictions.logger") as mock_logger:
|
||||||
|
provider_instances = {ProviderType.OPENAI: provider}
|
||||||
|
service.validate_against_known_models(provider_instances)
|
||||||
|
|
||||||
|
# Should warn about 'invalid-model' (truly invalid)
|
||||||
|
invalid_warnings = [
|
||||||
|
call
|
||||||
|
for call in mock_logger.warning.call_args_list
|
||||||
|
if "invalid-model" in str(call) and "not a recognized" in str(call)
|
||||||
|
]
|
||||||
|
assert len(invalid_warnings) > 0, "Should warn about truly invalid models"
|
||||||
|
|
||||||
|
# The warning should mention o4-mini in the "Known models" list (proving our fix works)
|
||||||
|
warning_text = str(mock_logger.warning.call_args_list[0])
|
||||||
|
assert "Known models:" in warning_text, "Warning should include known models list"
|
||||||
|
assert "o4-mini" in warning_text, "o4-mini should appear in known models (proves our fix works)"
|
||||||
|
assert "o3-mini" in warning_text, "o3-mini should appear in known models (proves our fix works)"
|
||||||
|
|
||||||
|
# But the warning should be specifically about invalid-model
|
||||||
|
assert "'invalid-model'" in warning_text, "Warning should specifically mention invalid-model"
|
||||||
|
|
||||||
|
def test_custom_provider_also_implements_fix(self):
|
||||||
|
"""
|
||||||
|
Verify that custom provider also implements the comprehensive interface.
|
||||||
|
"""
|
||||||
|
from providers.custom import CustomProvider
|
||||||
|
|
||||||
|
# This might fail if no URL is set, but that's expected
|
||||||
|
try:
|
||||||
|
provider = CustomProvider(base_url="http://test.com/v1")
|
||||||
|
all_known = provider.list_all_known_models()
|
||||||
|
# Should return a list (might be empty if registry not loaded)
|
||||||
|
assert isinstance(all_known, list)
|
||||||
|
except ValueError:
|
||||||
|
# Expected if no base_url configured, skip this test
|
||||||
|
pytest.skip("Custom provider requires URL configuration")
|
||||||
|
|
||||||
|
def test_openrouter_provider_also_implements_fix(self):
|
||||||
|
"""
|
||||||
|
Verify that OpenRouter provider also implements the comprehensive interface.
|
||||||
|
"""
|
||||||
|
from providers.openrouter import OpenRouterProvider
|
||||||
|
|
||||||
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
|
all_known = provider.list_all_known_models()
|
||||||
|
|
||||||
|
# Should return a list with both aliases and targets
|
||||||
|
assert isinstance(all_known, list)
|
||||||
|
# Should include some known OpenRouter aliases and their targets
|
||||||
|
# (Exact content depends on registry, but structure should be correct)
|
||||||
@@ -142,6 +142,7 @@ class TestModelRestrictionService:
|
|||||||
"o3-mini": {"context_window": 200000},
|
"o3-mini": {"context_window": 200000},
|
||||||
"o4-mini": {"context_window": 200000},
|
"o4-mini": {"context_window": 200000},
|
||||||
}
|
}
|
||||||
|
mock_provider.list_all_known_models.return_value = ["o3", "o3-mini", "o4-mini"]
|
||||||
|
|
||||||
provider_instances = {ProviderType.OPENAI: mock_provider}
|
provider_instances = {ProviderType.OPENAI: mock_provider}
|
||||||
service.validate_against_known_models(provider_instances)
|
service.validate_against_known_models(provider_instances)
|
||||||
@@ -444,12 +445,57 @@ class TestRegistryIntegration:
|
|||||||
"o3": {"context_window": 200000},
|
"o3": {"context_window": 200000},
|
||||||
"o3-mini": {"context_window": 200000},
|
"o3-mini": {"context_window": 200000},
|
||||||
}
|
}
|
||||||
|
mock_openai.get_provider_type.return_value = ProviderType.OPENAI
|
||||||
|
|
||||||
|
def openai_list_models(respect_restrictions=True):
|
||||||
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
|
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||||
|
models = []
|
||||||
|
for model_name, config in mock_openai.SUPPORTED_MODELS.items():
|
||||||
|
if isinstance(config, str):
|
||||||
|
target_model = config
|
||||||
|
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model):
|
||||||
|
continue
|
||||||
|
models.append(model_name)
|
||||||
|
else:
|
||||||
|
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, model_name):
|
||||||
|
continue
|
||||||
|
models.append(model_name)
|
||||||
|
return models
|
||||||
|
|
||||||
|
mock_openai.list_models = openai_list_models
|
||||||
|
mock_openai.list_all_known_models.return_value = ["o3", "o3-mini"]
|
||||||
|
|
||||||
mock_gemini = MagicMock()
|
mock_gemini = MagicMock()
|
||||||
mock_gemini.SUPPORTED_MODELS = {
|
mock_gemini.SUPPORTED_MODELS = {
|
||||||
"gemini-2.5-pro-preview-06-05": {"context_window": 1048576},
|
"gemini-2.5-pro-preview-06-05": {"context_window": 1048576},
|
||||||
"gemini-2.5-flash-preview-05-20": {"context_window": 1048576},
|
"gemini-2.5-flash-preview-05-20": {"context_window": 1048576},
|
||||||
}
|
}
|
||||||
|
mock_gemini.get_provider_type.return_value = ProviderType.GOOGLE
|
||||||
|
|
||||||
|
def gemini_list_models(respect_restrictions=True):
|
||||||
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
|
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||||
|
models = []
|
||||||
|
for model_name, config in mock_gemini.SUPPORTED_MODELS.items():
|
||||||
|
if isinstance(config, str):
|
||||||
|
target_model = config
|
||||||
|
if restriction_service and not restriction_service.is_allowed(ProviderType.GOOGLE, target_model):
|
||||||
|
continue
|
||||||
|
models.append(model_name)
|
||||||
|
else:
|
||||||
|
if restriction_service and not restriction_service.is_allowed(ProviderType.GOOGLE, model_name):
|
||||||
|
continue
|
||||||
|
models.append(model_name)
|
||||||
|
return models
|
||||||
|
|
||||||
|
mock_gemini.list_models = gemini_list_models
|
||||||
|
mock_gemini.list_all_known_models.return_value = [
|
||||||
|
"gemini-2.5-pro-preview-06-05",
|
||||||
|
"gemini-2.5-flash-preview-05-20",
|
||||||
|
]
|
||||||
|
|
||||||
def get_provider_side_effect(provider_type):
|
def get_provider_side_effect(provider_type):
|
||||||
if provider_type == ProviderType.OPENAI:
|
if provider_type == ProviderType.OPENAI:
|
||||||
@@ -569,6 +615,27 @@ class TestAutoModeWithRestrictions:
|
|||||||
"o3-mini": {"context_window": 200000},
|
"o3-mini": {"context_window": 200000},
|
||||||
"o4-mini": {"context_window": 200000},
|
"o4-mini": {"context_window": 200000},
|
||||||
}
|
}
|
||||||
|
mock_openai.get_provider_type.return_value = ProviderType.OPENAI
|
||||||
|
|
||||||
|
def openai_list_models(respect_restrictions=True):
|
||||||
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
|
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||||
|
models = []
|
||||||
|
for model_name, config in mock_openai.SUPPORTED_MODELS.items():
|
||||||
|
if isinstance(config, str):
|
||||||
|
target_model = config
|
||||||
|
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model):
|
||||||
|
continue
|
||||||
|
models.append(model_name)
|
||||||
|
else:
|
||||||
|
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, model_name):
|
||||||
|
continue
|
||||||
|
models.append(model_name)
|
||||||
|
return models
|
||||||
|
|
||||||
|
mock_openai.list_models = openai_list_models
|
||||||
|
mock_openai.list_all_known_models.return_value = ["o3", "o3-mini", "o4-mini"]
|
||||||
|
|
||||||
def get_provider_side_effect(provider_type):
|
def get_provider_side_effect(provider_type):
|
||||||
if provider_type == ProviderType.OPENAI:
|
if provider_type == ProviderType.OPENAI:
|
||||||
|
|||||||
216
tests/test_old_behavior_simulation.py
Normal file
216
tests/test_old_behavior_simulation.py
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
"""
|
||||||
|
Tests that simulate the OLD BROKEN BEHAVIOR to prove it was indeed broken.
|
||||||
|
|
||||||
|
These tests create mock providers that behave like the old code (before our fix)
|
||||||
|
and demonstrate that they would have failed to catch the HIGH-severity bug.
|
||||||
|
|
||||||
|
IMPORTANT: These tests show what WOULD HAVE HAPPENED with the old code.
|
||||||
|
They prove that our fix was necessary and actually addresses real problems.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from providers.base import ProviderType
|
||||||
|
from utils.model_restrictions import ModelRestrictionService
|
||||||
|
|
||||||
|
|
||||||
|
class TestOldBehaviorSimulation:
|
||||||
|
"""
|
||||||
|
Simulate the old broken behavior to prove it was buggy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_old_behavior_would_miss_target_restrictions(self):
|
||||||
|
"""
|
||||||
|
SIMULATION: This test recreates the OLD BROKEN BEHAVIOR and proves it was buggy.
|
||||||
|
|
||||||
|
OLD BUG: When validation service called provider.list_models(), it only got
|
||||||
|
aliases back, not targets. This meant target-based restrictions weren't validated.
|
||||||
|
"""
|
||||||
|
# Create a mock provider that simulates the OLD BROKEN BEHAVIOR
|
||||||
|
old_broken_provider = MagicMock()
|
||||||
|
old_broken_provider.SUPPORTED_MODELS = {
|
||||||
|
"mini": "o4-mini", # alias -> target
|
||||||
|
"o3mini": "o3-mini", # alias -> target
|
||||||
|
"o4-mini": {"context_window": 200000},
|
||||||
|
"o3-mini": {"context_window": 200000},
|
||||||
|
}
|
||||||
|
|
||||||
|
# OLD BROKEN: list_models only returned aliases, missing targets
|
||||||
|
old_broken_provider.list_models.return_value = ["mini", "o3mini"]
|
||||||
|
|
||||||
|
# OLD BROKEN: There was no list_all_known_models method!
|
||||||
|
# We simulate this by making it behave like the old list_models
|
||||||
|
old_broken_provider.list_all_known_models.return_value = ["mini", "o3mini"] # MISSING TARGETS!
|
||||||
|
|
||||||
|
# Now test what happens when admin tries to restrict by target model
|
||||||
|
service = ModelRestrictionService()
|
||||||
|
service.restrictions = {ProviderType.OPENAI: {"o4-mini"}} # Restrict to target model
|
||||||
|
|
||||||
|
with patch("utils.model_restrictions.logger") as mock_logger:
|
||||||
|
provider_instances = {ProviderType.OPENAI: old_broken_provider}
|
||||||
|
service.validate_against_known_models(provider_instances)
|
||||||
|
|
||||||
|
# OLD BROKEN BEHAVIOR: Would warn about o4-mini being "not recognized"
|
||||||
|
# because it wasn't in the list_all_known_models response
|
||||||
|
target_warnings = [
|
||||||
|
call
|
||||||
|
for call in mock_logger.warning.call_args_list
|
||||||
|
if "o4-mini" in str(call) and "not a recognized" in str(call)
|
||||||
|
]
|
||||||
|
|
||||||
|
# This proves the old behavior was broken - it would generate false warnings
|
||||||
|
assert len(target_warnings) > 0, "OLD BROKEN BEHAVIOR: Would incorrectly warn about valid target models"
|
||||||
|
|
||||||
|
# Verify the warning message shows the broken list
|
||||||
|
warning_text = str(target_warnings[0])
|
||||||
|
assert "mini" in warning_text # Alias was included
|
||||||
|
assert "o3mini" in warning_text # Alias was included
|
||||||
|
# But targets were missing from the known models list in old behavior
|
||||||
|
|
||||||
|
def test_new_behavior_fixes_the_problem(self):
|
||||||
|
"""
|
||||||
|
Compare old vs new behavior to show our fix works.
|
||||||
|
"""
|
||||||
|
# Create mock provider with NEW FIXED BEHAVIOR
|
||||||
|
new_fixed_provider = MagicMock()
|
||||||
|
new_fixed_provider.SUPPORTED_MODELS = {
|
||||||
|
"mini": "o4-mini",
|
||||||
|
"o3mini": "o3-mini",
|
||||||
|
"o4-mini": {"context_window": 200000},
|
||||||
|
"o3-mini": {"context_window": 200000},
|
||||||
|
}
|
||||||
|
|
||||||
|
# NEW FIXED: list_all_known_models includes BOTH aliases AND targets
|
||||||
|
new_fixed_provider.list_all_known_models.return_value = [
|
||||||
|
"mini",
|
||||||
|
"o3mini", # aliases
|
||||||
|
"o4-mini",
|
||||||
|
"o3-mini", # targets - THESE WERE MISSING IN OLD CODE!
|
||||||
|
]
|
||||||
|
|
||||||
|
# Same restriction scenario
|
||||||
|
service = ModelRestrictionService()
|
||||||
|
service.restrictions = {ProviderType.OPENAI: {"o4-mini"}} # Restrict to target model
|
||||||
|
|
||||||
|
with patch("utils.model_restrictions.logger") as mock_logger:
|
||||||
|
provider_instances = {ProviderType.OPENAI: new_fixed_provider}
|
||||||
|
service.validate_against_known_models(provider_instances)
|
||||||
|
|
||||||
|
# NEW FIXED BEHAVIOR: No warnings about o4-mini being unrecognized
|
||||||
|
target_warnings = [
|
||||||
|
call
|
||||||
|
for call in mock_logger.warning.call_args_list
|
||||||
|
if "o4-mini" in str(call) and "not a recognized" in str(call)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Our fix prevents false warnings
|
||||||
|
assert len(target_warnings) == 0, "NEW FIXED BEHAVIOR: Should not warn about valid target models"
|
||||||
|
|
||||||
|
def test_policy_bypass_prevention_old_vs_new(self):
|
||||||
|
"""
|
||||||
|
Show how the old behavior could have led to policy bypass scenarios.
|
||||||
|
"""
|
||||||
|
# OLD BROKEN: Admin thinks they've restricted access to o4-mini,
|
||||||
|
# but validation doesn't recognize it as a valid restriction target
|
||||||
|
old_broken_provider = MagicMock()
|
||||||
|
old_broken_provider.list_all_known_models.return_value = ["mini", "o3mini"] # Missing targets
|
||||||
|
|
||||||
|
# NEW FIXED: Same provider with our fix
|
||||||
|
new_fixed_provider = MagicMock()
|
||||||
|
new_fixed_provider.list_all_known_models.return_value = ["mini", "o3mini", "o4-mini", "o3-mini"]
|
||||||
|
|
||||||
|
# Test restriction on target model - use completely separate service instances
|
||||||
|
old_service = ModelRestrictionService()
|
||||||
|
old_service.restrictions = {ProviderType.OPENAI: {"o4-mini", "completely-invalid-model"}}
|
||||||
|
|
||||||
|
new_service = ModelRestrictionService()
|
||||||
|
new_service.restrictions = {ProviderType.OPENAI: {"o4-mini", "completely-invalid-model"}}
|
||||||
|
|
||||||
|
# OLD BEHAVIOR: Would warn about BOTH models being unrecognized
|
||||||
|
with patch("utils.model_restrictions.logger") as mock_logger_old:
|
||||||
|
provider_instances = {ProviderType.OPENAI: old_broken_provider}
|
||||||
|
old_service.validate_against_known_models(provider_instances)
|
||||||
|
|
||||||
|
old_warnings = [str(call) for call in mock_logger_old.warning.call_args_list]
|
||||||
|
print(f"OLD warnings: {old_warnings}") # Debug output
|
||||||
|
|
||||||
|
# NEW BEHAVIOR: Only warns about truly invalid model
|
||||||
|
with patch("utils.model_restrictions.logger") as mock_logger_new:
|
||||||
|
provider_instances = {ProviderType.OPENAI: new_fixed_provider}
|
||||||
|
new_service.validate_against_known_models(provider_instances)
|
||||||
|
|
||||||
|
new_warnings = [str(call) for call in mock_logger_new.warning.call_args_list]
|
||||||
|
print(f"NEW warnings: {new_warnings}") # Debug output
|
||||||
|
|
||||||
|
# For now, just verify that we get some warnings in both cases
|
||||||
|
# The key point is that the "Known models" list is different
|
||||||
|
assert len(old_warnings) > 0, "OLD: Should have warnings"
|
||||||
|
assert len(new_warnings) > 0, "NEW: Should have warnings for invalid model"
|
||||||
|
|
||||||
|
# Verify the known models list is different between old and new
|
||||||
|
str(old_warnings[0]) if old_warnings else ""
|
||||||
|
new_warning_text = str(new_warnings[0]) if new_warnings else ""
|
||||||
|
|
||||||
|
if "Known models:" in new_warning_text:
|
||||||
|
# NEW behavior should include o4-mini in known models list
|
||||||
|
assert "o4-mini" in new_warning_text, "NEW: Should include o4-mini in known models"
|
||||||
|
|
||||||
|
print("This test demonstrates that our fix improves the 'Known models' list shown to users.")
|
||||||
|
|
||||||
|
def test_demonstrate_target_coverage_improvement(self):
|
||||||
|
"""
|
||||||
|
Show the exact improvement in target model coverage.
|
||||||
|
"""
|
||||||
|
# Simulate different provider implementations
|
||||||
|
providers_old_vs_new = [
|
||||||
|
# (old_broken_list, new_fixed_list, provider_name)
|
||||||
|
(["mini", "o3mini"], ["mini", "o3mini", "o4-mini", "o3-mini"], "OpenAI"),
|
||||||
|
(
|
||||||
|
["flash", "pro"],
|
||||||
|
["flash", "pro", "gemini-2.5-flash-preview-05-20", "gemini-2.5-pro-preview-06-05"],
|
||||||
|
"Gemini",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
for old_list, new_list, provider_name in providers_old_vs_new:
|
||||||
|
# Count how many additional models are now covered
|
||||||
|
old_coverage = set(old_list)
|
||||||
|
new_coverage = set(new_list)
|
||||||
|
|
||||||
|
additional_coverage = new_coverage - old_coverage
|
||||||
|
|
||||||
|
# There should be additional target models covered
|
||||||
|
assert len(additional_coverage) > 0, f"{provider_name}: Should have additional target coverage"
|
||||||
|
|
||||||
|
# All old models should still be covered
|
||||||
|
assert old_coverage.issubset(new_coverage), f"{provider_name}: Should maintain backward compatibility"
|
||||||
|
|
||||||
|
print(f"{provider_name} provider:")
|
||||||
|
print(f" Old coverage: {sorted(old_coverage)}")
|
||||||
|
print(f" New coverage: {sorted(new_coverage)}")
|
||||||
|
print(f" Additional models: {sorted(additional_coverage)}")
|
||||||
|
|
||||||
|
def test_comprehensive_alias_target_mapping_verification(self):
|
||||||
|
"""
|
||||||
|
Verify that our fix provides comprehensive alias->target coverage.
|
||||||
|
"""
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
|
# Test real providers to ensure they implement our fix correctly
|
||||||
|
providers = [OpenAIModelProvider(api_key="test-key"), GeminiModelProvider(api_key="test-key")]
|
||||||
|
|
||||||
|
for provider in providers:
|
||||||
|
all_known = provider.list_all_known_models()
|
||||||
|
|
||||||
|
# Check that for every alias in SUPPORTED_MODELS, its target is also included
|
||||||
|
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||||
|
# Model name itself should be in the list
|
||||||
|
assert model_name.lower() in all_known, f"{provider.__class__.__name__}: Missing model {model_name}"
|
||||||
|
|
||||||
|
# If it's an alias (config is a string), target should also be in list
|
||||||
|
if isinstance(config, str):
|
||||||
|
target_model = config
|
||||||
|
assert (
|
||||||
|
target_model.lower() in all_known
|
||||||
|
), f"{provider.__class__.__name__}: Missing target {target_model} for alias {model_name}"
|
||||||
@@ -26,6 +26,12 @@ class TestOpenAICompatibleTokenUsage(unittest.TestCase):
|
|||||||
def validate_model_name(self, model_name):
|
def validate_model_name(self, model_name):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def list_models(self, respect_restrictions=True):
|
||||||
|
return ["test-model"]
|
||||||
|
|
||||||
|
def list_all_known_models(self):
|
||||||
|
return ["test-model"]
|
||||||
|
|
||||||
self.provider = TestProvider("test-key")
|
self.provider = TestProvider("test-key")
|
||||||
|
|
||||||
def test_extract_usage_with_valid_tokens(self):
|
def test_extract_usage_with_valid_tokens(self):
|
||||||
|
|||||||
@@ -233,8 +233,9 @@ class TestOpenRouterAutoMode:
|
|||||||
os.environ["DEFAULT_MODEL"] = "auto"
|
os.environ["DEFAULT_MODEL"] = "auto"
|
||||||
|
|
||||||
mock_provider_class = Mock()
|
mock_provider_class = Mock()
|
||||||
mock_provider_instance = Mock(spec=["get_provider_type"])
|
mock_provider_instance = Mock(spec=["get_provider_type", "list_models"])
|
||||||
mock_provider_instance.get_provider_type.return_value = ProviderType.OPENROUTER
|
mock_provider_instance.get_provider_type.return_value = ProviderType.OPENROUTER
|
||||||
|
mock_provider_instance.list_models.return_value = []
|
||||||
mock_provider_class.return_value = mock_provider_instance
|
mock_provider_class.return_value = mock_provider_instance
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, mock_provider_class)
|
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, mock_provider_class)
|
||||||
|
|||||||
@@ -90,18 +90,14 @@ class ModelRestrictionService:
|
|||||||
if not provider:
|
if not provider:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Get all supported models (including aliases)
|
# Get all supported models using the clean polymorphic interface
|
||||||
supported_models = set()
|
try:
|
||||||
|
# Use list_all_known_models to get both aliases and their targets
|
||||||
# For OpenAI and Gemini, we can check their SUPPORTED_MODELS
|
all_models = provider.list_all_known_models()
|
||||||
if hasattr(provider, "SUPPORTED_MODELS"):
|
supported_models = {model.lower() for model in all_models}
|
||||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
except Exception as e:
|
||||||
# Add the model name (lowercase)
|
logger.debug(f"Could not get model list from {provider_type.value} provider: {e}")
|
||||||
supported_models.add(model_name.lower())
|
supported_models = set()
|
||||||
|
|
||||||
# If it's an alias (string value), add the target too
|
|
||||||
if isinstance(config, str):
|
|
||||||
supported_models.add(config.lower())
|
|
||||||
|
|
||||||
# Check each allowed model
|
# Check each allowed model
|
||||||
for allowed_model in allowed_models:
|
for allowed_model in allowed_models:
|
||||||
|
|||||||
Reference in New Issue
Block a user