217 lines
9.9 KiB
Python
217 lines
9.9 KiB
Python
"""
|
|
Tests that simulate the OLD BROKEN BEHAVIOR to prove it was indeed broken.
|
|
|
|
These tests create mock providers that behave like the old code (before our fix)
|
|
and demonstrate that they would have failed to catch the HIGH-severity bug.
|
|
|
|
IMPORTANT: These tests show what WOULD HAVE HAPPENED with the old code.
|
|
They prove that our fix was necessary and actually addresses real problems.
|
|
"""
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from providers.base import ProviderType
|
|
from utils.model_restrictions import ModelRestrictionService
|
|
|
|
|
|
class TestOldBehaviorSimulation:
|
|
"""
|
|
Simulate the old broken behavior to prove it was buggy.
|
|
"""
|
|
|
|
def test_old_behavior_would_miss_target_restrictions(self):
|
|
"""
|
|
SIMULATION: This test recreates the OLD BROKEN BEHAVIOR and proves it was buggy.
|
|
|
|
OLD BUG: When validation service called provider.list_models(), it only got
|
|
aliases back, not targets. This meant target-based restrictions weren't validated.
|
|
"""
|
|
# Create a mock provider that simulates the OLD BROKEN BEHAVIOR
|
|
old_broken_provider = MagicMock()
|
|
old_broken_provider.SUPPORTED_MODELS = {
|
|
"mini": "o4-mini", # alias -> target
|
|
"o3mini": "o3-mini", # alias -> target
|
|
"o4-mini": {"context_window": 200000},
|
|
"o3-mini": {"context_window": 200000},
|
|
}
|
|
|
|
# OLD BROKEN: list_models only returned aliases, missing targets
|
|
old_broken_provider.list_models.return_value = ["mini", "o3mini"]
|
|
|
|
# OLD BROKEN: There was no list_all_known_models method!
|
|
# We simulate this by making it behave like the old list_models
|
|
old_broken_provider.list_all_known_models.return_value = ["mini", "o3mini"] # MISSING TARGETS!
|
|
|
|
# Now test what happens when admin tries to restrict by target model
|
|
service = ModelRestrictionService()
|
|
service.restrictions = {ProviderType.OPENAI: {"o4-mini"}} # Restrict to target model
|
|
|
|
with patch("utils.model_restrictions.logger") as mock_logger:
|
|
provider_instances = {ProviderType.OPENAI: old_broken_provider}
|
|
service.validate_against_known_models(provider_instances)
|
|
|
|
# OLD BROKEN BEHAVIOR: Would warn about o4-mini being "not recognized"
|
|
# because it wasn't in the list_all_known_models response
|
|
target_warnings = [
|
|
call
|
|
for call in mock_logger.warning.call_args_list
|
|
if "o4-mini" in str(call) and "not a recognized" in str(call)
|
|
]
|
|
|
|
# This proves the old behavior was broken - it would generate false warnings
|
|
assert len(target_warnings) > 0, "OLD BROKEN BEHAVIOR: Would incorrectly warn about valid target models"
|
|
|
|
# Verify the warning message shows the broken list
|
|
warning_text = str(target_warnings[0])
|
|
assert "mini" in warning_text # Alias was included
|
|
assert "o3mini" in warning_text # Alias was included
|
|
# But targets were missing from the known models list in old behavior
|
|
|
|
def test_new_behavior_fixes_the_problem(self):
|
|
"""
|
|
Compare old vs new behavior to show our fix works.
|
|
"""
|
|
# Create mock provider with NEW FIXED BEHAVIOR
|
|
new_fixed_provider = MagicMock()
|
|
new_fixed_provider.SUPPORTED_MODELS = {
|
|
"mini": "o4-mini",
|
|
"o3mini": "o3-mini",
|
|
"o4-mini": {"context_window": 200000},
|
|
"o3-mini": {"context_window": 200000},
|
|
}
|
|
|
|
# NEW FIXED: list_all_known_models includes BOTH aliases AND targets
|
|
new_fixed_provider.list_all_known_models.return_value = [
|
|
"mini",
|
|
"o3mini", # aliases
|
|
"o4-mini",
|
|
"o3-mini", # targets - THESE WERE MISSING IN OLD CODE!
|
|
]
|
|
|
|
# Same restriction scenario
|
|
service = ModelRestrictionService()
|
|
service.restrictions = {ProviderType.OPENAI: {"o4-mini"}} # Restrict to target model
|
|
|
|
with patch("utils.model_restrictions.logger") as mock_logger:
|
|
provider_instances = {ProviderType.OPENAI: new_fixed_provider}
|
|
service.validate_against_known_models(provider_instances)
|
|
|
|
# NEW FIXED BEHAVIOR: No warnings about o4-mini being unrecognized
|
|
target_warnings = [
|
|
call
|
|
for call in mock_logger.warning.call_args_list
|
|
if "o4-mini" in str(call) and "not a recognized" in str(call)
|
|
]
|
|
|
|
# Our fix prevents false warnings
|
|
assert len(target_warnings) == 0, "NEW FIXED BEHAVIOR: Should not warn about valid target models"
|
|
|
|
def test_policy_bypass_prevention_old_vs_new(self):
|
|
"""
|
|
Show how the old behavior could have led to policy bypass scenarios.
|
|
"""
|
|
# OLD BROKEN: Admin thinks they've restricted access to o4-mini,
|
|
# but validation doesn't recognize it as a valid restriction target
|
|
old_broken_provider = MagicMock()
|
|
old_broken_provider.list_all_known_models.return_value = ["mini", "o3mini"] # Missing targets
|
|
|
|
# NEW FIXED: Same provider with our fix
|
|
new_fixed_provider = MagicMock()
|
|
new_fixed_provider.list_all_known_models.return_value = ["mini", "o3mini", "o4-mini", "o3-mini"]
|
|
|
|
# Test restriction on target model - use completely separate service instances
|
|
old_service = ModelRestrictionService()
|
|
old_service.restrictions = {ProviderType.OPENAI: {"o4-mini", "completely-invalid-model"}}
|
|
|
|
new_service = ModelRestrictionService()
|
|
new_service.restrictions = {ProviderType.OPENAI: {"o4-mini", "completely-invalid-model"}}
|
|
|
|
# OLD BEHAVIOR: Would warn about BOTH models being unrecognized
|
|
with patch("utils.model_restrictions.logger") as mock_logger_old:
|
|
provider_instances = {ProviderType.OPENAI: old_broken_provider}
|
|
old_service.validate_against_known_models(provider_instances)
|
|
|
|
old_warnings = [str(call) for call in mock_logger_old.warning.call_args_list]
|
|
print(f"OLD warnings: {old_warnings}") # Debug output
|
|
|
|
# NEW BEHAVIOR: Only warns about truly invalid model
|
|
with patch("utils.model_restrictions.logger") as mock_logger_new:
|
|
provider_instances = {ProviderType.OPENAI: new_fixed_provider}
|
|
new_service.validate_against_known_models(provider_instances)
|
|
|
|
new_warnings = [str(call) for call in mock_logger_new.warning.call_args_list]
|
|
print(f"NEW warnings: {new_warnings}") # Debug output
|
|
|
|
# For now, just verify that we get some warnings in both cases
|
|
# The key point is that the "Known models" list is different
|
|
assert len(old_warnings) > 0, "OLD: Should have warnings"
|
|
assert len(new_warnings) > 0, "NEW: Should have warnings for invalid model"
|
|
|
|
# Verify the known models list is different between old and new
|
|
str(old_warnings[0]) if old_warnings else ""
|
|
new_warning_text = str(new_warnings[0]) if new_warnings else ""
|
|
|
|
if "Known models:" in new_warning_text:
|
|
# NEW behavior should include o4-mini in known models list
|
|
assert "o4-mini" in new_warning_text, "NEW: Should include o4-mini in known models"
|
|
|
|
print("This test demonstrates that our fix improves the 'Known models' list shown to users.")
|
|
|
|
def test_demonstrate_target_coverage_improvement(self):
|
|
"""
|
|
Show the exact improvement in target model coverage.
|
|
"""
|
|
# Simulate different provider implementations
|
|
providers_old_vs_new = [
|
|
# (old_broken_list, new_fixed_list, provider_name)
|
|
(["mini", "o3mini"], ["mini", "o3mini", "o4-mini", "o3-mini"], "OpenAI"),
|
|
(
|
|
["flash", "pro"],
|
|
["flash", "pro", "gemini-2.5-flash", "gemini-2.5-pro"],
|
|
"Gemini",
|
|
),
|
|
]
|
|
|
|
for old_list, new_list, provider_name in providers_old_vs_new:
|
|
# Count how many additional models are now covered
|
|
old_coverage = set(old_list)
|
|
new_coverage = set(new_list)
|
|
|
|
additional_coverage = new_coverage - old_coverage
|
|
|
|
# There should be additional target models covered
|
|
assert len(additional_coverage) > 0, f"{provider_name}: Should have additional target coverage"
|
|
|
|
# All old models should still be covered
|
|
assert old_coverage.issubset(new_coverage), f"{provider_name}: Should maintain backward compatibility"
|
|
|
|
print(f"{provider_name} provider:")
|
|
print(f" Old coverage: {sorted(old_coverage)}")
|
|
print(f" New coverage: {sorted(new_coverage)}")
|
|
print(f" Additional models: {sorted(additional_coverage)}")
|
|
|
|
def test_comprehensive_alias_target_mapping_verification(self):
|
|
"""
|
|
Verify that our fix provides comprehensive alias->target coverage.
|
|
"""
|
|
from providers.gemini import GeminiModelProvider
|
|
from providers.openai_provider import OpenAIModelProvider
|
|
|
|
# Test real providers to ensure they implement our fix correctly
|
|
providers = [OpenAIModelProvider(api_key="test-key"), GeminiModelProvider(api_key="test-key")]
|
|
|
|
for provider in providers:
|
|
all_known = provider.list_all_known_models()
|
|
|
|
# Check that for every alias in SUPPORTED_MODELS, its target is also included
|
|
for model_name, config in provider.SUPPORTED_MODELS.items():
|
|
# Model name itself should be in the list
|
|
assert model_name.lower() in all_known, f"{provider.__class__.__name__}: Missing model {model_name}"
|
|
|
|
# If it's an alias (config is a string), target should also be in list
|
|
if isinstance(config, str):
|
|
target_model = config
|
|
assert (
|
|
target_model.lower() in all_known
|
|
), f"{provider.__class__.__name__}: Missing target {target_model} for alias {model_name}"
|