Merge branch 'main' into refactor-image-validation

This commit is contained in:
Beehive Innovations
2025-08-07 23:12:00 -07:00
committed by GitHub
55 changed files with 2491 additions and 623 deletions

View File

@@ -15,13 +15,6 @@ parent_dir = Path(__file__).resolve().parent.parent
if str(parent_dir) not in sys.path:
sys.path.insert(0, str(parent_dir))
# Set dummy API keys for tests if not already set or if empty
if not os.environ.get("GEMINI_API_KEY"):
os.environ["GEMINI_API_KEY"] = "dummy-key-for-tests"
if not os.environ.get("OPENAI_API_KEY"):
os.environ["OPENAI_API_KEY"] = "dummy-key-for-tests"
if not os.environ.get("XAI_API_KEY"):
os.environ["XAI_API_KEY"] = "dummy-key-for-tests"
# Set default model to a specific value for tests to avoid auto mode
# This prevents all tests from failing due to missing model parameter
@@ -77,11 +70,27 @@ def project_path(tmp_path):
return test_dir
def _set_dummy_keys_if_missing():
"""Set dummy API keys only when they are completely absent."""
for var in ("GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY"):
if not os.environ.get(var):
os.environ[var] = "dummy-key-for-tests"
# Pytest configuration
def pytest_configure(config):
"""Configure pytest with custom markers"""
config.addinivalue_line("markers", "asyncio: mark test as async")
config.addinivalue_line("markers", "no_mock_provider: disable automatic provider mocking")
# Assume we need dummy keys until we learn otherwise
config._needs_dummy_keys = True
def pytest_collection_modifyitems(session, config, items):
"""Hook that runs after test collection to check for no_mock_provider markers."""
# Always set dummy keys if real keys are missing
# This ensures tests work in CI even with no_mock_provider marker
_set_dummy_keys_if_missing()
@pytest.fixture(autouse=True)

View File

@@ -0,0 +1,376 @@
#!/usr/bin/env python3
"""
HTTP Transport Recorder for O3-Pro Testing
Custom httpx transport solution that replaces respx for recording/replaying
HTTP interactions. Provides full control over the recording process without
respx limitations.
Key Features:
- RecordingTransport: Wraps default transport, captures real HTTP calls
- ReplayTransport: Serves saved responses from cassettes
- TransportFactory: Auto-selects record vs replay mode
- JSON cassette format with data sanitization
"""
import base64
import hashlib
import json
import logging
from pathlib import Path
from typing import Any, Optional
import httpx
from .pii_sanitizer import PIISanitizer
logger = logging.getLogger(__name__)
class RecordingTransport(httpx.HTTPTransport):
"""Transport that wraps default httpx transport and records all interactions."""
def __init__(self, cassette_path: str, capture_content: bool = True, sanitize: bool = True):
super().__init__()
self.cassette_path = Path(cassette_path)
self.recorded_interactions = []
self.capture_content = capture_content
self.sanitizer = PIISanitizer() if sanitize else None
def handle_request(self, request: httpx.Request) -> httpx.Response:
"""Handle request by recording interaction and delegating to real transport."""
logger.debug(f"RecordingTransport: Making request to {request.method} {request.url}")
# Record request BEFORE making the call
request_data = self._serialize_request(request)
# Make real HTTP call using parent transport
response = super().handle_request(request)
logger.debug(f"RecordingTransport: Got response {response.status_code}")
# Post-response content capture (proper approach)
if self.capture_content:
try:
# Consume the response stream to capture content
# Note: httpx automatically handles gzip decompression
content_bytes = response.read()
response.close() # Close the original stream
logger.debug(f"RecordingTransport: Captured {len(content_bytes)} bytes")
# Serialize response with captured content
response_data = self._serialize_response_with_content(response, content_bytes)
# Create a new response with the same metadata but buffered content
# If the original response was gzipped, we need to re-compress
response_content = content_bytes
if response.headers.get("content-encoding") == "gzip":
import gzip
response_content = gzip.compress(content_bytes)
logger.debug(f"Re-compressed content: {len(content_bytes)}{len(response_content)} bytes")
new_response = httpx.Response(
status_code=response.status_code,
headers=response.headers, # Keep original headers intact
content=response_content,
request=request,
extensions=response.extensions,
history=response.history,
)
# Record the interaction
self._record_interaction(request_data, response_data)
return new_response
except Exception:
logger.warning("Content capture failed, falling back to stub", exc_info=True)
response_data = self._serialize_response(response)
self._record_interaction(request_data, response_data)
return response
else:
# Legacy mode: record with stub content
response_data = self._serialize_response(response)
self._record_interaction(request_data, response_data)
return response
def _record_interaction(self, request_data: dict[str, Any], response_data: dict[str, Any]):
"""Helper method to record interaction and save cassette."""
interaction = {"request": request_data, "response": response_data}
self.recorded_interactions.append(interaction)
self._save_cassette()
logger.debug(f"Saved cassette to {self.cassette_path}")
def _serialize_request(self, request: httpx.Request) -> dict[str, Any]:
"""Serialize httpx.Request to JSON-compatible format."""
# For requests, we can safely read the content since it's already been prepared
# httpx.Request.content is safe to access multiple times
content = request.content
# Convert bytes to string for JSON serialization
if isinstance(content, bytes):
try:
content_str = content.decode("utf-8")
except UnicodeDecodeError:
# Handle binary content (shouldn't happen for o3-pro API)
content_str = content.hex()
else:
content_str = str(content) if content else ""
request_data = {
"method": request.method,
"url": str(request.url),
"path": request.url.path,
"headers": dict(request.headers),
"content": self._sanitize_request_content(content_str),
}
# Apply PII sanitization if enabled
if self.sanitizer:
request_data = self.sanitizer.sanitize_request(request_data)
return request_data
def _serialize_response(self, response: httpx.Response) -> dict[str, Any]:
"""Serialize httpx.Response to JSON-compatible format (legacy method without content)."""
# Legacy method for backward compatibility when content capture is disabled
return {
"status_code": response.status_code,
"headers": dict(response.headers),
"content": {"note": "Response content not recorded to avoid httpx.ResponseNotRead exception"},
"reason_phrase": response.reason_phrase,
}
def _serialize_response_with_content(self, response: httpx.Response, content_bytes: bytes) -> dict[str, Any]:
"""Serialize httpx.Response with captured content."""
try:
# Debug: check what we got
# Ensure we have bytes for base64 encoding
if not isinstance(content_bytes, bytes):
logger.warning(f"Content is not bytes, converting from {type(content_bytes)}")
if isinstance(content_bytes, str):
content_bytes = content_bytes.encode("utf-8")
else:
content_bytes = str(content_bytes).encode("utf-8")
# Encode content as base64 for JSON storage
content_b64 = base64.b64encode(content_bytes).decode("utf-8")
logger.debug(f"Base64 encoded {len(content_bytes)} bytes → {len(content_b64)} chars")
response_data = {
"status_code": response.status_code,
"headers": dict(response.headers),
"content": {"data": content_b64, "encoding": "base64", "size": len(content_bytes)},
"reason_phrase": response.reason_phrase,
}
# Apply PII sanitization if enabled
if self.sanitizer:
response_data = self.sanitizer.sanitize_response(response_data)
return response_data
except Exception as e:
logger.exception("Error in _serialize_response_with_content")
# Fall back to minimal info
return {
"status_code": response.status_code,
"headers": dict(response.headers),
"content": {"error": f"Failed to serialize content: {e}"},
"reason_phrase": response.reason_phrase,
}
def _sanitize_request_content(self, content: str) -> Any:
"""Sanitize request content to remove sensitive data."""
try:
if content.strip():
data = json.loads(content)
# Don't sanitize request content for now - it's user input
return data
except json.JSONDecodeError:
pass
return content
def _save_cassette(self):
"""Save recorded interactions to cassette file."""
# Ensure directory exists
self.cassette_path.parent.mkdir(parents=True, exist_ok=True)
# Save cassette
cassette_data = {"interactions": self.recorded_interactions}
self.cassette_path.write_text(json.dumps(cassette_data, indent=2, sort_keys=True))
class ReplayTransport(httpx.MockTransport):
"""Transport that replays saved HTTP interactions from cassettes."""
def __init__(self, cassette_path: str):
self.cassette_path = Path(cassette_path)
self.interactions = self._load_cassette()
super().__init__(self._handle_request)
def _load_cassette(self) -> list:
"""Load interactions from cassette file."""
if not self.cassette_path.exists():
raise FileNotFoundError(f"Cassette file not found: {self.cassette_path}")
try:
cassette_data = json.loads(self.cassette_path.read_text())
return cassette_data.get("interactions", [])
except json.JSONDecodeError as e:
raise ValueError(f"Invalid cassette file format: {e}")
def _handle_request(self, request: httpx.Request) -> httpx.Response:
"""Handle request by finding matching interaction and returning saved response."""
logger.debug(f"ReplayTransport: Looking for {request.method} {request.url}")
# Debug: show what we're trying to match
request_signature = self._get_request_signature(request)
logger.debug(f"Request signature: {request_signature}")
# Find matching interaction
interaction = self._find_matching_interaction(request)
if not interaction:
logger.warning("No matching interaction found in cassette")
raise ValueError(f"No matching interaction found for {request.method} {request.url}")
logger.debug("Found matching interaction in cassette")
# Build response from saved data
response_data = interaction["response"]
# Convert content back to appropriate format
content = response_data.get("content", {})
if isinstance(content, dict):
# Check if this is base64-encoded content
if content.get("encoding") == "base64" and "data" in content:
# Decode base64 content
try:
content_bytes = base64.b64decode(content["data"])
logger.debug(f"Decoded {len(content_bytes)} bytes from base64")
except Exception as e:
logger.warning(f"Failed to decode base64 content: {e}")
content_bytes = json.dumps(content).encode("utf-8")
else:
# Legacy format or stub content
content_bytes = json.dumps(content).encode("utf-8")
else:
content_bytes = str(content).encode("utf-8")
# Check if response expects gzipped content
headers = response_data.get("headers", {})
if headers.get("content-encoding") == "gzip":
# Re-compress the content for httpx
import gzip
content_bytes = gzip.compress(content_bytes)
logger.debug(f"Re-compressed for replay: {len(content_bytes)} bytes")
logger.debug(f"Returning cassette response ({len(content_bytes)} bytes)")
# Create httpx.Response
return httpx.Response(
status_code=response_data["status_code"],
headers=response_data.get("headers", {}),
content=content_bytes,
request=request,
)
def _find_matching_interaction(self, request: httpx.Request) -> Optional[dict[str, Any]]:
"""Find interaction that matches the request."""
request_signature = self._get_request_signature(request)
for interaction in self.interactions:
saved_signature = self._get_saved_request_signature(interaction["request"])
if request_signature == saved_signature:
return interaction
return None
def _get_request_signature(self, request: httpx.Request) -> str:
"""Generate signature for request matching."""
# Use method, path, and content hash for matching
content = request.content
if hasattr(content, "read"):
content = content.read()
if isinstance(content, bytes):
content_str = content.decode("utf-8", errors="ignore")
else:
content_str = str(content) if content else ""
# Parse JSON and re-serialize with sorted keys for consistent hashing
try:
if content_str.strip():
content_dict = json.loads(content_str)
content_str = json.dumps(content_dict, sort_keys=True)
except json.JSONDecodeError:
# Not JSON, use as-is
pass
# Create hash of content for stable matching
content_hash = hashlib.md5(content_str.encode()).hexdigest()
return f"{request.method}:{request.url.path}:{content_hash}"
def _get_saved_request_signature(self, saved_request: dict[str, Any]) -> str:
"""Generate signature for saved request."""
method = saved_request["method"]
path = saved_request["path"]
# Hash the saved content
content = saved_request.get("content", "")
if isinstance(content, dict):
content_str = json.dumps(content, sort_keys=True)
else:
content_str = str(content)
content_hash = hashlib.md5(content_str.encode()).hexdigest()
return f"{method}:{path}:{content_hash}"
class TransportFactory:
"""Factory for creating appropriate transport based on cassette availability."""
@staticmethod
def create_transport(cassette_path: str) -> httpx.HTTPTransport:
"""Create transport based on cassette existence and API key availability."""
cassette_file = Path(cassette_path)
# Check if we should record or replay
if cassette_file.exists():
# Cassette exists - use replay mode
return ReplayTransport(cassette_path)
else:
# No cassette - use recording mode
# Note: We'll check for API key in the test itself
return RecordingTransport(cassette_path)
@staticmethod
def should_record(cassette_path: str, api_key: Optional[str] = None) -> bool:
"""Determine if we should record based on cassette and API key availability."""
cassette_file = Path(cassette_path)
# Record if cassette doesn't exist AND we have API key
return not cassette_file.exists() and bool(api_key)
@staticmethod
def should_replay(cassette_path: str) -> bool:
"""Determine if we should replay based on cassette availability."""
cassette_file = Path(cassette_path)
return cassette_file.exists()
# Example usage:
#
# # In test setup:
# cassette_path = "tests/cassettes/o3_pro_basic_math.json"
# transport = TransportFactory.create_transport(cassette_path)
#
# # Inject into OpenAI client:
# provider._test_transport = transport
#
# # The provider's client property will detect _test_transport and use it

File diff suppressed because one or more lines are too long

290
tests/pii_sanitizer.py Normal file
View File

@@ -0,0 +1,290 @@
#!/usr/bin/env python3
"""
PII (Personally Identifiable Information) Sanitizer for HTTP recordings.
This module provides comprehensive sanitization of sensitive data in HTTP
request/response recordings to prevent accidental exposure of API keys,
tokens, personal information, and other sensitive data.
"""
import logging
import re
from copy import deepcopy
from dataclasses import dataclass
from re import Pattern
from typing import Any, Optional
logger = logging.getLogger(__name__)
@dataclass
class PIIPattern:
"""Defines a pattern for detecting and sanitizing PII."""
name: str
pattern: Pattern[str]
replacement: str
description: str
@classmethod
def create(cls, name: str, pattern: str, replacement: str, description: str) -> "PIIPattern":
"""Create a PIIPattern with compiled regex."""
return cls(name=name, pattern=re.compile(pattern), replacement=replacement, description=description)
class PIISanitizer:
"""Sanitizes PII from various data structures while preserving format."""
def __init__(self, patterns: Optional[list[PIIPattern]] = None):
"""Initialize with optional custom patterns."""
self.patterns: list[PIIPattern] = patterns or []
self.sanitize_enabled = True
# Add default patterns if none provided
if not patterns:
self._add_default_patterns()
def _add_default_patterns(self):
"""Add comprehensive default PII patterns."""
default_patterns = [
# API Keys - Core patterns (Bearer tokens handled in sanitize_headers)
PIIPattern.create(
name="openai_api_key_proj",
pattern=r"sk-proj-[A-Za-z0-9\-_]{48,}",
replacement="sk-proj-SANITIZED",
description="OpenAI project API keys",
),
PIIPattern.create(
name="openai_api_key",
pattern=r"sk-[A-Za-z0-9]{48,}",
replacement="sk-SANITIZED",
description="OpenAI API keys",
),
PIIPattern.create(
name="anthropic_api_key",
pattern=r"sk-ant-[A-Za-z0-9\-_]{48,}",
replacement="sk-ant-SANITIZED",
description="Anthropic API keys",
),
PIIPattern.create(
name="google_api_key",
pattern=r"AIza[A-Za-z0-9\-_]{35,}",
replacement="AIza-SANITIZED",
description="Google API keys",
),
PIIPattern.create(
name="github_tokens",
pattern=r"gh[psr]_[A-Za-z0-9]{36}",
replacement="gh_SANITIZED",
description="GitHub tokens (all types)",
),
# JWT tokens
PIIPattern.create(
name="jwt_token",
pattern=r"eyJ[A-Za-z0-9\-_]+\.eyJ[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_]+",
replacement="eyJ-SANITIZED",
description="JSON Web Tokens",
),
# Personal Information
PIIPattern.create(
name="email_address",
pattern=r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}",
replacement="user@example.com",
description="Email addresses",
),
PIIPattern.create(
name="ipv4_address",
pattern=r"\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b",
replacement="0.0.0.0",
description="IPv4 addresses",
),
PIIPattern.create(
name="ssn",
pattern=r"\b\d{3}-\d{2}-\d{4}\b",
replacement="XXX-XX-XXXX",
description="Social Security Numbers",
),
PIIPattern.create(
name="credit_card",
pattern=r"\b\d{4}[\s\-]?\d{4}[\s\-]?\d{4}[\s\-]?\d{4}\b",
replacement="XXXX-XXXX-XXXX-XXXX",
description="Credit card numbers",
),
PIIPattern.create(
name="phone_number",
pattern=r"(?:\+\d{1,3}[\s\-]?)?\(?\d{3}\)?[\s\-]?\d{3}[\s\-]?\d{4}\b(?![\d\.\,\]\}])",
replacement="(XXX) XXX-XXXX",
description="Phone numbers (all formats)",
),
# AWS
PIIPattern.create(
name="aws_access_key",
pattern=r"AKIA[0-9A-Z]{16}",
replacement="AKIA-SANITIZED",
description="AWS access keys",
),
# Other common patterns
PIIPattern.create(
name="slack_token",
pattern=r"xox[baprs]-[0-9]{10,13}-[0-9]{10,13}-[a-zA-Z0-9]{24,34}",
replacement="xox-SANITIZED",
description="Slack tokens",
),
PIIPattern.create(
name="stripe_key",
pattern=r"(?:sk|pk)_(?:test|live)_[0-9a-zA-Z]{24,99}",
replacement="sk_SANITIZED",
description="Stripe API keys",
),
]
self.patterns.extend(default_patterns)
def add_pattern(self, pattern: PIIPattern):
"""Add a custom PII pattern."""
self.patterns.append(pattern)
logger.info(f"Added PII pattern: {pattern.name}")
def sanitize_string(self, text: str) -> str:
"""Apply all patterns to sanitize a string."""
if not self.sanitize_enabled or not isinstance(text, str):
return text
sanitized = text
for pattern in self.patterns:
if pattern.pattern.search(sanitized):
sanitized = pattern.pattern.sub(pattern.replacement, sanitized)
logger.debug(f"Applied {pattern.name} sanitization")
return sanitized
def sanitize_headers(self, headers: dict[str, str]) -> dict[str, str]:
"""Special handling for HTTP headers."""
if not self.sanitize_enabled:
return headers
sanitized_headers = {}
for key, value in headers.items():
# Special case for Authorization headers to preserve auth type
if key.lower() == "authorization" and " " in value:
auth_type = value.split(" ", 1)[0]
if auth_type in ("Bearer", "Basic"):
sanitized_headers[key] = f"{auth_type} SANITIZED"
else:
sanitized_headers[key] = self.sanitize_string(value)
else:
# Apply standard sanitization to all other headers
sanitized_headers[key] = self.sanitize_string(value)
return sanitized_headers
def sanitize_value(self, value: Any) -> Any:
"""Recursively sanitize any value (string, dict, list, etc)."""
if not self.sanitize_enabled:
return value
if isinstance(value, str):
return self.sanitize_string(value)
elif isinstance(value, dict):
return {k: self.sanitize_value(v) for k, v in value.items()}
elif isinstance(value, list):
return [self.sanitize_value(item) for item in value]
elif isinstance(value, tuple):
return tuple(self.sanitize_value(item) for item in value)
else:
# For other types (int, float, bool, None), return as-is
return value
def sanitize_url(self, url: str) -> str:
"""Sanitize sensitive data from URLs (query params, etc)."""
if not self.sanitize_enabled:
return url
# First apply general string sanitization
url = self.sanitize_string(url)
# Parse and sanitize query parameters
if "?" in url:
base, query = url.split("?", 1)
params = []
for param in query.split("&"):
if "=" in param:
key, value = param.split("=", 1)
# Sanitize common sensitive parameter names
sensitive_params = {"key", "token", "api_key", "secret", "password"}
if key.lower() in sensitive_params:
params.append(f"{key}=SANITIZED")
else:
# Still sanitize the value for PII
params.append(f"{key}={self.sanitize_string(value)}")
else:
params.append(param)
return f"{base}?{'&'.join(params)}"
return url
def sanitize_request(self, request_data: dict[str, Any]) -> dict[str, Any]:
"""Sanitize a complete request dictionary."""
sanitized = deepcopy(request_data)
# Sanitize headers
if "headers" in sanitized:
sanitized["headers"] = self.sanitize_headers(sanitized["headers"])
# Sanitize URL
if "url" in sanitized:
sanitized["url"] = self.sanitize_url(sanitized["url"])
# Sanitize content
if "content" in sanitized:
sanitized["content"] = self.sanitize_value(sanitized["content"])
return sanitized
def sanitize_response(self, response_data: dict[str, Any]) -> dict[str, Any]:
"""Sanitize a complete response dictionary."""
sanitized = deepcopy(response_data)
# Sanitize headers
if "headers" in sanitized:
sanitized["headers"] = self.sanitize_headers(sanitized["headers"])
# Sanitize content
if "content" in sanitized:
# Handle base64 encoded content specially
if isinstance(sanitized["content"], dict) and sanitized["content"].get("encoding") == "base64":
if "data" in sanitized["content"]:
import base64
try:
# Decode, sanitize, and re-encode the actual response body
decoded_bytes = base64.b64decode(sanitized["content"]["data"])
# Attempt to decode as UTF-8 for sanitization. If it fails, it's likely binary.
try:
decoded_str = decoded_bytes.decode("utf-8")
sanitized_str = self.sanitize_string(decoded_str)
sanitized["content"]["data"] = base64.b64encode(sanitized_str.encode("utf-8")).decode(
"utf-8"
)
except UnicodeDecodeError:
# Content is not text, leave as is.
pass
except (base64.binascii.Error, TypeError):
# Handle cases where data is not valid base64
pass
# Sanitize other metadata fields
for key, value in sanitized["content"].items():
if key != "data":
sanitized["content"][key] = self.sanitize_value(value)
else:
sanitized["content"] = self.sanitize_value(sanitized["content"])
return sanitized
# Global instance for convenience
default_sanitizer = PIISanitizer()

110
tests/sanitize_cassettes.py Executable file
View File

@@ -0,0 +1,110 @@
#!/usr/bin/env python3
"""
Script to sanitize existing cassettes by applying PII sanitization.
This script will:
1. Load existing cassettes
2. Apply PII sanitization to all interactions
3. Create backups of originals
4. Save sanitized versions
"""
import json
import shutil
import sys
from datetime import datetime
from pathlib import Path
# Add tests directory to path to import our modules
sys.path.insert(0, str(Path(__file__).parent))
from pii_sanitizer import PIISanitizer
def sanitize_cassette(cassette_path: Path, backup: bool = True) -> bool:
"""Sanitize a single cassette file."""
print(f"\n🔍 Processing: {cassette_path}")
if not cassette_path.exists():
print(f"❌ File not found: {cassette_path}")
return False
try:
# Load cassette
with open(cassette_path) as f:
cassette_data = json.load(f)
# Create backup if requested
if backup:
backup_path = cassette_path.with_suffix(f'.backup-{datetime.now().strftime("%Y%m%d-%H%M%S")}.json')
shutil.copy2(cassette_path, backup_path)
print(f"📦 Backup created: {backup_path}")
# Initialize sanitizer
sanitizer = PIISanitizer()
# Sanitize interactions
if "interactions" in cassette_data:
sanitized_interactions = []
for interaction in cassette_data["interactions"]:
sanitized_interaction = {}
# Sanitize request
if "request" in interaction:
sanitized_interaction["request"] = sanitizer.sanitize_request(interaction["request"])
# Sanitize response
if "response" in interaction:
sanitized_interaction["response"] = sanitizer.sanitize_response(interaction["response"])
sanitized_interactions.append(sanitized_interaction)
cassette_data["interactions"] = sanitized_interactions
# Save sanitized cassette
with open(cassette_path, "w") as f:
json.dump(cassette_data, f, indent=2, sort_keys=True)
print(f"✅ Sanitized: {cassette_path}")
return True
except Exception as e:
print(f"❌ Error processing {cassette_path}: {e}")
import traceback
traceback.print_exc()
return False
def main():
"""Sanitize all cassettes in the openai_cassettes directory."""
cassettes_dir = Path(__file__).parent / "openai_cassettes"
if not cassettes_dir.exists():
print(f"❌ Directory not found: {cassettes_dir}")
sys.exit(1)
# Find all JSON cassettes
cassette_files = list(cassettes_dir.glob("*.json"))
if not cassette_files:
print(f"❌ No cassette files found in {cassettes_dir}")
sys.exit(1)
print(f"🎬 Found {len(cassette_files)} cassette(s) to sanitize")
# Process each cassette
success_count = 0
for cassette_path in cassette_files:
if sanitize_cassette(cassette_path):
success_count += 1
print(f"\n✨ Sanitization complete: {success_count}/{len(cassette_files)} cassettes processed successfully")
if success_count < len(cassette_files):
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -48,7 +48,8 @@ class TestAliasTargetRestrictions:
"""Test that restriction policy allows alias when target model is allowed.
This is the correct user-friendly behavior - if you allow 'o4-mini',
you should be able to use its alias 'mini' as well.
you should be able to use its aliases 'o4mini' and 'o4-mini'.
Note: 'mini' is now an alias for 'gpt-5-mini', not 'o4-mini'.
"""
# Clear cached restriction service
import utils.model_restrictions
@@ -57,15 +58,16 @@ class TestAliasTargetRestrictions:
provider = OpenAIModelProvider(api_key="test-key")
# Both target and alias should be allowed
# Both target and its actual aliases should be allowed
assert provider.validate_model_name("o4-mini")
assert provider.validate_model_name("mini")
assert provider.validate_model_name("o4mini")
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini"}) # Allow alias only
def test_restriction_policy_allows_only_alias_when_alias_specified(self):
"""Test that restriction policy allows only the alias when just alias is specified.
If you restrict to 'mini', only the alias should work, not the direct target.
If you restrict to 'mini' (which is an alias for gpt-5-mini),
only the alias should work, not other models.
This is the correct restrictive behavior.
"""
# Clear cached restriction service
@@ -77,7 +79,9 @@ class TestAliasTargetRestrictions:
# Only the alias should be allowed
assert provider.validate_model_name("mini")
# Direct target should NOT be allowed
# Direct target for this alias should NOT be allowed (mini -> gpt-5-mini)
assert not provider.validate_model_name("gpt-5-mini")
# Other models should NOT be allowed
assert not provider.validate_model_name("o4-mini")
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"}) # Allow target
@@ -127,12 +131,15 @@ class TestAliasTargetRestrictions:
# The warning should include both aliases and targets in known models
warning_message = str(warning_calls[0])
assert "mini" in warning_message # alias should be in known models
assert "o4-mini" in warning_message # target should be in known models
assert "o4mini" in warning_message or "o4-mini" in warning_message # aliases should be in known models
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini,o4-mini"}) # Allow both alias and target
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini,gpt-5-mini,o4-mini,o4mini"}) # Allow different models
def test_both_alias_and_target_allowed_when_both_specified(self):
"""Test that both alias and target work when both are explicitly allowed."""
"""Test that both alias and target work when both are explicitly allowed.
mini -> gpt-5-mini
o4mini -> o4-mini
"""
# Clear cached restriction service
import utils.model_restrictions
@@ -140,9 +147,11 @@ class TestAliasTargetRestrictions:
provider = OpenAIModelProvider(api_key="test-key")
# Both should be allowed
assert provider.validate_model_name("mini")
assert provider.validate_model_name("o4-mini")
# All should be allowed since we explicitly allowed them
assert provider.validate_model_name("mini") # alias for gpt-5-mini
assert provider.validate_model_name("gpt-5-mini") # target
assert provider.validate_model_name("o4-mini") # target
assert provider.validate_model_name("o4mini") # alias for o4-mini
def test_alias_target_policy_regression_prevention(self):
"""Regression test to ensure aliases and targets are both validated properly.

View File

@@ -95,8 +95,8 @@ class TestAutoModeComprehensive:
},
{
"EXTENDED_REASONING": "o3", # O3 for deep reasoning
"FAST_RESPONSE": "o4-mini", # O4-mini for speed
"BALANCED": "o4-mini", # O4-mini as balanced
"FAST_RESPONSE": "gpt-5", # Prefer gpt-5 for speed
"BALANCED": "gpt-5", # Prefer gpt-5 for balanced
},
),
# Only X.AI API available
@@ -108,12 +108,12 @@ class TestAutoModeComprehensive:
"OPENROUTER_API_KEY": None,
},
{
"EXTENDED_REASONING": "grok-3", # GROK-3 for reasoning
"EXTENDED_REASONING": "grok-4", # GROK-4 for reasoning (now preferred)
"FAST_RESPONSE": "grok-3-fast", # GROK-3-fast for speed
"BALANCED": "grok-3", # GROK-3 as balanced
"BALANCED": "grok-4", # GROK-4 as balanced (now preferred)
},
),
# Both Gemini and OpenAI available - should prefer based on tool category
# Both Gemini and OpenAI available - Google comes first in priority
(
{
"GEMINI_API_KEY": "real-key",
@@ -122,12 +122,12 @@ class TestAutoModeComprehensive:
"OPENROUTER_API_KEY": None,
},
{
"EXTENDED_REASONING": "o3", # Prefer O3 for deep reasoning
"FAST_RESPONSE": "o4-mini", # Prefer O4-mini for speed
"BALANCED": "o4-mini", # Prefer OpenAI for balanced
"EXTENDED_REASONING": "gemini-2.5-pro", # Gemini comes first in priority
"FAST_RESPONSE": "gemini-2.5-flash", # Prefer flash for speed
"BALANCED": "gemini-2.5-flash", # Prefer flash for balanced
},
),
# All native APIs available - should prefer based on tool category
# All native APIs available - Google still comes first
(
{
"GEMINI_API_KEY": "real-key",
@@ -136,9 +136,9 @@ class TestAutoModeComprehensive:
"OPENROUTER_API_KEY": None,
},
{
"EXTENDED_REASONING": "o3", # Prefer O3 for deep reasoning
"FAST_RESPONSE": "o4-mini", # Prefer O4-mini for speed
"BALANCED": "o4-mini", # Prefer OpenAI for balanced
"EXTENDED_REASONING": "gemini-2.5-pro", # Gemini comes first in priority
"FAST_RESPONSE": "gemini-2.5-flash", # Prefer flash for speed
"BALANCED": "gemini-2.5-flash", # Prefer flash for balanced
},
),
],

View File

@@ -97,10 +97,10 @@ class TestAutoModeProviderSelection:
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
balanced = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.BALANCED)
# Should select appropriate OpenAI models
assert extended_reasoning in ["o3", "o3-mini", "o4-mini"] # Any available OpenAI model for reasoning
assert fast_response in ["o4-mini", "o3-mini"] # Prefer faster models
assert balanced in ["o4-mini", "o3-mini"] # Balanced selection
# Should select appropriate OpenAI models based on new preference order
assert extended_reasoning == "o3" # O3 for extended reasoning
assert fast_response == "gpt-5" # gpt-5 comes first in fast response preference
assert balanced == "gpt-5" # gpt-5 for balanced
finally:
# Restore original environment
@@ -138,11 +138,11 @@ class TestAutoModeProviderSelection:
)
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
# Should prefer OpenAI for reasoning (based on fallback logic)
assert extended_reasoning == "o3" # Should prefer O3 for extended reasoning
# Should prefer Gemini now (based on new provider priority: Gemini before OpenAI)
assert extended_reasoning == "gemini-2.5-pro" # Gemini has higher priority now
# Should prefer OpenAI for fast response
assert fast_response == "o4-mini" # Should prefer O4-mini for fast response
# Should prefer Gemini for fast response
assert fast_response == "gemini-2.5-flash" # Gemini has higher priority now
finally:
# Restore original environment
@@ -318,9 +318,9 @@ class TestAutoModeProviderSelection:
test_cases = [
("flash", ProviderType.GOOGLE, "gemini-2.5-flash"),
("pro", ProviderType.GOOGLE, "gemini-2.5-pro"),
("mini", ProviderType.OPENAI, "o4-mini"),
("mini", ProviderType.OPENAI, "gpt-5-mini"), # "mini" now resolves to gpt-5-mini
("o3mini", ProviderType.OPENAI, "o3-mini"),
("grok", ProviderType.XAI, "grok-3"),
("grok", ProviderType.XAI, "grok-4"),
("grokfast", ProviderType.XAI, "grok-3-fast"),
]

View File

@@ -132,8 +132,11 @@ class TestBuggyBehaviorPrevention:
assert not provider.validate_model_name("o3-pro") # Not in allowed list
assert not provider.validate_model_name("o3") # Not in allowed list
# This should be ALLOWED because it resolves to o4-mini which is in the allowed list
assert provider.validate_model_name("mini") # Resolves to o4-mini, which IS allowed
# "mini" now resolves to gpt-5-mini, not o4-mini, so it should be blocked
assert not provider.validate_model_name("mini") # Resolves to gpt-5-mini, which is NOT allowed
# But o4mini (the actual alias for o4-mini) should work
assert provider.validate_model_name("o4mini") # Resolves to o4-mini, which IS allowed
# Verify our list_all_known_models includes the restricted models
all_known = provider.list_all_known_models()

View File

@@ -93,7 +93,7 @@ class TestChallengeTool:
response_data = json.loads(result[0].text)
# Check response structure
assert response_data["status"] == "challenge_created"
assert response_data["status"] == "challenge_accepted"
assert response_data["original_statement"] == "All software bugs are caused by syntax errors"
assert "challenge_prompt" in response_data
assert "instructions" in response_data

View File

@@ -113,7 +113,7 @@ class TestDIALProvider:
# Test temperature constraint
assert capabilities.temperature_constraint.min_temp == 0.0
assert capabilities.temperature_constraint.max_temp == 2.0
assert capabilities.temperature_constraint.default_temp == 0.7
assert capabilities.temperature_constraint.default_temp == 0.3
@patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False)
@patch("utils.model_restrictions._restriction_service", None)

View File

@@ -37,14 +37,14 @@ class TestIntelligentFallback:
@patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False)
def test_prefers_openai_o3_mini_when_available(self):
"""Test that o4-mini is preferred when OpenAI API key is available"""
"""Test that gpt-5 is preferred when OpenAI API key is available (based on new preference order)"""
# Register only OpenAI provider for this test
from providers.openai_provider import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
assert fallback_model == "o4-mini"
assert fallback_model == "gpt-5" # Based on new preference order: gpt-5 before o4-mini
@patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-gemini-key"}, clear=False)
def test_prefers_gemini_flash_when_openai_unavailable(self):
@@ -68,7 +68,7 @@ class TestIntelligentFallback:
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
assert fallback_model == "o4-mini" # OpenAI has priority
assert fallback_model == "gemini-2.5-flash" # Gemini has priority now (based on new PROVIDER_PRIORITY_ORDER)
@patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": ""}, clear=False)
def test_fallback_when_no_keys_available(self):
@@ -147,8 +147,8 @@ class TestIntelligentFallback:
history, tokens = build_conversation_history(context, model_context=None)
# Verify that ModelContext was called with o4-mini (the intelligent fallback)
mock_context_class.assert_called_once_with("o4-mini")
# Verify that ModelContext was called with gpt-5 (the intelligent fallback based on new preference order)
mock_context_class.assert_called_once_with("gpt-5")
def test_auto_mode_with_gemini_only(self):
"""Test auto mode behavior when only Gemini API key is available"""

View File

@@ -635,6 +635,13 @@ class TestAutoModeWithRestrictions:
mock_openai.list_models = openai_list_models
mock_openai.list_all_known_models.return_value = ["o3", "o3-mini", "o4-mini"]
# Add get_preferred_model method to mock to match new implementation
def get_preferred_model(category, allowed_models):
# Simple preference logic for testing - just return first allowed model
return allowed_models[0] if allowed_models else None
mock_openai.get_preferred_model = get_preferred_model
def get_provider_side_effect(provider_type):
if provider_type == ProviderType.OPENAI:
return mock_openai
@@ -656,9 +663,13 @@ class TestAutoModeWithRestrictions:
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
assert model == "o4-mini"
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini", "GEMINI_API_KEY": "", "OPENAI_API_KEY": "test-key"})
def test_fallback_with_shorthand_restrictions(self):
def test_fallback_with_shorthand_restrictions(self, monkeypatch):
"""Test fallback model selection with shorthand restrictions."""
# Use monkeypatch to set environment variables with automatic cleanup
monkeypatch.setenv("OPENAI_ALLOWED_MODELS", "mini")
monkeypatch.setenv("GEMINI_API_KEY", "")
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
# Clear caches and reset registry
import utils.model_restrictions
from providers.registry import ModelProviderRegistry
@@ -685,8 +696,9 @@ class TestAutoModeWithRestrictions:
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
# The fallback will depend on how get_available_models handles aliases
# For now, we accept either behavior and document it
assert model in ["o4-mini", "gemini-2.5-flash"]
# When "mini" is allowed, it's returned as the allowed model
# "mini" is now an alias for gpt-5-mini, but the list shows "mini" itself
assert model in ["mini", "gpt-5-mini", "o4-mini", "gemini-2.5-flash"]
finally:
# Restore original registry state
registry = ModelProviderRegistry()

View File

@@ -0,0 +1,124 @@
"""
Tests for o3-pro output_text parsing fix using HTTP transport recording.
This test validates the fix that uses `response.output_text` convenience field
instead of manually parsing `response.output.content[].text`.
Uses HTTP transport recorder to record real o3-pro API responses at the HTTP level while allowing
the OpenAI SDK to create real response objects that we can test.
RECORDING: To record new responses, delete the cassette file and run with real API keys.
"""
import logging
import os
from pathlib import Path
from unittest.mock import patch
import pytest
from dotenv import load_dotenv
from providers import ModelProviderRegistry
from tests.transport_helpers import inject_transport
from tools.chat import ChatTool
logger = logging.getLogger(__name__)
# Load environment variables from .env file
load_dotenv()
# Use absolute path for cassette directory
cassette_dir = Path(__file__).parent / "openai_cassettes"
cassette_dir.mkdir(exist_ok=True)
@pytest.mark.asyncio
class TestO3ProOutputTextFix:
"""Test o3-pro response parsing fix using respx for HTTP recording/replay."""
def setup_method(self):
"""Set up the test by ensuring clean registry state."""
# Use the new public API for registry cleanup
ModelProviderRegistry.reset_for_testing()
# Provider registration is now handled by inject_transport helper
# Clear restriction service to ensure it re-reads environment
# This is necessary because previous tests may have set restrictions
# that are cached in the singleton
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
def teardown_method(self):
"""Clean up after test to ensure no state pollution."""
# Use the new public API for registry cleanup
ModelProviderRegistry.reset_for_testing()
@pytest.mark.no_mock_provider # Disable provider mocking for this test
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-pro", "LOCALE": ""})
async def test_o3_pro_uses_output_text_field(self, monkeypatch):
"""Test that o3-pro parsing uses the output_text convenience field via ChatTool."""
cassette_path = cassette_dir / "o3_pro_basic_math.json"
# Check if we need to record or replay
if not cassette_path.exists():
# Recording mode - check for real API key
real_api_key = os.getenv("OPENAI_API_KEY", "").strip()
if not real_api_key or real_api_key.startswith("dummy"):
pytest.fail(
f"Cassette file not found at {cassette_path}. "
"To record: Set OPENAI_API_KEY environment variable to a valid key and run this test. "
"Note: Recording will make a real API call to OpenAI."
)
# Real API key is available, we'll record the cassette
logger.debug("🎬 Recording mode: Using real API key to record cassette")
else:
# Replay mode - use dummy key
monkeypatch.setenv("OPENAI_API_KEY", "dummy-key-for-replay")
logger.debug("📼 Replay mode: Using recorded cassette")
# Simplified transport injection - just one line!
inject_transport(monkeypatch, cassette_path)
# Execute ChatTool test with custom transport
result = await self._execute_chat_tool_test()
# Verify the response works correctly
self._verify_chat_tool_response(result)
# Verify cassette exists
assert cassette_path.exists()
async def _execute_chat_tool_test(self):
"""Execute the ChatTool with o3-pro and return the result."""
chat_tool = ChatTool()
arguments = {"prompt": "What is 2 + 2?", "model": "o3-pro", "temperature": 1.0}
return await chat_tool.execute(arguments)
def _verify_chat_tool_response(self, result):
"""Verify the ChatTool response contains expected data."""
# Basic response validation
assert result is not None
assert isinstance(result, list)
assert len(result) > 0
assert result[0].type == "text"
# Parse JSON response
import json
response_data = json.loads(result[0].text)
# Debug log the response
logger.debug(f"Response data: {json.dumps(response_data, indent=2)}")
# Verify response structure - no cargo culting
if response_data["status"] == "error":
pytest.fail(f"Chat tool returned error: {response_data.get('error', 'Unknown error')}")
assert response_data["status"] in ["success", "continuation_available"]
assert "4" in response_data["content"]
# Verify o3-pro was actually used
metadata = response_data["metadata"]
assert metadata["model_used"] == "o3-pro"
assert metadata["provider_used"] == "openai"

View File

@@ -230,7 +230,7 @@ class TestO3TemperatureParameterFixSimple:
assert temp_constraint.validate(0.5) is False
# Test regular model constraints - use gpt-4.1 which is supported
gpt41_capabilities = provider.get_capabilities("gpt-4.1-2025-04-14")
gpt41_capabilities = provider.get_capabilities("gpt-4.1")
assert gpt41_capabilities.temperature_constraint is not None
# Regular models should allow a range

View File

@@ -48,12 +48,17 @@ class TestOpenAIProvider:
assert provider.validate_model_name("o3-pro") is True
assert provider.validate_model_name("o4-mini") is True
assert provider.validate_model_name("o4-mini") is True
assert provider.validate_model_name("gpt-5") is True
assert provider.validate_model_name("gpt-5-mini") is True
# Test valid aliases
assert provider.validate_model_name("mini") is True
assert provider.validate_model_name("o3mini") is True
assert provider.validate_model_name("o4mini") is True
assert provider.validate_model_name("o4mini") is True
assert provider.validate_model_name("gpt5") is True
assert provider.validate_model_name("gpt5-mini") is True
assert provider.validate_model_name("gpt5mini") is True
# Test invalid model
assert provider.validate_model_name("invalid-model") is False
@@ -65,17 +70,22 @@ class TestOpenAIProvider:
provider = OpenAIModelProvider("test-key")
# Test shorthand resolution
assert provider._resolve_model_name("mini") == "o4-mini"
assert provider._resolve_model_name("mini") == "gpt-5-mini" # "mini" now resolves to gpt-5-mini
assert provider._resolve_model_name("o3mini") == "o3-mini"
assert provider._resolve_model_name("o4mini") == "o4-mini"
assert provider._resolve_model_name("o4mini") == "o4-mini"
assert provider._resolve_model_name("gpt5") == "gpt-5"
assert provider._resolve_model_name("gpt5-mini") == "gpt-5-mini"
assert provider._resolve_model_name("gpt5mini") == "gpt-5-mini"
# Test full name passthrough
assert provider._resolve_model_name("o3") == "o3"
assert provider._resolve_model_name("o3-mini") == "o3-mini"
assert provider._resolve_model_name("o3-pro") == "o3-pro-2025-06-10"
assert provider._resolve_model_name("o3-pro") == "o3-pro"
assert provider._resolve_model_name("o4-mini") == "o4-mini"
assert provider._resolve_model_name("o4-mini") == "o4-mini"
assert provider._resolve_model_name("gpt-5") == "gpt-5"
assert provider._resolve_model_name("gpt-5-mini") == "gpt-5-mini"
def test_get_capabilities_o3(self):
"""Test getting model capabilities for O3."""
@@ -99,11 +109,43 @@ class TestOpenAIProvider:
provider = OpenAIModelProvider("test-key")
capabilities = provider.get_capabilities("mini")
assert capabilities.model_name == "o4-mini" # Capabilities should show resolved model name
assert capabilities.friendly_name == "OpenAI (O4-mini)"
assert capabilities.context_window == 200_000
assert capabilities.model_name == "gpt-5-mini" # "mini" now resolves to gpt-5-mini
assert capabilities.friendly_name == "OpenAI (GPT-5-mini)"
assert capabilities.context_window == 400_000
assert capabilities.provider == ProviderType.OPENAI
def test_get_capabilities_gpt5(self):
"""Test getting model capabilities for GPT-5."""
provider = OpenAIModelProvider("test-key")
capabilities = provider.get_capabilities("gpt-5")
assert capabilities.model_name == "gpt-5"
assert capabilities.friendly_name == "OpenAI (GPT-5)"
assert capabilities.context_window == 400_000
assert capabilities.max_output_tokens == 128_000
assert capabilities.provider == ProviderType.OPENAI
assert capabilities.supports_extended_thinking is True
assert capabilities.supports_system_prompts is True
assert capabilities.supports_streaming is True
assert capabilities.supports_function_calling is True
assert capabilities.supports_temperature is True
def test_get_capabilities_gpt5_mini(self):
"""Test getting model capabilities for GPT-5-mini."""
provider = OpenAIModelProvider("test-key")
capabilities = provider.get_capabilities("gpt-5-mini")
assert capabilities.model_name == "gpt-5-mini"
assert capabilities.friendly_name == "OpenAI (GPT-5-mini)"
assert capabilities.context_window == 400_000
assert capabilities.max_output_tokens == 128_000
assert capabilities.provider == ProviderType.OPENAI
assert capabilities.supports_extended_thinking is True
assert capabilities.supports_system_prompts is True
assert capabilities.supports_streaming is True
assert capabilities.supports_function_calling is True
assert capabilities.supports_temperature is True
@patch("providers.openai_compatible.OpenAI")
def test_generate_content_resolves_alias_before_api_call(self, mock_openai_class):
"""Test that generate_content resolves aliases before making API calls.
@@ -132,21 +174,19 @@ class TestOpenAIProvider:
provider = OpenAIModelProvider("test-key")
# Call generate_content with alias 'gpt4.1' (resolves to gpt-4.1-2025-04-14, supports temperature)
# Call generate_content with alias 'gpt4.1' (resolves to gpt-4.1, supports temperature)
result = provider.generate_content(
prompt="Test prompt",
model_name="gpt4.1",
temperature=1.0, # This should be resolved to "gpt-4.1-2025-04-14"
temperature=1.0, # This should be resolved to "gpt-4.1"
)
# Verify the API was called with the RESOLVED model name
mock_client.chat.completions.create.assert_called_once()
call_kwargs = mock_client.chat.completions.create.call_args[1]
# CRITICAL ASSERTION: The API should receive "gpt-4.1-2025-04-14", not "gpt4.1"
assert (
call_kwargs["model"] == "gpt-4.1-2025-04-14"
), f"Expected 'gpt-4.1-2025-04-14' but API received '{call_kwargs['model']}'"
# CRITICAL ASSERTION: The API should receive "gpt-4.1", not "gpt4.1"
assert call_kwargs["model"] == "gpt-4.1", f"Expected 'gpt-4.1' but API received '{call_kwargs['model']}'"
# Verify other parameters (gpt-4.1 supports temperature unlike O3/O4 models)
assert call_kwargs["temperature"] == 1.0
@@ -156,7 +196,7 @@ class TestOpenAIProvider:
# Verify response
assert result.content == "Test response"
assert result.model_name == "gpt-4.1-2025-04-14" # Should be the resolved name
assert result.model_name == "gpt-4.1" # Should be the resolved name
@patch("providers.openai_compatible.OpenAI")
def test_generate_content_other_aliases(self, mock_openai_class):
@@ -213,14 +253,22 @@ class TestOpenAIProvider:
assert call_kwargs["model"] == "o3-mini" # Should be unchanged
def test_supports_thinking_mode(self):
"""Test thinking mode support (currently False for all OpenAI models)."""
"""Test thinking mode support."""
provider = OpenAIModelProvider("test-key")
# All OpenAI models currently don't support thinking mode
# GPT-5 models support thinking mode (reasoning tokens)
assert provider.supports_thinking_mode("gpt-5") is True
assert provider.supports_thinking_mode("gpt-5-mini") is True
assert provider.supports_thinking_mode("gpt5") is True # Test with alias
assert provider.supports_thinking_mode("gpt5mini") is True # Test with alias
# O3/O4 models don't support thinking mode
assert provider.supports_thinking_mode("o3") is False
assert provider.supports_thinking_mode("o3-mini") is False
assert provider.supports_thinking_mode("o4-mini") is False
assert provider.supports_thinking_mode("mini") is False # Test with alias too
assert (
provider.supports_thinking_mode("mini") is True
) # "mini" now resolves to gpt-5-mini which supports thinking
@patch("providers.openai_compatible.OpenAI")
def test_o3_pro_routes_to_responses_endpoint(self, mock_openai_class):
@@ -230,11 +278,9 @@ class TestOpenAIProvider:
mock_openai_class.return_value = mock_client
mock_response = MagicMock()
mock_response.output = MagicMock()
mock_response.output.content = [MagicMock()]
mock_response.output.content[0].type = "output_text"
mock_response.output.content[0].text = "4"
mock_response.model = "o3-pro-2025-06-10"
# New o3-pro format: direct output_text field
mock_response.output_text = "4"
mock_response.model = "o3-pro"
mock_response.id = "test-id"
mock_response.created_at = 1234567890
mock_response.usage = MagicMock()
@@ -252,13 +298,13 @@ class TestOpenAIProvider:
# Verify responses.create was called
mock_client.responses.create.assert_called_once()
call_args = mock_client.responses.create.call_args[1]
assert call_args["model"] == "o3-pro-2025-06-10"
assert call_args["model"] == "o3-pro"
assert call_args["input"][0]["role"] == "user"
assert "What is 2 + 2?" in call_args["input"][0]["content"][0]["text"]
# Verify the response
assert result.content == "4"
assert result.model_name == "o3-pro-2025-06-10"
assert result.model_name == "o3-pro"
assert result.metadata["endpoint"] == "responses"
@patch("providers.openai_compatible.OpenAI")

View File

@@ -3,6 +3,7 @@ Test per-tool model default selection functionality
"""
import json
import os
from unittest.mock import MagicMock, patch
import pytest
@@ -73,154 +74,194 @@ class TestToolModelCategories:
class TestModelSelection:
"""Test model selection based on tool categories."""
def teardown_method(self):
"""Clean up after each test to prevent state pollution."""
ModelProviderRegistry.clear_cache()
# Unregister all providers
for provider_type in list(ProviderType):
ModelProviderRegistry.unregister_provider(provider_type)
def test_extended_reasoning_with_openai(self):
"""Test EXTENDED_REASONING prefers o3 when OpenAI is available."""
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available:
# Mock OpenAI models available
mock_get_available.return_value = {
"o3": ProviderType.OPENAI,
"o3-mini": ProviderType.OPENAI,
"o4-mini": ProviderType.OPENAI,
}
"""Test EXTENDED_REASONING with OpenAI provider."""
# Setup with only OpenAI provider
ModelProviderRegistry.clear_cache()
# First unregister all providers to ensure isolation
for provider_type in list(ProviderType):
ModelProviderRegistry.unregister_provider(provider_type)
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}, clear=False):
from providers.openai_provider import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
# OpenAI prefers o3 for extended reasoning
assert model == "o3"
def test_extended_reasoning_with_gemini_only(self):
"""Test EXTENDED_REASONING prefers pro when only Gemini is available."""
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available:
# Mock only Gemini models available
mock_get_available.return_value = {
"gemini-2.5-pro": ProviderType.GOOGLE,
"gemini-2.5-flash": ProviderType.GOOGLE,
}
# Clear cache and unregister all providers first
ModelProviderRegistry.clear_cache()
for provider_type in list(ProviderType):
ModelProviderRegistry.unregister_provider(provider_type)
# Register only Gemini provider
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-key"}, clear=False):
from providers.gemini import GeminiModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
# Should find the pro model for extended reasoning
assert "pro" in model or model == "gemini-2.5-pro"
# Gemini should return one of its models for extended reasoning
# The default behavior may return flash when pro is not explicitly preferred
assert model in ["gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.0-flash"]
def test_fast_response_with_openai(self):
"""Test FAST_RESPONSE prefers o4-mini when OpenAI is available."""
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available:
# Mock OpenAI models available
mock_get_available.return_value = {
"o3": ProviderType.OPENAI,
"o3-mini": ProviderType.OPENAI,
"o4-mini": ProviderType.OPENAI,
}
"""Test FAST_RESPONSE with OpenAI provider."""
# Setup with only OpenAI provider
ModelProviderRegistry.clear_cache()
# First unregister all providers to ensure isolation
for provider_type in list(ProviderType):
ModelProviderRegistry.unregister_provider(provider_type)
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}, clear=False):
from providers.openai_provider import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
assert model == "o4-mini"
# OpenAI now prefers gpt-5 for fast response (based on our new preference order)
assert model == "gpt-5"
def test_fast_response_with_gemini_only(self):
"""Test FAST_RESPONSE prefers flash when only Gemini is available."""
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available:
# Mock only Gemini models available
mock_get_available.return_value = {
"gemini-2.5-pro": ProviderType.GOOGLE,
"gemini-2.5-flash": ProviderType.GOOGLE,
}
# Clear cache and unregister all providers first
ModelProviderRegistry.clear_cache()
for provider_type in list(ProviderType):
ModelProviderRegistry.unregister_provider(provider_type)
# Register only Gemini provider
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-key"}, clear=False):
from providers.gemini import GeminiModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
# Should find the flash model for fast response
assert "flash" in model or model == "gemini-2.5-flash"
# Gemini should return one of its models for fast response
assert model in ["gemini-2.5-flash", "gemini-2.0-flash", "gemini-2.5-pro"]
def test_balanced_category_fallback(self):
"""Test BALANCED category uses existing logic."""
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available:
# Mock OpenAI models available
mock_get_available.return_value = {
"o3": ProviderType.OPENAI,
"o3-mini": ProviderType.OPENAI,
"o4-mini": ProviderType.OPENAI,
}
# Setup with only OpenAI provider
ModelProviderRegistry.clear_cache()
# First unregister all providers to ensure isolation
for provider_type in list(ProviderType):
ModelProviderRegistry.unregister_provider(provider_type)
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}, clear=False):
from providers.openai_provider import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.BALANCED)
assert model == "o4-mini" # Balanced prefers o4-mini when OpenAI available
# OpenAI prefers gpt-5 for balanced (based on our new preference order)
assert model == "gpt-5"
def test_no_category_uses_balanced_logic(self):
"""Test that no category specified uses balanced logic."""
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available:
# Mock only Gemini models available
mock_get_available.return_value = {
"gemini-2.5-pro": ProviderType.GOOGLE,
"gemini-2.5-flash": ProviderType.GOOGLE,
}
# Setup with only Gemini provider
with patch.dict(os.environ, {"GEMINI_API_KEY": "test-key"}, clear=False):
from providers.gemini import GeminiModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
model = ModelProviderRegistry.get_preferred_fallback_model()
# Should pick a reasonable default, preferring flash for balanced use
assert "flash" in model or model == "gemini-2.5-flash"
# Should pick flash for balanced use
assert model == "gemini-2.5-flash"
class TestFlexibleModelSelection:
"""Test that model selection handles various naming scenarios."""
def test_fallback_handles_mixed_model_names(self):
"""Test that fallback selection works with mix of full names and shorthands."""
# Test with mix of full names and shorthands
"""Test that fallback selection works with different providers."""
# Test with different provider configurations
test_cases = [
# Case 1: Mix of OpenAI shorthands and full names
# Case 1: OpenAI provider for extended reasoning
{
"available": {"o3": ProviderType.OPENAI, "o4-mini": ProviderType.OPENAI},
"env": {"OPENAI_API_KEY": "test-key"},
"provider_type": ProviderType.OPENAI,
"category": ToolModelCategory.EXTENDED_REASONING,
"expected": "o3",
},
# Case 2: Mix of Gemini shorthands and full names
# Case 2: Gemini provider for fast response
{
"available": {
"gemini-2.5-flash": ProviderType.GOOGLE,
"gemini-2.5-pro": ProviderType.GOOGLE,
},
"env": {"GEMINI_API_KEY": "test-key"},
"provider_type": ProviderType.GOOGLE,
"category": ToolModelCategory.FAST_RESPONSE,
"expected_contains": "flash",
"expected": "gemini-2.5-flash",
},
# Case 3: Only shorthands available
# Case 3: OpenAI provider for fast response
{
"available": {"o4-mini": ProviderType.OPENAI, "o3-mini": ProviderType.OPENAI},
"env": {"OPENAI_API_KEY": "test-key"},
"provider_type": ProviderType.OPENAI,
"category": ToolModelCategory.FAST_RESPONSE,
"expected": "o4-mini",
"expected": "gpt-5", # Based on new preference order
},
]
for case in test_cases:
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available:
mock_get_available.return_value = case["available"]
# Clear registry for clean test
ModelProviderRegistry.clear_cache()
# First unregister all providers to ensure isolation
for provider_type in list(ProviderType):
ModelProviderRegistry.unregister_provider(provider_type)
with patch.dict(os.environ, case["env"], clear=False):
# Register the appropriate provider
if case["provider_type"] == ProviderType.OPENAI:
from providers.openai_provider import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
elif case["provider_type"] == ProviderType.GOOGLE:
from providers.gemini import GeminiModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
model = ModelProviderRegistry.get_preferred_fallback_model(case["category"])
if "expected" in case:
assert model == case["expected"], f"Failed for case: {case}"
elif "expected_contains" in case:
assert (
case["expected_contains"] in model
), f"Expected '{case['expected_contains']}' in '{model}' for case: {case}"
assert model == case["expected"], f"Failed for case: {case}, got {model}"
class TestCustomProviderFallback:
"""Test fallback to custom/openrouter providers."""
@patch.object(ModelProviderRegistry, "_find_extended_thinking_model")
def test_extended_reasoning_custom_fallback(self, mock_find_thinking):
"""Test EXTENDED_REASONING falls back to custom thinking model."""
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available:
# No native models available, but OpenRouter is available
mock_get_available.return_value = {"openrouter-model": ProviderType.OPENROUTER}
mock_find_thinking.return_value = "custom/thinking-model"
def test_extended_reasoning_custom_fallback(self):
"""Test EXTENDED_REASONING with custom provider."""
# Setup with custom provider
ModelProviderRegistry.clear_cache()
with patch.dict(os.environ, {"CUSTOM_API_URL": "http://localhost:11434", "CUSTOM_API_KEY": ""}, clear=False):
from providers.custom import CustomProvider
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
assert model == "custom/thinking-model"
mock_find_thinking.assert_called_once()
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider)
@patch.object(ModelProviderRegistry, "_find_extended_thinking_model")
def test_extended_reasoning_final_fallback(self, mock_find_thinking):
"""Test EXTENDED_REASONING falls back to pro when no custom found."""
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
# No providers available
mock_get_provider.return_value = None
mock_find_thinking.return_value = None
provider = ModelProviderRegistry.get_provider(ProviderType.CUSTOM)
if provider:
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
# Should get a model from custom provider
assert model is not None
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
assert model == "gemini-2.5-pro"
def test_extended_reasoning_final_fallback(self):
"""Test EXTENDED_REASONING falls back to default when no providers."""
# Clear all providers
ModelProviderRegistry.clear_cache()
for provider_type in list(
ModelProviderRegistry._instance._providers.keys() if ModelProviderRegistry._instance else []
):
ModelProviderRegistry.unregister_provider(provider_type)
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
# Should fall back to hardcoded default
assert model == "gemini-2.5-flash"
class TestAutoModeErrorMessages:
@@ -266,42 +307,45 @@ class TestAutoModeErrorMessages:
class TestProviderHelperMethods:
"""Test the helper methods for finding models from custom/openrouter."""
def test_find_extended_thinking_model_custom(self):
"""Test finding thinking model from custom provider."""
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
def test_extended_reasoning_with_custom_provider(self):
"""Test extended reasoning model selection with custom provider."""
# Setup with custom provider
with patch.dict(os.environ, {"CUSTOM_API_URL": "http://localhost:11434", "CUSTOM_API_KEY": ""}, clear=False):
from providers.custom import CustomProvider
# Mock custom provider with thinking model
mock_custom = MagicMock(spec=CustomProvider)
mock_custom.model_registry = {
"model1": {"supports_extended_thinking": False},
"model2": {"supports_extended_thinking": True},
"model3": {"supports_extended_thinking": False},
}
mock_get_provider.side_effect = lambda ptype: mock_custom if ptype == ProviderType.CUSTOM else None
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider)
model = ModelProviderRegistry._find_extended_thinking_model()
assert model == "model2"
provider = ModelProviderRegistry.get_provider(ProviderType.CUSTOM)
if provider:
# Custom provider should return a model for extended reasoning
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
assert model is not None
def test_find_extended_thinking_model_openrouter(self):
"""Test finding thinking model from openrouter."""
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
# Mock openrouter provider
mock_openrouter = MagicMock()
mock_openrouter.validate_model_name.side_effect = lambda m: m == "anthropic/claude-sonnet-4"
mock_get_provider.side_effect = lambda ptype: mock_openrouter if ptype == ProviderType.OPENROUTER else None
def test_extended_reasoning_with_openrouter(self):
"""Test extended reasoning model selection with OpenRouter."""
# Setup with OpenRouter provider
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}, clear=False):
from providers.openrouter import OpenRouterProvider
model = ModelProviderRegistry._find_extended_thinking_model()
assert model == "anthropic/claude-sonnet-4"
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
def test_find_extended_thinking_model_none_found(self):
"""Test when no thinking model is found."""
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
# No providers available
mock_get_provider.return_value = None
# OpenRouter should provide a model for extended reasoning
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
# Should return first available OpenRouter model
assert model is not None
model = ModelProviderRegistry._find_extended_thinking_model()
assert model is None
def test_fallback_when_no_providers_available(self):
"""Test fallback when no providers are available."""
# Clear all providers
ModelProviderRegistry.clear_cache()
for provider_type in list(
ModelProviderRegistry._instance._providers.keys() if ModelProviderRegistry._instance else []
):
ModelProviderRegistry.unregister_provider(provider_type)
# Should return hardcoded fallback
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
assert model == "gemini-2.5-flash"
class TestEffectiveAutoMode:

143
tests/test_pii_sanitizer.py Normal file
View File

@@ -0,0 +1,143 @@
#!/usr/bin/env python3
"""Test cases for PII sanitizer."""
import unittest
from .pii_sanitizer import PIIPattern, PIISanitizer
class TestPIISanitizer(unittest.TestCase):
"""Test PII sanitization functionality."""
def setUp(self):
"""Set up test sanitizer."""
self.sanitizer = PIISanitizer()
def test_api_key_sanitization(self):
"""Test various API key formats are sanitized."""
test_cases = [
# OpenAI keys
("sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12", "sk-proj-SANITIZED"),
("sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN", "sk-SANITIZED"),
# Anthropic keys
("sk-ant-abcd1234567890ABCD1234567890abcd1234567890ABCD12", "sk-ant-SANITIZED"),
# Google keys
("AIzaSyD-1234567890abcdefghijklmnopqrstuv", "AIza-SANITIZED"),
# GitHub tokens
("ghp_1234567890abcdefghijklmnopqrstuvwxyz", "gh_SANITIZED"),
("ghs_1234567890abcdefghijklmnopqrstuvwxyz", "gh_SANITIZED"),
]
for original, expected in test_cases:
with self.subTest(original=original):
result = self.sanitizer.sanitize_string(original)
self.assertEqual(result, expected)
def test_personal_info_sanitization(self):
"""Test personal information is sanitized."""
test_cases = [
# Email addresses
("john.doe@example.com", "user@example.com"),
("test123@company.org", "user@example.com"),
# Phone numbers (all now use the same pattern)
("(555) 123-4567", "(XXX) XXX-XXXX"),
("555-123-4567", "(XXX) XXX-XXXX"),
("+1-555-123-4567", "(XXX) XXX-XXXX"),
# SSN
("123-45-6789", "XXX-XX-XXXX"),
# Credit card
("1234 5678 9012 3456", "XXXX-XXXX-XXXX-XXXX"),
("1234-5678-9012-3456", "XXXX-XXXX-XXXX-XXXX"),
]
for original, expected in test_cases:
with self.subTest(original=original):
result = self.sanitizer.sanitize_string(original)
self.assertEqual(result, expected)
def test_header_sanitization(self):
"""Test HTTP header sanitization."""
headers = {
"Authorization": "Bearer sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12",
"API-Key": "sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN",
"Content-Type": "application/json",
"User-Agent": "MyApp/1.0",
"Cookie": "session=abc123; user=john.doe@example.com",
}
sanitized = self.sanitizer.sanitize_headers(headers)
self.assertEqual(sanitized["Authorization"], "Bearer SANITIZED")
self.assertEqual(sanitized["API-Key"], "sk-SANITIZED")
self.assertEqual(sanitized["Content-Type"], "application/json")
self.assertEqual(sanitized["User-Agent"], "MyApp/1.0")
self.assertIn("user@example.com", sanitized["Cookie"])
def test_nested_structure_sanitization(self):
"""Test sanitization of nested data structures."""
data = {
"user": {
"email": "john.doe@example.com",
"api_key": "sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12",
},
"tokens": [
"ghp_1234567890abcdefghijklmnopqrstuvwxyz",
"Bearer sk-ant-abcd1234567890ABCD1234567890abcd1234567890ABCD12",
],
"metadata": {"ip": "192.168.1.100", "phone": "(555) 123-4567"},
}
sanitized = self.sanitizer.sanitize_value(data)
self.assertEqual(sanitized["user"]["email"], "user@example.com")
self.assertEqual(sanitized["user"]["api_key"], "sk-proj-SANITIZED")
self.assertEqual(sanitized["tokens"][0], "gh_SANITIZED")
self.assertEqual(sanitized["tokens"][1], "Bearer sk-ant-SANITIZED")
self.assertEqual(sanitized["metadata"]["ip"], "0.0.0.0")
self.assertEqual(sanitized["metadata"]["phone"], "(XXX) XXX-XXXX")
def test_url_sanitization(self):
"""Test URL parameter sanitization."""
urls = [
(
"https://api.example.com/v1/users?api_key=sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN",
"https://api.example.com/v1/users?api_key=SANITIZED",
),
(
"https://example.com/login?token=ghp_1234567890abcdefghijklmnopqrstuvwxyz&user=test",
"https://example.com/login?token=SANITIZED&user=test",
),
]
for original, expected in urls:
with self.subTest(url=original):
result = self.sanitizer.sanitize_url(original)
self.assertEqual(result, expected)
def test_disable_sanitization(self):
"""Test that sanitization can be disabled."""
self.sanitizer.sanitize_enabled = False
sensitive_data = "sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12"
result = self.sanitizer.sanitize_string(sensitive_data)
# Should return original when disabled
self.assertEqual(result, sensitive_data)
def test_custom_pattern(self):
"""Test adding custom PII patterns."""
# Add custom pattern for internal employee IDs
custom_pattern = PIIPattern.create(
name="employee_id", pattern=r"EMP\d{6}", replacement="EMP-REDACTED", description="Internal employee IDs"
)
self.sanitizer.add_pattern(custom_pattern)
text = "Employee EMP123456 has access to the system"
result = self.sanitizer.sanitize_string(text)
self.assertEqual(result, "Employee EMP-REDACTED has access to the system")
if __name__ == "__main__":
unittest.main()

View File

@@ -126,7 +126,7 @@ class TestProviderUTF8Encoding(unittest.TestCase):
mock_response.usage = Mock()
mock_response.usage.input_tokens = 50
mock_response.usage.output_tokens = 25
mock_response.model = "o3-pro-2025-06-10"
mock_response.model = "o3-pro"
mock_response.id = "test-id"
mock_response.created_at = 1234567890
@@ -141,7 +141,7 @@ class TestProviderUTF8Encoding(unittest.TestCase):
with patch("logging.info") as mock_logging:
response = provider.generate_content(
prompt="Analyze this Python code for issues",
model_name="o3-pro-2025-06-10",
model_name="o3-pro",
system_prompt="You are a code review expert.",
)
@@ -351,7 +351,7 @@ class TestLocaleModelIntegration(unittest.TestCase):
def test_model_name_resolution_utf8(self):
"""Test model name resolution with UTF-8."""
provider = OpenAIModelProvider(api_key="test")
model_names = ["gpt-4", "gemini-2.5-flash", "claude-3-opus", "o3-pro-2025-06-10"]
model_names = ["gpt-4", "gemini-2.5-flash", "claude-3-opus", "o3-pro"]
for model_name in model_names:
resolved = provider._resolve_model_name(model_name)
self.assertIsInstance(resolved, str)

View File

@@ -47,22 +47,23 @@ class TestSupportedModelsAliases:
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
# Test specific aliases
assert "mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases
# "mini" is now an alias for gpt-5-mini, not o4-mini
assert "mini" in provider.SUPPORTED_MODELS["gpt-5-mini"].aliases
assert "o4mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases
assert "o4-mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases
assert "o3mini" in provider.SUPPORTED_MODELS["o3-mini"].aliases
assert "o3-pro" in provider.SUPPORTED_MODELS["o3-pro-2025-06-10"].aliases
assert "o4mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases
assert "gpt4.1" in provider.SUPPORTED_MODELS["gpt-4.1-2025-04-14"].aliases
assert "o3-pro" in provider.SUPPORTED_MODELS["o3-pro"].aliases
assert "gpt4.1" in provider.SUPPORTED_MODELS["gpt-4.1"].aliases
# Test alias resolution
assert provider._resolve_model_name("mini") == "o4-mini"
assert provider._resolve_model_name("mini") == "gpt-5-mini" # mini -> gpt-5-mini now
assert provider._resolve_model_name("o3mini") == "o3-mini"
assert provider._resolve_model_name("o3-pro") == "o3-pro-2025-06-10"
assert provider._resolve_model_name("o3-pro") == "o3-pro" # o3-pro is already the base model name
assert provider._resolve_model_name("o4mini") == "o4-mini"
assert provider._resolve_model_name("gpt4.1") == "gpt-4.1-2025-04-14"
assert provider._resolve_model_name("gpt4.1") == "gpt-4.1" # gpt4.1 resolves to gpt-4.1
# Test case insensitive resolution
assert provider._resolve_model_name("Mini") == "o4-mini"
assert provider._resolve_model_name("Mini") == "gpt-5-mini" # mini -> gpt-5-mini now
assert provider._resolve_model_name("O3MINI") == "o3-mini"
def test_xai_provider_aliases(self):
@@ -75,19 +76,21 @@ class TestSupportedModelsAliases:
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
# Test specific aliases
assert "grok" in provider.SUPPORTED_MODELS["grok-3"].aliases
assert "grok" in provider.SUPPORTED_MODELS["grok-4"].aliases
assert "grok4" in provider.SUPPORTED_MODELS["grok-4"].aliases
assert "grok3" in provider.SUPPORTED_MODELS["grok-3"].aliases
assert "grok3fast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases
assert "grokfast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases
# Test alias resolution
assert provider._resolve_model_name("grok") == "grok-3"
assert provider._resolve_model_name("grok") == "grok-4"
assert provider._resolve_model_name("grok4") == "grok-4"
assert provider._resolve_model_name("grok3") == "grok-3"
assert provider._resolve_model_name("grok3fast") == "grok-3-fast"
assert provider._resolve_model_name("grokfast") == "grok-3-fast"
# Test case insensitive resolution
assert provider._resolve_model_name("Grok") == "grok-3"
assert provider._resolve_model_name("Grok") == "grok-4"
assert provider._resolve_model_name("GROKFAST") == "grok-3-fast"
def test_dial_provider_aliases(self):

View File

@@ -45,6 +45,8 @@ class TestXAIProvider:
provider = XAIModelProvider("test-key")
# Test valid models
assert provider.validate_model_name("grok-4") is True
assert provider.validate_model_name("grok4") is True
assert provider.validate_model_name("grok-3") is True
assert provider.validate_model_name("grok-3-fast") is True
assert provider.validate_model_name("grok") is True
@@ -62,12 +64,14 @@ class TestXAIProvider:
provider = XAIModelProvider("test-key")
# Test shorthand resolution
assert provider._resolve_model_name("grok") == "grok-3"
assert provider._resolve_model_name("grok") == "grok-4"
assert provider._resolve_model_name("grok4") == "grok-4"
assert provider._resolve_model_name("grok3") == "grok-3"
assert provider._resolve_model_name("grokfast") == "grok-3-fast"
assert provider._resolve_model_name("grok3fast") == "grok-3-fast"
# Test full name passthrough
assert provider._resolve_model_name("grok-4") == "grok-4"
assert provider._resolve_model_name("grok-3") == "grok-3"
assert provider._resolve_model_name("grok-3-fast") == "grok-3-fast"
@@ -88,7 +92,28 @@ class TestXAIProvider:
# Test temperature range
assert capabilities.temperature_constraint.min_temp == 0.0
assert capabilities.temperature_constraint.max_temp == 2.0
assert capabilities.temperature_constraint.default_temp == 0.7
assert capabilities.temperature_constraint.default_temp == 0.3
def test_get_capabilities_grok4(self):
"""Test getting model capabilities for GROK-4."""
provider = XAIModelProvider("test-key")
capabilities = provider.get_capabilities("grok-4")
assert capabilities.model_name == "grok-4"
assert capabilities.friendly_name == "X.AI (Grok 4)"
assert capabilities.context_window == 256_000
assert capabilities.provider == ProviderType.XAI
assert capabilities.supports_extended_thinking is True
assert capabilities.supports_system_prompts is True
assert capabilities.supports_streaming is True
assert capabilities.supports_function_calling is True
assert capabilities.supports_json_mode is True
assert capabilities.supports_images is True
# Test temperature range
assert capabilities.temperature_constraint.min_temp == 0.0
assert capabilities.temperature_constraint.max_temp == 2.0
assert capabilities.temperature_constraint.default_temp == 0.3
def test_get_capabilities_grok3_fast(self):
"""Test getting model capabilities for GROK-3 Fast."""
@@ -106,8 +131,8 @@ class TestXAIProvider:
provider = XAIModelProvider("test-key")
capabilities = provider.get_capabilities("grok")
assert capabilities.model_name == "grok-3" # Should resolve to full name
assert capabilities.context_window == 131_072
assert capabilities.model_name == "grok-4" # Should resolve to full name
assert capabilities.context_window == 256_000
capabilities_fast = provider.get_capabilities("grokfast")
assert capabilities_fast.model_name == "grok-3-fast" # Should resolve to full name
@@ -119,13 +144,20 @@ class TestXAIProvider:
with pytest.raises(ValueError, match="Unsupported X.AI model"):
provider.get_capabilities("invalid-model")
def test_no_thinking_mode_support(self):
"""Test that X.AI models don't support thinking mode."""
def test_thinking_mode_support(self):
"""Test thinking mode support for X.AI models."""
provider = XAIModelProvider("test-key")
# Grok-4 supports thinking mode
assert provider.supports_thinking_mode("grok-4") is True
assert provider.supports_thinking_mode("grok") is True # Resolves to grok-4
# Grok-3 models don't support thinking mode
assert not provider.supports_thinking_mode("grok-3")
assert not provider.supports_thinking_mode("grok-3-fast")
assert not provider.supports_thinking_mode("grok")
assert provider.supports_thinking_mode("grok-4") # grok-4 supports thinking mode
assert provider.supports_thinking_mode("grok") # resolves to grok-4
assert provider.supports_thinking_mode("grok4") # resolves to grok-4
assert not provider.supports_thinking_mode("grokfast")
def test_provider_type(self):
@@ -145,7 +177,10 @@ class TestXAIProvider:
# grok-3 should be allowed
assert provider.validate_model_name("grok-3") is True
assert provider.validate_model_name("grok") is True # Shorthand for grok-3
assert provider.validate_model_name("grok3") is True # Shorthand for grok-3
# grok should be blocked (resolves to grok-4 which is not allowed)
assert provider.validate_model_name("grok") is False
# grok-3-fast should be blocked by restrictions
assert provider.validate_model_name("grok-3-fast") is False
@@ -161,10 +196,13 @@ class TestXAIProvider:
provider = XAIModelProvider("test-key")
# Shorthand "grok" should be allowed (resolves to grok-3)
# Shorthand "grok" should be allowed (resolves to grok-4)
assert provider.validate_model_name("grok") is True
# Full name "grok-3" should NOT be allowed (only shorthand "grok" is in restriction list)
# Full name "grok-4" should NOT be allowed (only shorthand "grok" is in restriction list)
assert provider.validate_model_name("grok-4") is False
# "grok-3" should NOT be allowed (not in restriction list)
assert provider.validate_model_name("grok-3") is False
# "grok-3-fast" should be allowed (explicitly listed)
@@ -173,7 +211,7 @@ class TestXAIProvider:
# Shorthand "grokfast" should be allowed (resolves to grok-3-fast)
assert provider.validate_model_name("grokfast") is True
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok,grok-3"})
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok,grok-3,grok-4"})
def test_both_shorthand_and_full_name_allowed(self):
"""Test that both shorthand and full name can be allowed."""
# Clear cached restriction service
@@ -184,8 +222,9 @@ class TestXAIProvider:
provider = XAIModelProvider("test-key")
# Both shorthand and full name should be allowed
assert provider.validate_model_name("grok") is True
assert provider.validate_model_name("grok") is True # Resolves to grok-4
assert provider.validate_model_name("grok-3") is True
assert provider.validate_model_name("grok-4") is True
# Other models should not be allowed
assert provider.validate_model_name("grok-3-fast") is False
@@ -201,10 +240,12 @@ class TestXAIProvider:
provider = XAIModelProvider("test-key")
assert provider.validate_model_name("grok-4") is True
assert provider.validate_model_name("grok-3") is True
assert provider.validate_model_name("grok-3-fast") is True
assert provider.validate_model_name("grok") is True
assert provider.validate_model_name("grokfast") is True
assert provider.validate_model_name("grok4") is True
def test_friendly_name(self):
"""Test friendly name constant."""
@@ -219,23 +260,36 @@ class TestXAIProvider:
provider = XAIModelProvider("test-key")
# Check that all expected base models are present
assert "grok-4" in provider.SUPPORTED_MODELS
assert "grok-3" in provider.SUPPORTED_MODELS
assert "grok-3-fast" in provider.SUPPORTED_MODELS
# Check model configs have required fields
from providers.base import ModelCapabilities
grok3_config = provider.SUPPORTED_MODELS["grok-3"]
assert isinstance(grok3_config, ModelCapabilities)
assert hasattr(grok3_config, "context_window")
assert hasattr(grok3_config, "supports_extended_thinking")
assert hasattr(grok3_config, "aliases")
assert grok3_config.context_window == 131_072
assert grok3_config.supports_extended_thinking is False
grok4_config = provider.SUPPORTED_MODELS["grok-4"]
assert isinstance(grok4_config, ModelCapabilities)
assert hasattr(grok4_config, "context_window")
assert hasattr(grok4_config, "supports_extended_thinking")
assert hasattr(grok4_config, "aliases")
assert grok4_config.context_window == 256_000
assert grok4_config.supports_extended_thinking is True
# Check aliases are correctly structured
assert "grok" in grok3_config.aliases
assert "grok3" in grok3_config.aliases
assert "grok" in grok4_config.aliases
assert "grok-4" in grok4_config.aliases
assert "grok4" in grok4_config.aliases
grok3_config = provider.SUPPORTED_MODELS["grok-3"]
assert grok3_config.context_window == 131_072
assert grok3_config.supports_extended_thinking is False
# Check aliases are correctly structured
assert "grok3" in grok3_config.aliases # grok3 resolves to grok-3
# Check grok-4 aliases
grok4_config = provider.SUPPORTED_MODELS["grok-4"]
assert "grok" in grok4_config.aliases # grok resolves to grok-4
assert "grok4" in grok4_config.aliases
grok3fast_config = provider.SUPPORTED_MODELS["grok-3-fast"]
assert "grok3fast" in grok3fast_config.aliases
@@ -246,7 +300,7 @@ class TestXAIProvider:
"""Test that generate_content resolves aliases before making API calls.
This is the CRITICAL test that ensures aliases like 'grok' get resolved
to 'grok-3' before being sent to X.AI API.
to 'grok-4' before being sent to X.AI API.
"""
# Set up mock OpenAI client
mock_client = MagicMock()
@@ -257,7 +311,7 @@ class TestXAIProvider:
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Test response"
mock_response.choices[0].finish_reason = "stop"
mock_response.model = "grok-3" # API returns the resolved model name
mock_response.model = "grok-4" # API returns the resolved model name
mock_response.id = "test-id"
mock_response.created = 1234567890
mock_response.usage = MagicMock()
@@ -271,15 +325,15 @@ class TestXAIProvider:
# Call generate_content with alias 'grok'
result = provider.generate_content(
prompt="Test prompt", model_name="grok", temperature=0.7 # This should be resolved to "grok-3"
prompt="Test prompt", model_name="grok", temperature=0.7 # This should be resolved to "grok-4"
)
# Verify the API was called with the RESOLVED model name
mock_client.chat.completions.create.assert_called_once()
call_kwargs = mock_client.chat.completions.create.call_args[1]
# CRITICAL ASSERTION: The API should receive "grok-3", not "grok"
assert call_kwargs["model"] == "grok-3", f"Expected 'grok-3' but API received '{call_kwargs['model']}'"
# CRITICAL ASSERTION: The API should receive "grok-4", not "grok"
assert call_kwargs["model"] == "grok-4", f"Expected 'grok-4' but API received '{call_kwargs['model']}'"
# Verify other parameters
assert call_kwargs["temperature"] == 0.7
@@ -289,7 +343,7 @@ class TestXAIProvider:
# Verify response
assert result.content == "Test response"
assert result.model_name == "grok-3" # Should be the resolved name
assert result.model_name == "grok-4" # Should be the resolved name
@patch("providers.openai_compatible.OpenAI")
def test_generate_content_other_aliases(self, mock_openai_class):
@@ -311,6 +365,17 @@ class TestXAIProvider:
provider = XAIModelProvider("test-key")
# Test grok4 -> grok-4
mock_response.model = "grok-4"
provider.generate_content(prompt="Test", model_name="grok4", temperature=0.7)
call_kwargs = mock_client.chat.completions.create.call_args[1]
assert call_kwargs["model"] == "grok-4"
# Test grok-4 -> grok-4
provider.generate_content(prompt="Test", model_name="grok-4", temperature=0.7)
call_kwargs = mock_client.chat.completions.create.call_args[1]
assert call_kwargs["model"] == "grok-4"
# Test grok3 -> grok-3
mock_response.model = "grok-3"
provider.generate_content(prompt="Test", model_name="grok3", temperature=0.7)

View File

@@ -0,0 +1,47 @@
"""Helper functions for HTTP transport injection in tests."""
from tests.http_transport_recorder import TransportFactory
def inject_transport(monkeypatch, cassette_path: str):
"""Inject HTTP transport into OpenAICompatibleProvider for testing.
This helper simplifies the monkey patching pattern used across tests
to inject custom HTTP transports for recording/replaying API calls.
Also ensures OpenAI provider is properly registered for tests that need it.
Args:
monkeypatch: pytest monkeypatch fixture
cassette_path: Path to cassette file for recording/replay
Returns:
The created transport instance
Example:
transport = inject_transport(monkeypatch, "path/to/cassette.json")
"""
# Ensure OpenAI provider is registered - always needed for transport injection
from providers.base import ProviderType
from providers.openai_provider import OpenAIModelProvider
from providers.registry import ModelProviderRegistry
# Always register OpenAI provider for transport tests (API key might be dummy)
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
# Create transport
transport = TransportFactory.create_transport(str(cassette_path))
# Inject transport using the established pattern
from providers.openai_compatible import OpenAICompatibleProvider
original_client_property = OpenAICompatibleProvider.client
def patched_client_getter(self):
if self._client is None:
self._test_transport = transport
return original_client_property.fget(self)
monkeypatch.setattr(OpenAICompatibleProvider, "client", property(patched_client_getter))
return transport