fix: Resolve o3-pro response parsing and test execution issues

- Fix lint errors: trailing whitespace and deprecated typing imports
- Update test mock for o3-pro response format (output.content[] → output_text)
- Implement robust test isolation with monkeypatch fixture
- Clear provider registry cache to prevent test interference
- Ensure o3-pro tests pass in both individual and full suite execution

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Josh Vera
2025-07-12 20:24:34 -06:00
parent ae5e43b792
commit 3db49413ff
8 changed files with 328 additions and 320 deletions

View File

@@ -221,7 +221,7 @@ class OpenAICompatibleProvider(ModelProvider):
# Create httpx client with minimal config to avoid proxy conflicts # Create httpx client with minimal config to avoid proxy conflicts
# Note: proxies parameter was removed in httpx 0.28.0 # Note: proxies parameter was removed in httpx 0.28.0
# Check for test transport injection # Check for test transport injection
if hasattr(self, '_test_transport'): if hasattr(self, "_test_transport"):
# Use custom transport for testing (HTTP recording/replay) # Use custom transport for testing (HTTP recording/replay)
http_client = httpx.Client( http_client = httpx.Client(
transport=self._test_transport, transport=self._test_transport,
@@ -318,13 +318,13 @@ class OpenAICompatibleProvider(ModelProvider):
""" """
logging.debug(f"Response object type: {type(response)}") logging.debug(f"Response object type: {type(response)}")
logging.debug(f"Response attributes: {dir(response)}") logging.debug(f"Response attributes: {dir(response)}")
if not hasattr(response, "output_text"): if not hasattr(response, "output_text"):
raise ValueError(f"o3-pro response missing output_text field. Response type: {type(response).__name__}") raise ValueError(f"o3-pro response missing output_text field. Response type: {type(response).__name__}")
content = response.output_text content = response.output_text
logging.debug(f"Extracted output_text: '{content}' (type: {type(content)})") logging.debug(f"Extracted output_text: '{content}' (type: {type(content)})")
if content is None: if content is None:
raise ValueError("o3-pro returned None for output_text") raise ValueError("o3-pro returned None for output_text")

View File

@@ -93,7 +93,7 @@ def pytest_collection_modifyitems(session, config, items):
if item.get_closest_marker("no_mock_provider"): if item.get_closest_marker("no_mock_provider"):
config._needs_dummy_keys = False config._needs_dummy_keys = False
break break
# Set dummy keys only if no test needs real keys # Set dummy keys only if no test needs real keys
if config._needs_dummy_keys: if config._needs_dummy_keys:
_set_dummy_keys_if_missing() _set_dummy_keys_if_missing()

View File

@@ -2,7 +2,7 @@
""" """
HTTP Transport Recorder for O3-Pro Testing HTTP Transport Recorder for O3-Pro Testing
Custom httpx transport solution that replaces respx for recording/replaying Custom httpx transport solution that replaces respx for recording/replaying
HTTP interactions. Provides full control over the recording process without HTTP interactions. Provides full control over the recording process without
respx limitations. respx limitations.
@@ -13,40 +13,40 @@ Key Features:
- JSON cassette format with data sanitization - JSON cassette format with data sanitization
""" """
import json
import hashlib
import copy
import base64 import base64
import copy
import hashlib
import json
from pathlib import Path from pathlib import Path
from typing import Dict, Any, Optional from typing import Any, Optional
import httpx
from io import BytesIO
from .pii_sanitizer import PIISanitizer
import httpx
from .pii_sanitizer import PIISanitizer
class RecordingTransport(httpx.HTTPTransport): class RecordingTransport(httpx.HTTPTransport):
"""Transport that wraps default httpx transport and records all interactions.""" """Transport that wraps default httpx transport and records all interactions."""
def __init__(self, cassette_path: str, capture_content: bool = True, sanitize: bool = True): def __init__(self, cassette_path: str, capture_content: bool = True, sanitize: bool = True):
super().__init__() super().__init__()
self.cassette_path = Path(cassette_path) self.cassette_path = Path(cassette_path)
self.recorded_interactions = [] self.recorded_interactions = []
self.capture_content = capture_content self.capture_content = capture_content
self.sanitizer = PIISanitizer() if sanitize else None self.sanitizer = PIISanitizer() if sanitize else None
def handle_request(self, request: httpx.Request) -> httpx.Response: def handle_request(self, request: httpx.Request) -> httpx.Response:
"""Handle request by recording interaction and delegating to real transport.""" """Handle request by recording interaction and delegating to real transport."""
print(f"🎬 RecordingTransport: Making request to {request.method} {request.url}") print(f"🎬 RecordingTransport: Making request to {request.method} {request.url}")
# Record request BEFORE making the call # Record request BEFORE making the call
request_data = self._serialize_request(request) request_data = self._serialize_request(request)
# Make real HTTP call using parent transport # Make real HTTP call using parent transport
response = super().handle_request(request) response = super().handle_request(request)
print(f"🎬 RecordingTransport: Got response {response.status_code}") print(f"🎬 RecordingTransport: Got response {response.status_code}")
# Post-response content capture (proper approach) # Post-response content capture (proper approach)
if self.capture_content: if self.capture_content:
try: try:
@@ -55,19 +55,20 @@ class RecordingTransport(httpx.HTTPTransport):
content_bytes = response.read() content_bytes = response.read()
response.close() # Close the original stream response.close() # Close the original stream
print(f"🎬 RecordingTransport: Captured {len(content_bytes)} bytes of decompressed content") print(f"🎬 RecordingTransport: Captured {len(content_bytes)} bytes of decompressed content")
# Serialize response with captured content # Serialize response with captured content
response_data = self._serialize_response_with_content(response, content_bytes) response_data = self._serialize_response_with_content(response, content_bytes)
# Create a new response with the same metadata but buffered content # Create a new response with the same metadata but buffered content
# If the original response was gzipped, we need to re-compress # If the original response was gzipped, we need to re-compress
response_content = content_bytes response_content = content_bytes
if response.headers.get('content-encoding') == 'gzip': if response.headers.get("content-encoding") == "gzip":
import gzip import gzip
print(f"🗜️ Re-compressing {len(content_bytes)} bytes with gzip...") print(f"🗜️ Re-compressing {len(content_bytes)} bytes with gzip...")
response_content = gzip.compress(content_bytes) response_content = gzip.compress(content_bytes)
print(f"🗜️ Compressed to {len(response_content)} bytes") print(f"🗜️ Compressed to {len(response_content)} bytes")
new_response = httpx.Response( new_response = httpx.Response(
status_code=response.status_code, status_code=response.status_code,
headers=response.headers, # Keep original headers intact headers=response.headers, # Keep original headers intact
@@ -76,15 +77,16 @@ class RecordingTransport(httpx.HTTPTransport):
extensions=response.extensions, extensions=response.extensions,
history=response.history, history=response.history,
) )
# Record the interaction # Record the interaction
self._record_interaction(request_data, response_data) self._record_interaction(request_data, response_data)
return new_response return new_response
except Exception as e: except Exception as e:
print(f"⚠️ Content capture failed: {e}, falling back to stub") print(f"⚠️ Content capture failed: {e}, falling back to stub")
import traceback import traceback
print(f"⚠️ Full exception traceback:\n{traceback.format_exc()}") print(f"⚠️ Full exception traceback:\n{traceback.format_exc()}")
response_data = self._serialize_response(response) response_data = self._serialize_response(response)
self._record_interaction(request_data, response_data) self._record_interaction(request_data, response_data)
@@ -94,105 +96,99 @@ class RecordingTransport(httpx.HTTPTransport):
response_data = self._serialize_response(response) response_data = self._serialize_response(response)
self._record_interaction(request_data, response_data) self._record_interaction(request_data, response_data)
return response return response
def _record_interaction(self, request_data: Dict[str, Any], response_data: Dict[str, Any]): def _record_interaction(self, request_data: dict[str, Any], response_data: dict[str, Any]):
"""Helper method to record interaction and save cassette.""" """Helper method to record interaction and save cassette."""
interaction = { interaction = {"request": request_data, "response": response_data}
"request": request_data,
"response": response_data
}
self.recorded_interactions.append(interaction) self.recorded_interactions.append(interaction)
self._save_cassette() self._save_cassette()
print(f"🎬 RecordingTransport: Saved cassette to {self.cassette_path}") print(f"🎬 RecordingTransport: Saved cassette to {self.cassette_path}")
def _serialize_request(self, request: httpx.Request) -> Dict[str, Any]: def _serialize_request(self, request: httpx.Request) -> dict[str, Any]:
"""Serialize httpx.Request to JSON-compatible format.""" """Serialize httpx.Request to JSON-compatible format."""
# For requests, we can safely read the content since it's already been prepared # For requests, we can safely read the content since it's already been prepared
# httpx.Request.content is safe to access multiple times # httpx.Request.content is safe to access multiple times
content = request.content content = request.content
# Convert bytes to string for JSON serialization # Convert bytes to string for JSON serialization
if isinstance(content, bytes): if isinstance(content, bytes):
try: try:
content_str = content.decode('utf-8') content_str = content.decode("utf-8")
except UnicodeDecodeError: except UnicodeDecodeError:
# Handle binary content (shouldn't happen for o3-pro API) # Handle binary content (shouldn't happen for o3-pro API)
content_str = content.hex() content_str = content.hex()
else: else:
content_str = str(content) if content else "" content_str = str(content) if content else ""
request_data = { request_data = {
"method": request.method, "method": request.method,
"url": str(request.url), "url": str(request.url),
"path": request.url.path, "path": request.url.path,
"headers": dict(request.headers), "headers": dict(request.headers),
"content": self._sanitize_request_content(content_str) "content": self._sanitize_request_content(content_str),
} }
# Apply PII sanitization if enabled # Apply PII sanitization if enabled
if self.sanitizer: if self.sanitizer:
request_data = self.sanitizer.sanitize_request(request_data) request_data = self.sanitizer.sanitize_request(request_data)
return request_data return request_data
def _serialize_response(self, response: httpx.Response) -> Dict[str, Any]: def _serialize_response(self, response: httpx.Response) -> dict[str, Any]:
"""Serialize httpx.Response to JSON-compatible format (legacy method without content).""" """Serialize httpx.Response to JSON-compatible format (legacy method without content)."""
# Legacy method for backward compatibility when content capture is disabled # Legacy method for backward compatibility when content capture is disabled
return { return {
"status_code": response.status_code, "status_code": response.status_code,
"headers": dict(response.headers), "headers": dict(response.headers),
"content": {"note": "Response content not recorded to avoid httpx.ResponseNotRead exception"}, "content": {"note": "Response content not recorded to avoid httpx.ResponseNotRead exception"},
"reason_phrase": response.reason_phrase "reason_phrase": response.reason_phrase,
} }
def _serialize_response_with_content(self, response: httpx.Response, content_bytes: bytes) -> Dict[str, Any]: def _serialize_response_with_content(self, response: httpx.Response, content_bytes: bytes) -> dict[str, Any]:
"""Serialize httpx.Response with captured content.""" """Serialize httpx.Response with captured content."""
try: try:
# Debug: check what we got # Debug: check what we got
print(f"🔍 Content type: {type(content_bytes)}, size: {len(content_bytes)}") print(f"🔍 Content type: {type(content_bytes)}, size: {len(content_bytes)}")
print(f"🔍 First 100 chars: {content_bytes[:100]}") print(f"🔍 First 100 chars: {content_bytes[:100]}")
# Ensure we have bytes for base64 encoding # Ensure we have bytes for base64 encoding
if not isinstance(content_bytes, bytes): if not isinstance(content_bytes, bytes):
print(f"⚠️ Content is not bytes, converting from {type(content_bytes)}") print(f"⚠️ Content is not bytes, converting from {type(content_bytes)}")
if isinstance(content_bytes, str): if isinstance(content_bytes, str):
content_bytes = content_bytes.encode('utf-8') content_bytes = content_bytes.encode("utf-8")
else: else:
content_bytes = str(content_bytes).encode('utf-8') content_bytes = str(content_bytes).encode("utf-8")
# Encode content as base64 for JSON storage # Encode content as base64 for JSON storage
print(f"🔍 Base64 encoding {len(content_bytes)} bytes...") print(f"🔍 Base64 encoding {len(content_bytes)} bytes...")
content_b64 = base64.b64encode(content_bytes).decode('utf-8') content_b64 = base64.b64encode(content_bytes).decode("utf-8")
print(f"✅ Base64 encoded successfully, result length: {len(content_b64)}") print(f"✅ Base64 encoded successfully, result length: {len(content_b64)}")
response_data = { response_data = {
"status_code": response.status_code, "status_code": response.status_code,
"headers": dict(response.headers), "headers": dict(response.headers),
"content": { "content": {"data": content_b64, "encoding": "base64", "size": len(content_bytes)},
"data": content_b64, "reason_phrase": response.reason_phrase,
"encoding": "base64",
"size": len(content_bytes)
},
"reason_phrase": response.reason_phrase
} }
# Apply PII sanitization if enabled # Apply PII sanitization if enabled
if self.sanitizer: if self.sanitizer:
response_data = self.sanitizer.sanitize_response(response_data) response_data = self.sanitizer.sanitize_response(response_data)
return response_data return response_data
except Exception as e: except Exception as e:
print(f"🔍 Error in _serialize_response_with_content: {e}") print(f"🔍 Error in _serialize_response_with_content: {e}")
import traceback import traceback
print(f"🔍 Full traceback: {traceback.format_exc()}") print(f"🔍 Full traceback: {traceback.format_exc()}")
# Fall back to minimal info # Fall back to minimal info
return { return {
"status_code": response.status_code, "status_code": response.status_code,
"headers": dict(response.headers), "headers": dict(response.headers),
"content": {"error": f"Failed to serialize content: {e}"}, "content": {"error": f"Failed to serialize content: {e}"},
"reason_phrase": response.reason_phrase "reason_phrase": response.reason_phrase,
} }
def _sanitize_request_content(self, content: str) -> Any: def _sanitize_request_content(self, content: str) -> Any:
"""Sanitize request content to remove sensitive data.""" """Sanitize request content to remove sensitive data."""
try: try:
@@ -203,14 +199,14 @@ class RecordingTransport(httpx.HTTPTransport):
except json.JSONDecodeError: except json.JSONDecodeError:
pass pass
return content return content
def _sanitize_response_content(self, data: Any) -> Any: def _sanitize_response_content(self, data: Any) -> Any:
"""Sanitize response content to remove sensitive data.""" """Sanitize response content to remove sensitive data."""
if not isinstance(data, dict): if not isinstance(data, dict):
return data return data
sanitized = copy.deepcopy(data) sanitized = copy.deepcopy(data)
# Sensitive fields to sanitize # Sensitive fields to sanitize
sensitive_fields = { sensitive_fields = {
"id": "resp_SANITIZED", "id": "resp_SANITIZED",
@@ -218,7 +214,7 @@ class RecordingTransport(httpx.HTTPTransport):
"created_at": 0, "created_at": 0,
"system_fingerprint": "fp_SANITIZED", "system_fingerprint": "fp_SANITIZED",
} }
def sanitize_dict(obj): def sanitize_dict(obj):
if isinstance(obj, dict): if isinstance(obj, dict):
for key, value in obj.items(): for key, value in obj.items():
@@ -230,82 +226,76 @@ class RecordingTransport(httpx.HTTPTransport):
for item in obj: for item in obj:
if isinstance(item, (dict, list)): if isinstance(item, (dict, list)):
sanitize_dict(item) sanitize_dict(item)
sanitize_dict(sanitized) sanitize_dict(sanitized)
return sanitized return sanitized
def _save_cassette(self): def _save_cassette(self):
"""Save recorded interactions to cassette file.""" """Save recorded interactions to cassette file."""
# Ensure directory exists # Ensure directory exists
self.cassette_path.parent.mkdir(parents=True, exist_ok=True) self.cassette_path.parent.mkdir(parents=True, exist_ok=True)
# Save cassette # Save cassette
cassette_data = { cassette_data = {"interactions": self.recorded_interactions}
"interactions": self.recorded_interactions
} self.cassette_path.write_text(json.dumps(cassette_data, indent=2, sort_keys=True))
self.cassette_path.write_text(
json.dumps(cassette_data, indent=2, sort_keys=True)
)
class ReplayTransport(httpx.MockTransport): class ReplayTransport(httpx.MockTransport):
"""Transport that replays saved HTTP interactions from cassettes.""" """Transport that replays saved HTTP interactions from cassettes."""
def __init__(self, cassette_path: str): def __init__(self, cassette_path: str):
self.cassette_path = Path(cassette_path) self.cassette_path = Path(cassette_path)
self.interactions = self._load_cassette() self.interactions = self._load_cassette()
super().__init__(self._handle_request) super().__init__(self._handle_request)
def _load_cassette(self) -> list: def _load_cassette(self) -> list:
"""Load interactions from cassette file.""" """Load interactions from cassette file."""
if not self.cassette_path.exists(): if not self.cassette_path.exists():
raise FileNotFoundError(f"Cassette file not found: {self.cassette_path}") raise FileNotFoundError(f"Cassette file not found: {self.cassette_path}")
try: try:
cassette_data = json.loads(self.cassette_path.read_text()) cassette_data = json.loads(self.cassette_path.read_text())
return cassette_data.get("interactions", []) return cassette_data.get("interactions", [])
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise ValueError(f"Invalid cassette file format: {e}") raise ValueError(f"Invalid cassette file format: {e}")
def _handle_request(self, request: httpx.Request) -> httpx.Response: def _handle_request(self, request: httpx.Request) -> httpx.Response:
"""Handle request by finding matching interaction and returning saved response.""" """Handle request by finding matching interaction and returning saved response."""
print(f"🔍 ReplayTransport: Looking for {request.method} {request.url}") print(f"🔍 ReplayTransport: Looking for {request.method} {request.url}")
# Debug: show what we're trying to match # Debug: show what we're trying to match
request_signature = self._get_request_signature(request) request_signature = self._get_request_signature(request)
print(f"🔍 Request signature: {request_signature}") print(f"🔍 Request signature: {request_signature}")
# Debug: show actual request content # Debug: show actual request content
content = request.content content = request.content
if hasattr(content, 'read'): if hasattr(content, "read"):
content = content.read() content = content.read()
if isinstance(content, bytes): if isinstance(content, bytes):
content_str = content.decode('utf-8', errors='ignore') content_str = content.decode("utf-8", errors="ignore")
else: else:
content_str = str(content) if content else "" content_str = str(content) if content else ""
print(f"🔍 Actual request content: {content_str}") print(f"🔍 Actual request content: {content_str}")
# Debug: show available signatures # Debug: show available signatures
for i, interaction in enumerate(self.interactions): for i, interaction in enumerate(self.interactions):
saved_signature = self._get_saved_request_signature(interaction["request"]) saved_signature = self._get_saved_request_signature(interaction["request"])
saved_content = interaction["request"].get("content", {}) saved_content = interaction["request"].get("content", {})
print(f"🔍 Available signature {i}: {saved_signature}") print(f"🔍 Available signature {i}: {saved_signature}")
print(f"🔍 Saved content {i}: {saved_content}") print(f"🔍 Saved content {i}: {saved_content}")
# Find matching interaction # Find matching interaction
interaction = self._find_matching_interaction(request) interaction = self._find_matching_interaction(request)
if not interaction: if not interaction:
print("🚨 MYSTERY SOLVED: No matching interaction found! This should fail...") print("🚨 MYSTERY SOLVED: No matching interaction found! This should fail...")
raise ValueError( raise ValueError(f"No matching interaction found for {request.method} {request.url}")
f"No matching interaction found for {request.method} {request.url}"
) print("✅ Found matching interaction from cassette!")
print(f"✅ Found matching interaction from cassette!")
# Build response from saved data # Build response from saved data
response_data = interaction["response"] response_data = interaction["response"]
# Convert content back to appropriate format # Convert content back to appropriate format
content = response_data.get("content", {}) content = response_data.get("content", {})
if isinstance(content, dict): if isinstance(content, dict):
@@ -317,55 +307,56 @@ class ReplayTransport(httpx.MockTransport):
print(f"🎬 ReplayTransport: Decoded {len(content_bytes)} bytes from base64") print(f"🎬 ReplayTransport: Decoded {len(content_bytes)} bytes from base64")
except Exception as e: except Exception as e:
print(f"⚠️ Failed to decode base64 content: {e}") print(f"⚠️ Failed to decode base64 content: {e}")
content_bytes = json.dumps(content).encode('utf-8') content_bytes = json.dumps(content).encode("utf-8")
else: else:
# Legacy format or stub content # Legacy format or stub content
content_bytes = json.dumps(content).encode('utf-8') content_bytes = json.dumps(content).encode("utf-8")
else: else:
content_bytes = str(content).encode('utf-8') content_bytes = str(content).encode("utf-8")
# Check if response expects gzipped content # Check if response expects gzipped content
headers = response_data.get("headers", {}) headers = response_data.get("headers", {})
if headers.get('content-encoding') == 'gzip': if headers.get("content-encoding") == "gzip":
# Re-compress the content for httpx # Re-compress the content for httpx
import gzip import gzip
print(f"🗜️ ReplayTransport: Re-compressing {len(content_bytes)} bytes with gzip...") print(f"🗜️ ReplayTransport: Re-compressing {len(content_bytes)} bytes with gzip...")
content_bytes = gzip.compress(content_bytes) content_bytes = gzip.compress(content_bytes)
print(f"🗜️ ReplayTransport: Compressed to {len(content_bytes)} bytes") print(f"🗜️ ReplayTransport: Compressed to {len(content_bytes)} bytes")
print(f"🎬 ReplayTransport: Returning cassette response with content: {content_bytes[:100]}...") print(f"🎬 ReplayTransport: Returning cassette response with content: {content_bytes[:100]}...")
# Create httpx.Response # Create httpx.Response
return httpx.Response( return httpx.Response(
status_code=response_data["status_code"], status_code=response_data["status_code"],
headers=response_data.get("headers", {}), headers=response_data.get("headers", {}),
content=content_bytes, content=content_bytes,
request=request request=request,
) )
def _find_matching_interaction(self, request: httpx.Request) -> Optional[Dict[str, Any]]: def _find_matching_interaction(self, request: httpx.Request) -> Optional[dict[str, Any]]:
"""Find interaction that matches the request.""" """Find interaction that matches the request."""
request_signature = self._get_request_signature(request) request_signature = self._get_request_signature(request)
for interaction in self.interactions: for interaction in self.interactions:
saved_signature = self._get_saved_request_signature(interaction["request"]) saved_signature = self._get_saved_request_signature(interaction["request"])
if request_signature == saved_signature: if request_signature == saved_signature:
return interaction return interaction
return None return None
def _get_request_signature(self, request: httpx.Request) -> str: def _get_request_signature(self, request: httpx.Request) -> str:
"""Generate signature for request matching.""" """Generate signature for request matching."""
# Use method, path, and content hash for matching # Use method, path, and content hash for matching
content = request.content content = request.content
if hasattr(content, 'read'): if hasattr(content, "read"):
content = content.read() content = content.read()
if isinstance(content, bytes): if isinstance(content, bytes):
content_str = content.decode('utf-8', errors='ignore') content_str = content.decode("utf-8", errors="ignore")
else: else:
content_str = str(content) if content else "" content_str = str(content) if content else ""
# Parse JSON and re-serialize with sorted keys for consistent hashing # Parse JSON and re-serialize with sorted keys for consistent hashing
try: try:
if content_str.strip(): if content_str.strip():
@@ -374,37 +365,37 @@ class ReplayTransport(httpx.MockTransport):
except json.JSONDecodeError: except json.JSONDecodeError:
# Not JSON, use as-is # Not JSON, use as-is
pass pass
# Create hash of content for stable matching # Create hash of content for stable matching
content_hash = hashlib.md5(content_str.encode()).hexdigest() content_hash = hashlib.md5(content_str.encode()).hexdigest()
return f"{request.method}:{request.url.path}:{content_hash}" return f"{request.method}:{request.url.path}:{content_hash}"
def _get_saved_request_signature(self, saved_request: Dict[str, Any]) -> str: def _get_saved_request_signature(self, saved_request: dict[str, Any]) -> str:
"""Generate signature for saved request.""" """Generate signature for saved request."""
method = saved_request["method"] method = saved_request["method"]
path = saved_request["path"] path = saved_request["path"]
# Hash the saved content # Hash the saved content
content = saved_request.get("content", "") content = saved_request.get("content", "")
if isinstance(content, dict): if isinstance(content, dict):
content_str = json.dumps(content, sort_keys=True) content_str = json.dumps(content, sort_keys=True)
else: else:
content_str = str(content) content_str = str(content)
content_hash = hashlib.md5(content_str.encode()).hexdigest() content_hash = hashlib.md5(content_str.encode()).hexdigest()
return f"{method}:{path}:{content_hash}" return f"{method}:{path}:{content_hash}"
class TransportFactory: class TransportFactory:
"""Factory for creating appropriate transport based on cassette availability.""" """Factory for creating appropriate transport based on cassette availability."""
@staticmethod @staticmethod
def create_transport(cassette_path: str) -> httpx.HTTPTransport: def create_transport(cassette_path: str) -> httpx.HTTPTransport:
"""Create transport based on cassette existence and API key availability.""" """Create transport based on cassette existence and API key availability."""
cassette_file = Path(cassette_path) cassette_file = Path(cassette_path)
# Check if we should record or replay # Check if we should record or replay
if cassette_file.exists(): if cassette_file.exists():
# Cassette exists - use replay mode # Cassette exists - use replay mode
@@ -413,15 +404,15 @@ class TransportFactory:
# No cassette - use recording mode # No cassette - use recording mode
# Note: We'll check for API key in the test itself # Note: We'll check for API key in the test itself
return RecordingTransport(cassette_path) return RecordingTransport(cassette_path)
@staticmethod @staticmethod
def should_record(cassette_path: str, api_key: Optional[str] = None) -> bool: def should_record(cassette_path: str, api_key: Optional[str] = None) -> bool:
"""Determine if we should record based on cassette and API key availability.""" """Determine if we should record based on cassette and API key availability."""
cassette_file = Path(cassette_path) cassette_file = Path(cassette_path)
# Record if cassette doesn't exist AND we have API key # Record if cassette doesn't exist AND we have API key
return not cassette_file.exists() and bool(api_key) return not cassette_file.exists() and bool(api_key)
@staticmethod @staticmethod
def should_replay(cassette_path: str) -> bool: def should_replay(cassette_path: str) -> bool:
"""Determine if we should replay based on cassette availability.""" """Determine if we should replay based on cassette availability."""
@@ -434,8 +425,8 @@ class TransportFactory:
# # In test setup: # # In test setup:
# cassette_path = "tests/cassettes/o3_pro_basic_math.json" # cassette_path = "tests/cassettes/o3_pro_basic_math.json"
# transport = TransportFactory.create_transport(cassette_path) # transport = TransportFactory.create_transport(cassette_path)
# #
# # Inject into OpenAI client: # # Inject into OpenAI client:
# provider._test_transport = transport # provider._test_transport = transport
# #
# # The provider's client property will detect _test_transport and use it # # The provider's client property will detect _test_transport and use it

View File

@@ -7,11 +7,12 @@ request/response recordings to prevent accidental exposure of API keys,
tokens, personal information, and other sensitive data. tokens, personal information, and other sensitive data.
""" """
import re
from typing import Any, Dict, List, Optional, Pattern
from dataclasses import dataclass
from copy import deepcopy
import logging import logging
import re
from copy import deepcopy
from dataclasses import dataclass
from re import Pattern
from typing import Any, Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -19,178 +20,170 @@ logger = logging.getLogger(__name__)
@dataclass @dataclass
class PIIPattern: class PIIPattern:
"""Defines a pattern for detecting and sanitizing PII.""" """Defines a pattern for detecting and sanitizing PII."""
name: str name: str
pattern: Pattern[str] pattern: Pattern[str]
replacement: str replacement: str
description: str description: str
@classmethod @classmethod
def create(cls, name: str, pattern: str, replacement: str, description: str) -> 'PIIPattern': def create(cls, name: str, pattern: str, replacement: str, description: str) -> "PIIPattern":
"""Create a PIIPattern with compiled regex.""" """Create a PIIPattern with compiled regex."""
return cls( return cls(name=name, pattern=re.compile(pattern), replacement=replacement, description=description)
name=name,
pattern=re.compile(pattern),
replacement=replacement,
description=description
)
class PIISanitizer: class PIISanitizer:
"""Sanitizes PII from various data structures while preserving format.""" """Sanitizes PII from various data structures while preserving format."""
def __init__(self, patterns: Optional[List[PIIPattern]] = None): def __init__(self, patterns: Optional[list[PIIPattern]] = None):
"""Initialize with optional custom patterns.""" """Initialize with optional custom patterns."""
self.patterns: List[PIIPattern] = patterns or [] self.patterns: list[PIIPattern] = patterns or []
self.sanitize_enabled = True self.sanitize_enabled = True
# Add default patterns if none provided # Add default patterns if none provided
if not patterns: if not patterns:
self._add_default_patterns() self._add_default_patterns()
def _add_default_patterns(self): def _add_default_patterns(self):
"""Add comprehensive default PII patterns.""" """Add comprehensive default PII patterns."""
default_patterns = [ default_patterns = [
# API Keys - Core patterns (Bearer tokens handled in sanitize_headers) # API Keys - Core patterns (Bearer tokens handled in sanitize_headers)
PIIPattern.create( PIIPattern.create(
name="openai_api_key_proj", name="openai_api_key_proj",
pattern=r'sk-proj-[A-Za-z0-9\-_]{48,}', pattern=r"sk-proj-[A-Za-z0-9\-_]{48,}",
replacement="sk-proj-SANITIZED", replacement="sk-proj-SANITIZED",
description="OpenAI project API keys" description="OpenAI project API keys",
), ),
PIIPattern.create( PIIPattern.create(
name="openai_api_key", name="openai_api_key",
pattern=r'sk-[A-Za-z0-9]{48,}', pattern=r"sk-[A-Za-z0-9]{48,}",
replacement="sk-SANITIZED", replacement="sk-SANITIZED",
description="OpenAI API keys" description="OpenAI API keys",
), ),
PIIPattern.create( PIIPattern.create(
name="anthropic_api_key", name="anthropic_api_key",
pattern=r'sk-ant-[A-Za-z0-9\-_]{48,}', pattern=r"sk-ant-[A-Za-z0-9\-_]{48,}",
replacement="sk-ant-SANITIZED", replacement="sk-ant-SANITIZED",
description="Anthropic API keys" description="Anthropic API keys",
), ),
PIIPattern.create( PIIPattern.create(
name="google_api_key", name="google_api_key",
pattern=r'AIza[A-Za-z0-9\-_]{35,}', pattern=r"AIza[A-Za-z0-9\-_]{35,}",
replacement="AIza-SANITIZED", replacement="AIza-SANITIZED",
description="Google API keys" description="Google API keys",
), ),
PIIPattern.create( PIIPattern.create(
name="github_tokens", name="github_tokens",
pattern=r'gh[psr]_[A-Za-z0-9]{36}', pattern=r"gh[psr]_[A-Za-z0-9]{36}",
replacement="gh_SANITIZED", replacement="gh_SANITIZED",
description="GitHub tokens (all types)" description="GitHub tokens (all types)",
), ),
# JWT tokens # JWT tokens
PIIPattern.create( PIIPattern.create(
name="jwt_token", name="jwt_token",
pattern=r'eyJ[A-Za-z0-9\-_]+\.eyJ[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_]+', pattern=r"eyJ[A-Za-z0-9\-_]+\.eyJ[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_]+",
replacement="eyJ-SANITIZED", replacement="eyJ-SANITIZED",
description="JSON Web Tokens" description="JSON Web Tokens",
), ),
# Personal Information # Personal Information
PIIPattern.create( PIIPattern.create(
name="email_address", name="email_address",
pattern=r'[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}', pattern=r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}",
replacement="user@example.com", replacement="user@example.com",
description="Email addresses" description="Email addresses",
), ),
PIIPattern.create( PIIPattern.create(
name="ipv4_address", name="ipv4_address",
pattern=r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b', pattern=r"\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b",
replacement="0.0.0.0", replacement="0.0.0.0",
description="IPv4 addresses" description="IPv4 addresses",
), ),
PIIPattern.create( PIIPattern.create(
name="ssn", name="ssn",
pattern=r'\b\d{3}-\d{2}-\d{4}\b', pattern=r"\b\d{3}-\d{2}-\d{4}\b",
replacement="XXX-XX-XXXX", replacement="XXX-XX-XXXX",
description="Social Security Numbers" description="Social Security Numbers",
), ),
PIIPattern.create( PIIPattern.create(
name="credit_card", name="credit_card",
pattern=r'\b\d{4}[\s\-]?\d{4}[\s\-]?\d{4}[\s\-]?\d{4}\b', pattern=r"\b\d{4}[\s\-]?\d{4}[\s\-]?\d{4}[\s\-]?\d{4}\b",
replacement="XXXX-XXXX-XXXX-XXXX", replacement="XXXX-XXXX-XXXX-XXXX",
description="Credit card numbers" description="Credit card numbers",
), ),
PIIPattern.create( PIIPattern.create(
name="phone_number", name="phone_number",
pattern=r'(?:\+\d{1,3}[\s\-]?)?\(?\d{3}\)?[\s\-]?\d{3}[\s\-]?\d{4}', pattern=r"(?:\+\d{1,3}[\s\-]?)?\(?\d{3}\)?[\s\-]?\d{3}[\s\-]?\d{4}",
replacement="(XXX) XXX-XXXX", replacement="(XXX) XXX-XXXX",
description="Phone numbers (all formats)" description="Phone numbers (all formats)",
), ),
# AWS # AWS
PIIPattern.create( PIIPattern.create(
name="aws_access_key", name="aws_access_key",
pattern=r'AKIA[0-9A-Z]{16}', pattern=r"AKIA[0-9A-Z]{16}",
replacement="AKIA-SANITIZED", replacement="AKIA-SANITIZED",
description="AWS access keys" description="AWS access keys",
), ),
# Other common patterns # Other common patterns
PIIPattern.create( PIIPattern.create(
name="slack_token", name="slack_token",
pattern=r'xox[baprs]-[0-9]{10,13}-[0-9]{10,13}-[a-zA-Z0-9]{24,34}', pattern=r"xox[baprs]-[0-9]{10,13}-[0-9]{10,13}-[a-zA-Z0-9]{24,34}",
replacement="xox-SANITIZED", replacement="xox-SANITIZED",
description="Slack tokens" description="Slack tokens",
), ),
PIIPattern.create( PIIPattern.create(
name="stripe_key", name="stripe_key",
pattern=r'(?:sk|pk)_(?:test|live)_[0-9a-zA-Z]{24,99}', pattern=r"(?:sk|pk)_(?:test|live)_[0-9a-zA-Z]{24,99}",
replacement="sk_SANITIZED", replacement="sk_SANITIZED",
description="Stripe API keys" description="Stripe API keys",
), ),
] ]
self.patterns.extend(default_patterns) self.patterns.extend(default_patterns)
def add_pattern(self, pattern: PIIPattern): def add_pattern(self, pattern: PIIPattern):
"""Add a custom PII pattern.""" """Add a custom PII pattern."""
self.patterns.append(pattern) self.patterns.append(pattern)
logger.info(f"Added PII pattern: {pattern.name}") logger.info(f"Added PII pattern: {pattern.name}")
def sanitize_string(self, text: str) -> str: def sanitize_string(self, text: str) -> str:
"""Apply all patterns to sanitize a string.""" """Apply all patterns to sanitize a string."""
if not self.sanitize_enabled or not isinstance(text, str): if not self.sanitize_enabled or not isinstance(text, str):
return text return text
sanitized = text sanitized = text
for pattern in self.patterns: for pattern in self.patterns:
if pattern.pattern.search(sanitized): if pattern.pattern.search(sanitized):
sanitized = pattern.pattern.sub(pattern.replacement, sanitized) sanitized = pattern.pattern.sub(pattern.replacement, sanitized)
logger.debug(f"Applied {pattern.name} sanitization") logger.debug(f"Applied {pattern.name} sanitization")
return sanitized return sanitized
def sanitize_headers(self, headers: Dict[str, str]) -> Dict[str, str]: def sanitize_headers(self, headers: dict[str, str]) -> dict[str, str]:
"""Special handling for HTTP headers.""" """Special handling for HTTP headers."""
if not self.sanitize_enabled: if not self.sanitize_enabled:
return headers return headers
sanitized_headers = {} sanitized_headers = {}
for key, value in headers.items(): for key, value in headers.items():
# Special case for Authorization headers to preserve auth type # Special case for Authorization headers to preserve auth type
if key.lower() == 'authorization' and ' ' in value: if key.lower() == "authorization" and " " in value:
auth_type = value.split(' ', 1)[0] auth_type = value.split(" ", 1)[0]
if auth_type in ('Bearer', 'Basic'): if auth_type in ("Bearer", "Basic"):
sanitized_headers[key] = f'{auth_type} SANITIZED' sanitized_headers[key] = f"{auth_type} SANITIZED"
else: else:
sanitized_headers[key] = self.sanitize_string(value) sanitized_headers[key] = self.sanitize_string(value)
else: else:
# Apply standard sanitization to all other headers # Apply standard sanitization to all other headers
sanitized_headers[key] = self.sanitize_string(value) sanitized_headers[key] = self.sanitize_string(value)
return sanitized_headers return sanitized_headers
def sanitize_value(self, value: Any) -> Any: def sanitize_value(self, value: Any) -> Any:
"""Recursively sanitize any value (string, dict, list, etc).""" """Recursively sanitize any value (string, dict, list, etc)."""
if not self.sanitize_enabled: if not self.sanitize_enabled:
return value return value
if isinstance(value, str): if isinstance(value, str):
return self.sanitize_string(value) return self.sanitize_string(value)
elif isinstance(value, dict): elif isinstance(value, dict):
@@ -202,25 +195,25 @@ class PIISanitizer:
else: else:
# For other types (int, float, bool, None), return as-is # For other types (int, float, bool, None), return as-is
return value return value
def sanitize_url(self, url: str) -> str: def sanitize_url(self, url: str) -> str:
"""Sanitize sensitive data from URLs (query params, etc).""" """Sanitize sensitive data from URLs (query params, etc)."""
if not self.sanitize_enabled: if not self.sanitize_enabled:
return url return url
# First apply general string sanitization # First apply general string sanitization
url = self.sanitize_string(url) url = self.sanitize_string(url)
# Parse and sanitize query parameters # Parse and sanitize query parameters
if '?' in url: if "?" in url:
base, query = url.split('?', 1) base, query = url.split("?", 1)
params = [] params = []
for param in query.split('&'): for param in query.split("&"):
if '=' in param: if "=" in param:
key, value = param.split('=', 1) key, value = param.split("=", 1)
# Sanitize common sensitive parameter names # Sanitize common sensitive parameter names
sensitive_params = {'key', 'token', 'api_key', 'secret', 'password'} sensitive_params = {"key", "token", "api_key", "secret", "password"}
if key.lower() in sensitive_params: if key.lower() in sensitive_params:
params.append(f"{key}=SANITIZED") params.append(f"{key}=SANITIZED")
else: else:
@@ -228,54 +221,53 @@ class PIISanitizer:
params.append(f"{key}={self.sanitize_string(value)}") params.append(f"{key}={self.sanitize_string(value)}")
else: else:
params.append(param) params.append(param)
return f"{base}?{'&'.join(params)}" return f"{base}?{'&'.join(params)}"
return url return url
def sanitize_request(self, request_data: dict[str, Any]) -> dict[str, Any]:
def sanitize_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""Sanitize a complete request dictionary.""" """Sanitize a complete request dictionary."""
sanitized = deepcopy(request_data) sanitized = deepcopy(request_data)
# Sanitize headers # Sanitize headers
if 'headers' in sanitized: if "headers" in sanitized:
sanitized['headers'] = self.sanitize_headers(sanitized['headers']) sanitized["headers"] = self.sanitize_headers(sanitized["headers"])
# Sanitize URL # Sanitize URL
if 'url' in sanitized: if "url" in sanitized:
sanitized['url'] = self.sanitize_url(sanitized['url']) sanitized["url"] = self.sanitize_url(sanitized["url"])
# Sanitize content # Sanitize content
if 'content' in sanitized: if "content" in sanitized:
sanitized['content'] = self.sanitize_value(sanitized['content']) sanitized["content"] = self.sanitize_value(sanitized["content"])
return sanitized return sanitized
def sanitize_response(self, response_data: Dict[str, Any]) -> Dict[str, Any]: def sanitize_response(self, response_data: dict[str, Any]) -> dict[str, Any]:
"""Sanitize a complete response dictionary.""" """Sanitize a complete response dictionary."""
sanitized = deepcopy(response_data) sanitized = deepcopy(response_data)
# Sanitize headers # Sanitize headers
if 'headers' in sanitized: if "headers" in sanitized:
sanitized['headers'] = self.sanitize_headers(sanitized['headers']) sanitized["headers"] = self.sanitize_headers(sanitized["headers"])
# Sanitize content # Sanitize content
if 'content' in sanitized: if "content" in sanitized:
# Handle base64 encoded content specially # Handle base64 encoded content specially
if isinstance(sanitized['content'], dict) and sanitized['content'].get('encoding') == 'base64': if isinstance(sanitized["content"], dict) and sanitized["content"].get("encoding") == "base64":
# Don't decode/re-encode the actual response body # Don't decode/re-encode the actual response body
# but sanitize any metadata # but sanitize any metadata
if 'data' in sanitized['content']: if "data" in sanitized["content"]:
# Keep the data as-is but sanitize other fields # Keep the data as-is but sanitize other fields
for key, value in sanitized['content'].items(): for key, value in sanitized["content"].items():
if key != 'data': if key != "data":
sanitized['content'][key] = self.sanitize_value(value) sanitized["content"][key] = self.sanitize_value(value)
else: else:
sanitized['content'] = self.sanitize_value(sanitized['content']) sanitized["content"] = self.sanitize_value(sanitized["content"])
return sanitized return sanitized
# Global instance for convenience # Global instance for convenience
default_sanitizer = PIISanitizer() default_sanitizer = PIISanitizer()

View File

@@ -10,10 +10,10 @@ This script will:
""" """
import json import json
import sys
from pathlib import Path
import shutil import shutil
import sys
from datetime import datetime from datetime import datetime
from pathlib import Path
# Add tests directory to path to import our modules # Add tests directory to path to import our modules
sys.path.insert(0, str(Path(__file__).parent)) sys.path.insert(0, str(Path(__file__).parent))
@@ -24,54 +24,55 @@ from pii_sanitizer import PIISanitizer
def sanitize_cassette(cassette_path: Path, backup: bool = True) -> bool: def sanitize_cassette(cassette_path: Path, backup: bool = True) -> bool:
"""Sanitize a single cassette file.""" """Sanitize a single cassette file."""
print(f"\n🔍 Processing: {cassette_path}") print(f"\n🔍 Processing: {cassette_path}")
if not cassette_path.exists(): if not cassette_path.exists():
print(f"❌ File not found: {cassette_path}") print(f"❌ File not found: {cassette_path}")
return False return False
try: try:
# Load cassette # Load cassette
with open(cassette_path, 'r') as f: with open(cassette_path) as f:
cassette_data = json.load(f) cassette_data = json.load(f)
# Create backup if requested # Create backup if requested
if backup: if backup:
backup_path = cassette_path.with_suffix(f'.backup-{datetime.now().strftime("%Y%m%d-%H%M%S")}.json') backup_path = cassette_path.with_suffix(f'.backup-{datetime.now().strftime("%Y%m%d-%H%M%S")}.json')
shutil.copy2(cassette_path, backup_path) shutil.copy2(cassette_path, backup_path)
print(f"📦 Backup created: {backup_path}") print(f"📦 Backup created: {backup_path}")
# Initialize sanitizer # Initialize sanitizer
sanitizer = PIISanitizer() sanitizer = PIISanitizer()
# Sanitize interactions # Sanitize interactions
if 'interactions' in cassette_data: if "interactions" in cassette_data:
sanitized_interactions = [] sanitized_interactions = []
for interaction in cassette_data['interactions']: for interaction in cassette_data["interactions"]:
sanitized_interaction = {} sanitized_interaction = {}
# Sanitize request # Sanitize request
if 'request' in interaction: if "request" in interaction:
sanitized_interaction['request'] = sanitizer.sanitize_request(interaction['request']) sanitized_interaction["request"] = sanitizer.sanitize_request(interaction["request"])
# Sanitize response # Sanitize response
if 'response' in interaction: if "response" in interaction:
sanitized_interaction['response'] = sanitizer.sanitize_response(interaction['response']) sanitized_interaction["response"] = sanitizer.sanitize_response(interaction["response"])
sanitized_interactions.append(sanitized_interaction) sanitized_interactions.append(sanitized_interaction)
cassette_data['interactions'] = sanitized_interactions cassette_data["interactions"] = sanitized_interactions
# Save sanitized cassette # Save sanitized cassette
with open(cassette_path, 'w') as f: with open(cassette_path, "w") as f:
json.dump(cassette_data, f, indent=2, sort_keys=True) json.dump(cassette_data, f, indent=2, sort_keys=True)
print(f"✅ Sanitized: {cassette_path}") print(f"✅ Sanitized: {cassette_path}")
return True return True
except Exception as e: except Exception as e:
print(f"❌ Error processing {cassette_path}: {e}") print(f"❌ Error processing {cassette_path}: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return False return False
@@ -79,31 +80,31 @@ def sanitize_cassette(cassette_path: Path, backup: bool = True) -> bool:
def main(): def main():
"""Sanitize all cassettes in the openai_cassettes directory.""" """Sanitize all cassettes in the openai_cassettes directory."""
cassettes_dir = Path(__file__).parent / "openai_cassettes" cassettes_dir = Path(__file__).parent / "openai_cassettes"
if not cassettes_dir.exists(): if not cassettes_dir.exists():
print(f"❌ Directory not found: {cassettes_dir}") print(f"❌ Directory not found: {cassettes_dir}")
sys.exit(1) sys.exit(1)
# Find all JSON cassettes # Find all JSON cassettes
cassette_files = list(cassettes_dir.glob("*.json")) cassette_files = list(cassettes_dir.glob("*.json"))
if not cassette_files: if not cassette_files:
print(f"❌ No cassette files found in {cassettes_dir}") print(f"❌ No cassette files found in {cassettes_dir}")
sys.exit(1) sys.exit(1)
print(f"🎬 Found {len(cassette_files)} cassette(s) to sanitize") print(f"🎬 Found {len(cassette_files)} cassette(s) to sanitize")
# Process each cassette # Process each cassette
success_count = 0 success_count = 0
for cassette_path in cassette_files: for cassette_path in cassette_files:
if sanitize_cassette(cassette_path): if sanitize_cassette(cassette_path):
success_count += 1 success_count += 1
print(f"\n✨ Sanitization complete: {success_count}/{len(cassette_files)} cassettes processed successfully") print(f"\n✨ Sanitization complete: {success_count}/{len(cassette_files)} cassettes processed successfully")
if success_count < len(cassette_files): if success_count < len(cassette_files):
sys.exit(1) sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -18,11 +18,11 @@ from pathlib import Path
import pytest import pytest
from dotenv import load_dotenv from dotenv import load_dotenv
from tools.chat import ChatTool
from providers import ModelProviderRegistry from providers import ModelProviderRegistry
from providers.base import ProviderType from providers.base import ProviderType
from providers.openai_provider import OpenAIModelProvider from providers.openai_provider import OpenAIModelProvider
from tests.http_transport_recorder import TransportFactory from tests.http_transport_recorder import TransportFactory
from tools.chat import ChatTool
# Load environment variables from .env file # Load environment variables from .env file
load_dotenv() load_dotenv()
@@ -32,54 +32,87 @@ cassette_dir = Path(__file__).parent / "openai_cassettes"
cassette_dir.mkdir(exist_ok=True) cassette_dir.mkdir(exist_ok=True)
@pytest.fixture
def allow_all_models(monkeypatch):
"""Allow all models by resetting the restriction service singleton."""
# Import here to avoid circular imports
from utils.model_restrictions import _restriction_service
# Store original state
original_service = _restriction_service
original_allowed_models = os.getenv("ALLOWED_MODELS")
original_openai_key = os.getenv("OPENAI_API_KEY")
# Reset the singleton so it will re-read env vars inside this fixture
monkeypatch.setattr("utils.model_restrictions._restriction_service", None)
monkeypatch.setenv("ALLOWED_MODELS", "") # empty string = no restrictions
monkeypatch.setenv("OPENAI_API_KEY", "dummy-key-for-replay") # transport layer expects a key
# Also clear the provider registry cache to ensure clean state
from providers.registry import ModelProviderRegistry
ModelProviderRegistry.clear_cache()
yield
# Clean up: reset singleton again so other tests don't see the unrestricted version
monkeypatch.setattr("utils.model_restrictions._restriction_service", None)
# Clear registry cache again for other tests
ModelProviderRegistry.clear_cache()
@pytest.mark.no_mock_provider # Disable provider mocking for this test @pytest.mark.no_mock_provider # Disable provider mocking for this test
class TestO3ProOutputTextFix(unittest.IsolatedAsyncioTestCase): class TestO3ProOutputTextFix(unittest.IsolatedAsyncioTestCase):
"""Test o3-pro response parsing fix using respx for HTTP recording/replay.""" """Test o3-pro response parsing fix using respx for HTTP recording/replay."""
def setUp(self): def setUp(self):
"""Set up the test by ensuring OpenAI provider is registered.""" """Set up the test by ensuring OpenAI provider is registered."""
# Clear any cached providers to ensure clean state
ModelProviderRegistry.clear_cache()
# Manually register the OpenAI provider to ensure it's available # Manually register the OpenAI provider to ensure it's available
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider) ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
@pytest.mark.usefixtures("allow_all_models")
async def test_o3_pro_uses_output_text_field(self): async def test_o3_pro_uses_output_text_field(self):
"""Test that o3-pro parsing uses the output_text convenience field via ChatTool.""" """Test that o3-pro parsing uses the output_text convenience field via ChatTool."""
cassette_path = cassette_dir / "o3_pro_basic_math.json" cassette_path = cassette_dir / "o3_pro_basic_math.json"
# Skip if no API key available and cassette doesn't exist # Skip if no API key available and cassette doesn't exist
if not cassette_path.exists() and not os.getenv("OPENAI_API_KEY"): if not cassette_path.exists() and not os.getenv("OPENAI_API_KEY"):
pytest.skip("Set real OPENAI_API_KEY to record cassettes") pytest.skip("Set real OPENAI_API_KEY to record cassettes")
# Create transport (automatically selects record vs replay mode) # Create transport (automatically selects record vs replay mode)
transport = TransportFactory.create_transport(str(cassette_path)) transport = TransportFactory.create_transport(str(cassette_path))
# Get provider and inject custom transport # Get provider and inject custom transport
provider = ModelProviderRegistry.get_provider_for_model("o3-pro") provider = ModelProviderRegistry.get_provider_for_model("o3-pro")
if not provider: if not provider:
self.fail("OpenAI provider not available for o3-pro model") self.fail("OpenAI provider not available for o3-pro model")
# Inject transport for this test # Inject transport for this test
original_transport = getattr(provider, '_test_transport', None) original_transport = getattr(provider, "_test_transport", None)
provider._test_transport = transport provider._test_transport = transport
try: try:
# Execute ChatTool test with custom transport # Execute ChatTool test with custom transport
result = await self._execute_chat_tool_test() result = await self._execute_chat_tool_test()
# Verify the response works correctly # Verify the response works correctly
self._verify_chat_tool_response(result) self._verify_chat_tool_response(result)
# Verify cassette was created/used # Verify cassette was created/used
if not cassette_path.exists(): if not cassette_path.exists():
self.fail(f"Cassette should exist at {cassette_path}") self.fail(f"Cassette should exist at {cassette_path}")
print(f"✅ HTTP transport {'recorded' if isinstance(transport, type(transport).__bases__[0]) else 'replayed'} o3-pro interaction") print(
f"✅ HTTP transport {'recorded' if isinstance(transport, type(transport).__bases__[0]) else 'replayed'} o3-pro interaction"
)
finally: finally:
# Restore original transport (if any) # Restore original transport (if any)
if original_transport: if original_transport:
provider._test_transport = original_transport provider._test_transport = original_transport
elif hasattr(provider, '_test_transport'): elif hasattr(provider, "_test_transport"):
delattr(provider, '_test_transport') delattr(provider, "_test_transport")
async def _execute_chat_tool_test(self): async def _execute_chat_tool_test(self):
"""Execute the ChatTool with o3-pro and return the result.""" """Execute the ChatTool with o3-pro and return the result."""

View File

@@ -230,10 +230,8 @@ class TestOpenAIProvider:
mock_openai_class.return_value = mock_client mock_openai_class.return_value = mock_client
mock_response = MagicMock() mock_response = MagicMock()
mock_response.output = MagicMock() # New o3-pro format: direct output_text field
mock_response.output.content = [MagicMock()] mock_response.output_text = "4"
mock_response.output.content[0].type = "output_text"
mock_response.output.content[0].text = "4"
mock_response.model = "o3-pro-2025-06-10" mock_response.model = "o3-pro-2025-06-10"
mock_response.id = "test-id" mock_response.id = "test-id"
mock_response.created_at = 1234567890 mock_response.created_at = 1234567890

View File

@@ -2,64 +2,59 @@
"""Test cases for PII sanitizer.""" """Test cases for PII sanitizer."""
import unittest import unittest
from tests.pii_sanitizer import PIISanitizer, PIIPattern
from tests.pii_sanitizer import PIIPattern, PIISanitizer
class TestPIISanitizer(unittest.TestCase): class TestPIISanitizer(unittest.TestCase):
"""Test PII sanitization functionality.""" """Test PII sanitization functionality."""
def setUp(self): def setUp(self):
"""Set up test sanitizer.""" """Set up test sanitizer."""
self.sanitizer = PIISanitizer() self.sanitizer = PIISanitizer()
def test_api_key_sanitization(self): def test_api_key_sanitization(self):
"""Test various API key formats are sanitized.""" """Test various API key formats are sanitized."""
test_cases = [ test_cases = [
# OpenAI keys # OpenAI keys
("sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12", "sk-proj-SANITIZED"), ("sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12", "sk-proj-SANITIZED"),
("sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN", "sk-SANITIZED"), ("sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN", "sk-SANITIZED"),
# Anthropic keys # Anthropic keys
("sk-ant-abcd1234567890ABCD1234567890abcd1234567890ABCD12", "sk-ant-SANITIZED"), ("sk-ant-abcd1234567890ABCD1234567890abcd1234567890ABCD12", "sk-ant-SANITIZED"),
# Google keys # Google keys
("AIzaSyD-1234567890abcdefghijklmnopqrstuv", "AIza-SANITIZED"), ("AIzaSyD-1234567890abcdefghijklmnopqrstuv", "AIza-SANITIZED"),
# GitHub tokens # GitHub tokens
("ghp_1234567890abcdefghijklmnopqrstuvwxyz", "gh_SANITIZED"), ("ghp_1234567890abcdefghijklmnopqrstuvwxyz", "gh_SANITIZED"),
("ghs_1234567890abcdefghijklmnopqrstuvwxyz", "gh_SANITIZED"), ("ghs_1234567890abcdefghijklmnopqrstuvwxyz", "gh_SANITIZED"),
] ]
for original, expected in test_cases: for original, expected in test_cases:
with self.subTest(original=original): with self.subTest(original=original):
result = self.sanitizer.sanitize_string(original) result = self.sanitizer.sanitize_string(original)
self.assertEqual(result, expected) self.assertEqual(result, expected)
def test_personal_info_sanitization(self): def test_personal_info_sanitization(self):
"""Test personal information is sanitized.""" """Test personal information is sanitized."""
test_cases = [ test_cases = [
# Email addresses # Email addresses
("john.doe@example.com", "user@example.com"), ("john.doe@example.com", "user@example.com"),
("test123@company.org", "user@example.com"), ("test123@company.org", "user@example.com"),
# Phone numbers (all now use the same pattern) # Phone numbers (all now use the same pattern)
("(555) 123-4567", "(XXX) XXX-XXXX"), ("(555) 123-4567", "(XXX) XXX-XXXX"),
("555-123-4567", "(XXX) XXX-XXXX"), ("555-123-4567", "(XXX) XXX-XXXX"),
("+1-555-123-4567", "(XXX) XXX-XXXX"), ("+1-555-123-4567", "(XXX) XXX-XXXX"),
# SSN # SSN
("123-45-6789", "XXX-XX-XXXX"), ("123-45-6789", "XXX-XX-XXXX"),
# Credit card # Credit card
("1234 5678 9012 3456", "XXXX-XXXX-XXXX-XXXX"), ("1234 5678 9012 3456", "XXXX-XXXX-XXXX-XXXX"),
("1234-5678-9012-3456", "XXXX-XXXX-XXXX-XXXX"), ("1234-5678-9012-3456", "XXXX-XXXX-XXXX-XXXX"),
] ]
for original, expected in test_cases: for original, expected in test_cases:
with self.subTest(original=original): with self.subTest(original=original):
result = self.sanitizer.sanitize_string(original) result = self.sanitizer.sanitize_string(original)
self.assertEqual(result, expected) self.assertEqual(result, expected)
def test_header_sanitization(self): def test_header_sanitization(self):
"""Test HTTP header sanitization.""" """Test HTTP header sanitization."""
headers = { headers = {
@@ -67,84 +62,82 @@ class TestPIISanitizer(unittest.TestCase):
"API-Key": "sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN", "API-Key": "sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN",
"Content-Type": "application/json", "Content-Type": "application/json",
"User-Agent": "MyApp/1.0", "User-Agent": "MyApp/1.0",
"Cookie": "session=abc123; user=john.doe@example.com" "Cookie": "session=abc123; user=john.doe@example.com",
} }
sanitized = self.sanitizer.sanitize_headers(headers) sanitized = self.sanitizer.sanitize_headers(headers)
self.assertEqual(sanitized["Authorization"], "Bearer SANITIZED") self.assertEqual(sanitized["Authorization"], "Bearer SANITIZED")
self.assertEqual(sanitized["API-Key"], "sk-SANITIZED") self.assertEqual(sanitized["API-Key"], "sk-SANITIZED")
self.assertEqual(sanitized["Content-Type"], "application/json") self.assertEqual(sanitized["Content-Type"], "application/json")
self.assertEqual(sanitized["User-Agent"], "MyApp/1.0") self.assertEqual(sanitized["User-Agent"], "MyApp/1.0")
self.assertIn("user@example.com", sanitized["Cookie"]) self.assertIn("user@example.com", sanitized["Cookie"])
def test_nested_structure_sanitization(self): def test_nested_structure_sanitization(self):
"""Test sanitization of nested data structures.""" """Test sanitization of nested data structures."""
data = { data = {
"user": { "user": {
"email": "john.doe@example.com", "email": "john.doe@example.com",
"api_key": "sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12" "api_key": "sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12",
}, },
"tokens": [ "tokens": [
"ghp_1234567890abcdefghijklmnopqrstuvwxyz", "ghp_1234567890abcdefghijklmnopqrstuvwxyz",
"Bearer sk-ant-abcd1234567890ABCD1234567890abcd1234567890ABCD12" "Bearer sk-ant-abcd1234567890ABCD1234567890abcd1234567890ABCD12",
], ],
"metadata": { "metadata": {"ip": "192.168.1.100", "phone": "(555) 123-4567"},
"ip": "192.168.1.100",
"phone": "(555) 123-4567"
}
} }
sanitized = self.sanitizer.sanitize_value(data) sanitized = self.sanitizer.sanitize_value(data)
self.assertEqual(sanitized["user"]["email"], "user@example.com") self.assertEqual(sanitized["user"]["email"], "user@example.com")
self.assertEqual(sanitized["user"]["api_key"], "sk-proj-SANITIZED") self.assertEqual(sanitized["user"]["api_key"], "sk-proj-SANITIZED")
self.assertEqual(sanitized["tokens"][0], "gh_SANITIZED") self.assertEqual(sanitized["tokens"][0], "gh_SANITIZED")
self.assertEqual(sanitized["tokens"][1], "Bearer sk-ant-SANITIZED") self.assertEqual(sanitized["tokens"][1], "Bearer sk-ant-SANITIZED")
self.assertEqual(sanitized["metadata"]["ip"], "0.0.0.0") self.assertEqual(sanitized["metadata"]["ip"], "0.0.0.0")
self.assertEqual(sanitized["metadata"]["phone"], "(XXX) XXX-XXXX") self.assertEqual(sanitized["metadata"]["phone"], "(XXX) XXX-XXXX")
def test_url_sanitization(self): def test_url_sanitization(self):
"""Test URL parameter sanitization.""" """Test URL parameter sanitization."""
urls = [ urls = [
("https://api.example.com/v1/users?api_key=sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN", (
"https://api.example.com/v1/users?api_key=SANITIZED"), "https://api.example.com/v1/users?api_key=sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN",
("https://example.com/login?token=ghp_1234567890abcdefghijklmnopqrstuvwxyz&user=test", "https://api.example.com/v1/users?api_key=SANITIZED",
"https://example.com/login?token=SANITIZED&user=test"), ),
(
"https://example.com/login?token=ghp_1234567890abcdefghijklmnopqrstuvwxyz&user=test",
"https://example.com/login?token=SANITIZED&user=test",
),
] ]
for original, expected in urls: for original, expected in urls:
with self.subTest(url=original): with self.subTest(url=original):
result = self.sanitizer.sanitize_url(original) result = self.sanitizer.sanitize_url(original)
self.assertEqual(result, expected) self.assertEqual(result, expected)
def test_disable_sanitization(self): def test_disable_sanitization(self):
"""Test that sanitization can be disabled.""" """Test that sanitization can be disabled."""
self.sanitizer.sanitize_enabled = False self.sanitizer.sanitize_enabled = False
sensitive_data = "sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12" sensitive_data = "sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12"
result = self.sanitizer.sanitize_string(sensitive_data) result = self.sanitizer.sanitize_string(sensitive_data)
# Should return original when disabled # Should return original when disabled
self.assertEqual(result, sensitive_data) self.assertEqual(result, sensitive_data)
def test_custom_pattern(self): def test_custom_pattern(self):
"""Test adding custom PII patterns.""" """Test adding custom PII patterns."""
# Add custom pattern for internal employee IDs # Add custom pattern for internal employee IDs
custom_pattern = PIIPattern.create( custom_pattern = PIIPattern.create(
name="employee_id", name="employee_id", pattern=r"EMP\d{6}", replacement="EMP-REDACTED", description="Internal employee IDs"
pattern=r'EMP\d{6}',
replacement="EMP-REDACTED",
description="Internal employee IDs"
) )
self.sanitizer.add_pattern(custom_pattern) self.sanitizer.add_pattern(custom_pattern)
text = "Employee EMP123456 has access to the system" text = "Employee EMP123456 has access to the system"
result = self.sanitizer.sanitize_string(text) result = self.sanitizer.sanitize_string(text)
self.assertEqual(result, "Employee EMP-REDACTED has access to the system") self.assertEqual(result, "Employee EMP-REDACTED has access to the system")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()