fix: Resolve o3-pro response parsing and test execution issues

- Fix lint errors: trailing whitespace and deprecated typing imports
- Update test mock for o3-pro response format (output.content[] → output_text)
- Implement robust test isolation with monkeypatch fixture
- Clear provider registry cache to prevent test interference
- Ensure o3-pro tests pass in both individual and full suite execution

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Josh Vera
2025-07-12 20:24:34 -06:00
parent ae5e43b792
commit 3db49413ff
8 changed files with 328 additions and 320 deletions

View File

@@ -221,7 +221,7 @@ class OpenAICompatibleProvider(ModelProvider):
# Create httpx client with minimal config to avoid proxy conflicts
# Note: proxies parameter was removed in httpx 0.28.0
# Check for test transport injection
if hasattr(self, '_test_transport'):
if hasattr(self, "_test_transport"):
# Use custom transport for testing (HTTP recording/replay)
http_client = httpx.Client(
transport=self._test_transport,
@@ -318,13 +318,13 @@ class OpenAICompatibleProvider(ModelProvider):
"""
logging.debug(f"Response object type: {type(response)}")
logging.debug(f"Response attributes: {dir(response)}")
if not hasattr(response, "output_text"):
raise ValueError(f"o3-pro response missing output_text field. Response type: {type(response).__name__}")
content = response.output_text
logging.debug(f"Extracted output_text: '{content}' (type: {type(content)})")
if content is None:
raise ValueError("o3-pro returned None for output_text")

View File

@@ -93,7 +93,7 @@ def pytest_collection_modifyitems(session, config, items):
if item.get_closest_marker("no_mock_provider"):
config._needs_dummy_keys = False
break
# Set dummy keys only if no test needs real keys
if config._needs_dummy_keys:
_set_dummy_keys_if_missing()

View File

@@ -2,7 +2,7 @@
"""
HTTP Transport Recorder for O3-Pro Testing
Custom httpx transport solution that replaces respx for recording/replaying
Custom httpx transport solution that replaces respx for recording/replaying
HTTP interactions. Provides full control over the recording process without
respx limitations.
@@ -13,40 +13,40 @@ Key Features:
- JSON cassette format with data sanitization
"""
import json
import hashlib
import copy
import base64
import copy
import hashlib
import json
from pathlib import Path
from typing import Dict, Any, Optional
import httpx
from io import BytesIO
from .pii_sanitizer import PIISanitizer
from typing import Any, Optional
import httpx
from .pii_sanitizer import PIISanitizer
class RecordingTransport(httpx.HTTPTransport):
"""Transport that wraps default httpx transport and records all interactions."""
def __init__(self, cassette_path: str, capture_content: bool = True, sanitize: bool = True):
super().__init__()
self.cassette_path = Path(cassette_path)
self.recorded_interactions = []
self.capture_content = capture_content
self.sanitizer = PIISanitizer() if sanitize else None
def handle_request(self, request: httpx.Request) -> httpx.Response:
"""Handle request by recording interaction and delegating to real transport."""
print(f"🎬 RecordingTransport: Making request to {request.method} {request.url}")
# Record request BEFORE making the call
request_data = self._serialize_request(request)
# Make real HTTP call using parent transport
response = super().handle_request(request)
print(f"🎬 RecordingTransport: Got response {response.status_code}")
# Post-response content capture (proper approach)
if self.capture_content:
try:
@@ -55,19 +55,20 @@ class RecordingTransport(httpx.HTTPTransport):
content_bytes = response.read()
response.close() # Close the original stream
print(f"🎬 RecordingTransport: Captured {len(content_bytes)} bytes of decompressed content")
# Serialize response with captured content
response_data = self._serialize_response_with_content(response, content_bytes)
# Create a new response with the same metadata but buffered content
# If the original response was gzipped, we need to re-compress
response_content = content_bytes
if response.headers.get('content-encoding') == 'gzip':
if response.headers.get("content-encoding") == "gzip":
import gzip
print(f"🗜️ Re-compressing {len(content_bytes)} bytes with gzip...")
response_content = gzip.compress(content_bytes)
print(f"🗜️ Compressed to {len(response_content)} bytes")
new_response = httpx.Response(
status_code=response.status_code,
headers=response.headers, # Keep original headers intact
@@ -76,15 +77,16 @@ class RecordingTransport(httpx.HTTPTransport):
extensions=response.extensions,
history=response.history,
)
# Record the interaction
self._record_interaction(request_data, response_data)
return new_response
except Exception as e:
print(f"⚠️ Content capture failed: {e}, falling back to stub")
import traceback
print(f"⚠️ Full exception traceback:\n{traceback.format_exc()}")
response_data = self._serialize_response(response)
self._record_interaction(request_data, response_data)
@@ -94,105 +96,99 @@ class RecordingTransport(httpx.HTTPTransport):
response_data = self._serialize_response(response)
self._record_interaction(request_data, response_data)
return response
def _record_interaction(self, request_data: Dict[str, Any], response_data: Dict[str, Any]):
def _record_interaction(self, request_data: dict[str, Any], response_data: dict[str, Any]):
"""Helper method to record interaction and save cassette."""
interaction = {
"request": request_data,
"response": response_data
}
interaction = {"request": request_data, "response": response_data}
self.recorded_interactions.append(interaction)
self._save_cassette()
print(f"🎬 RecordingTransport: Saved cassette to {self.cassette_path}")
def _serialize_request(self, request: httpx.Request) -> Dict[str, Any]:
def _serialize_request(self, request: httpx.Request) -> dict[str, Any]:
"""Serialize httpx.Request to JSON-compatible format."""
# For requests, we can safely read the content since it's already been prepared
# httpx.Request.content is safe to access multiple times
content = request.content
# Convert bytes to string for JSON serialization
if isinstance(content, bytes):
try:
content_str = content.decode('utf-8')
content_str = content.decode("utf-8")
except UnicodeDecodeError:
# Handle binary content (shouldn't happen for o3-pro API)
content_str = content.hex()
else:
content_str = str(content) if content else ""
request_data = {
"method": request.method,
"url": str(request.url),
"path": request.url.path,
"headers": dict(request.headers),
"content": self._sanitize_request_content(content_str)
"content": self._sanitize_request_content(content_str),
}
# Apply PII sanitization if enabled
if self.sanitizer:
request_data = self.sanitizer.sanitize_request(request_data)
return request_data
def _serialize_response(self, response: httpx.Response) -> Dict[str, Any]:
def _serialize_response(self, response: httpx.Response) -> dict[str, Any]:
"""Serialize httpx.Response to JSON-compatible format (legacy method without content)."""
# Legacy method for backward compatibility when content capture is disabled
return {
"status_code": response.status_code,
"headers": dict(response.headers),
"content": {"note": "Response content not recorded to avoid httpx.ResponseNotRead exception"},
"reason_phrase": response.reason_phrase
"reason_phrase": response.reason_phrase,
}
def _serialize_response_with_content(self, response: httpx.Response, content_bytes: bytes) -> Dict[str, Any]:
def _serialize_response_with_content(self, response: httpx.Response, content_bytes: bytes) -> dict[str, Any]:
"""Serialize httpx.Response with captured content."""
try:
# Debug: check what we got
print(f"🔍 Content type: {type(content_bytes)}, size: {len(content_bytes)}")
print(f"🔍 First 100 chars: {content_bytes[:100]}")
# Ensure we have bytes for base64 encoding
if not isinstance(content_bytes, bytes):
print(f"⚠️ Content is not bytes, converting from {type(content_bytes)}")
if isinstance(content_bytes, str):
content_bytes = content_bytes.encode('utf-8')
content_bytes = content_bytes.encode("utf-8")
else:
content_bytes = str(content_bytes).encode('utf-8')
content_bytes = str(content_bytes).encode("utf-8")
# Encode content as base64 for JSON storage
print(f"🔍 Base64 encoding {len(content_bytes)} bytes...")
content_b64 = base64.b64encode(content_bytes).decode('utf-8')
content_b64 = base64.b64encode(content_bytes).decode("utf-8")
print(f"✅ Base64 encoded successfully, result length: {len(content_b64)}")
response_data = {
"status_code": response.status_code,
"headers": dict(response.headers),
"content": {
"data": content_b64,
"encoding": "base64",
"size": len(content_bytes)
},
"reason_phrase": response.reason_phrase
"content": {"data": content_b64, "encoding": "base64", "size": len(content_bytes)},
"reason_phrase": response.reason_phrase,
}
# Apply PII sanitization if enabled
if self.sanitizer:
response_data = self.sanitizer.sanitize_response(response_data)
return response_data
except Exception as e:
print(f"🔍 Error in _serialize_response_with_content: {e}")
import traceback
print(f"🔍 Full traceback: {traceback.format_exc()}")
# Fall back to minimal info
return {
"status_code": response.status_code,
"headers": dict(response.headers),
"content": {"error": f"Failed to serialize content: {e}"},
"reason_phrase": response.reason_phrase
"reason_phrase": response.reason_phrase,
}
def _sanitize_request_content(self, content: str) -> Any:
"""Sanitize request content to remove sensitive data."""
try:
@@ -203,14 +199,14 @@ class RecordingTransport(httpx.HTTPTransport):
except json.JSONDecodeError:
pass
return content
def _sanitize_response_content(self, data: Any) -> Any:
"""Sanitize response content to remove sensitive data."""
if not isinstance(data, dict):
return data
sanitized = copy.deepcopy(data)
# Sensitive fields to sanitize
sensitive_fields = {
"id": "resp_SANITIZED",
@@ -218,7 +214,7 @@ class RecordingTransport(httpx.HTTPTransport):
"created_at": 0,
"system_fingerprint": "fp_SANITIZED",
}
def sanitize_dict(obj):
if isinstance(obj, dict):
for key, value in obj.items():
@@ -230,82 +226,76 @@ class RecordingTransport(httpx.HTTPTransport):
for item in obj:
if isinstance(item, (dict, list)):
sanitize_dict(item)
sanitize_dict(sanitized)
return sanitized
def _save_cassette(self):
"""Save recorded interactions to cassette file."""
# Ensure directory exists
self.cassette_path.parent.mkdir(parents=True, exist_ok=True)
# Save cassette
cassette_data = {
"interactions": self.recorded_interactions
}
self.cassette_path.write_text(
json.dumps(cassette_data, indent=2, sort_keys=True)
)
cassette_data = {"interactions": self.recorded_interactions}
self.cassette_path.write_text(json.dumps(cassette_data, indent=2, sort_keys=True))
class ReplayTransport(httpx.MockTransport):
"""Transport that replays saved HTTP interactions from cassettes."""
def __init__(self, cassette_path: str):
self.cassette_path = Path(cassette_path)
self.interactions = self._load_cassette()
super().__init__(self._handle_request)
def _load_cassette(self) -> list:
"""Load interactions from cassette file."""
if not self.cassette_path.exists():
raise FileNotFoundError(f"Cassette file not found: {self.cassette_path}")
try:
cassette_data = json.loads(self.cassette_path.read_text())
return cassette_data.get("interactions", [])
except json.JSONDecodeError as e:
raise ValueError(f"Invalid cassette file format: {e}")
def _handle_request(self, request: httpx.Request) -> httpx.Response:
"""Handle request by finding matching interaction and returning saved response."""
print(f"🔍 ReplayTransport: Looking for {request.method} {request.url}")
# Debug: show what we're trying to match
request_signature = self._get_request_signature(request)
print(f"🔍 Request signature: {request_signature}")
# Debug: show actual request content
content = request.content
if hasattr(content, 'read'):
if hasattr(content, "read"):
content = content.read()
if isinstance(content, bytes):
content_str = content.decode('utf-8', errors='ignore')
content_str = content.decode("utf-8", errors="ignore")
else:
content_str = str(content) if content else ""
print(f"🔍 Actual request content: {content_str}")
# Debug: show available signatures
for i, interaction in enumerate(self.interactions):
saved_signature = self._get_saved_request_signature(interaction["request"])
saved_content = interaction["request"].get("content", {})
print(f"🔍 Available signature {i}: {saved_signature}")
print(f"🔍 Saved content {i}: {saved_content}")
# Find matching interaction
interaction = self._find_matching_interaction(request)
if not interaction:
print("🚨 MYSTERY SOLVED: No matching interaction found! This should fail...")
raise ValueError(
f"No matching interaction found for {request.method} {request.url}"
)
print(f"✅ Found matching interaction from cassette!")
raise ValueError(f"No matching interaction found for {request.method} {request.url}")
print("✅ Found matching interaction from cassette!")
# Build response from saved data
response_data = interaction["response"]
# Convert content back to appropriate format
content = response_data.get("content", {})
if isinstance(content, dict):
@@ -317,55 +307,56 @@ class ReplayTransport(httpx.MockTransport):
print(f"🎬 ReplayTransport: Decoded {len(content_bytes)} bytes from base64")
except Exception as e:
print(f"⚠️ Failed to decode base64 content: {e}")
content_bytes = json.dumps(content).encode('utf-8')
content_bytes = json.dumps(content).encode("utf-8")
else:
# Legacy format or stub content
content_bytes = json.dumps(content).encode('utf-8')
content_bytes = json.dumps(content).encode("utf-8")
else:
content_bytes = str(content).encode('utf-8')
content_bytes = str(content).encode("utf-8")
# Check if response expects gzipped content
headers = response_data.get("headers", {})
if headers.get('content-encoding') == 'gzip':
if headers.get("content-encoding") == "gzip":
# Re-compress the content for httpx
import gzip
print(f"🗜️ ReplayTransport: Re-compressing {len(content_bytes)} bytes with gzip...")
content_bytes = gzip.compress(content_bytes)
print(f"🗜️ ReplayTransport: Compressed to {len(content_bytes)} bytes")
print(f"🎬 ReplayTransport: Returning cassette response with content: {content_bytes[:100]}...")
# Create httpx.Response
return httpx.Response(
status_code=response_data["status_code"],
headers=response_data.get("headers", {}),
content=content_bytes,
request=request
request=request,
)
def _find_matching_interaction(self, request: httpx.Request) -> Optional[Dict[str, Any]]:
def _find_matching_interaction(self, request: httpx.Request) -> Optional[dict[str, Any]]:
"""Find interaction that matches the request."""
request_signature = self._get_request_signature(request)
for interaction in self.interactions:
saved_signature = self._get_saved_request_signature(interaction["request"])
if request_signature == saved_signature:
return interaction
return None
def _get_request_signature(self, request: httpx.Request) -> str:
"""Generate signature for request matching."""
# Use method, path, and content hash for matching
content = request.content
if hasattr(content, 'read'):
if hasattr(content, "read"):
content = content.read()
if isinstance(content, bytes):
content_str = content.decode('utf-8', errors='ignore')
content_str = content.decode("utf-8", errors="ignore")
else:
content_str = str(content) if content else ""
# Parse JSON and re-serialize with sorted keys for consistent hashing
try:
if content_str.strip():
@@ -374,37 +365,37 @@ class ReplayTransport(httpx.MockTransport):
except json.JSONDecodeError:
# Not JSON, use as-is
pass
# Create hash of content for stable matching
content_hash = hashlib.md5(content_str.encode()).hexdigest()
return f"{request.method}:{request.url.path}:{content_hash}"
def _get_saved_request_signature(self, saved_request: Dict[str, Any]) -> str:
def _get_saved_request_signature(self, saved_request: dict[str, Any]) -> str:
"""Generate signature for saved request."""
method = saved_request["method"]
path = saved_request["path"]
# Hash the saved content
content = saved_request.get("content", "")
if isinstance(content, dict):
content_str = json.dumps(content, sort_keys=True)
else:
content_str = str(content)
content_hash = hashlib.md5(content_str.encode()).hexdigest()
return f"{method}:{path}:{content_hash}"
class TransportFactory:
"""Factory for creating appropriate transport based on cassette availability."""
@staticmethod
def create_transport(cassette_path: str) -> httpx.HTTPTransport:
"""Create transport based on cassette existence and API key availability."""
cassette_file = Path(cassette_path)
# Check if we should record or replay
if cassette_file.exists():
# Cassette exists - use replay mode
@@ -413,15 +404,15 @@ class TransportFactory:
# No cassette - use recording mode
# Note: We'll check for API key in the test itself
return RecordingTransport(cassette_path)
@staticmethod
def should_record(cassette_path: str, api_key: Optional[str] = None) -> bool:
"""Determine if we should record based on cassette and API key availability."""
cassette_file = Path(cassette_path)
# Record if cassette doesn't exist AND we have API key
return not cassette_file.exists() and bool(api_key)
@staticmethod
def should_replay(cassette_path: str) -> bool:
"""Determine if we should replay based on cassette availability."""
@@ -434,8 +425,8 @@ class TransportFactory:
# # In test setup:
# cassette_path = "tests/cassettes/o3_pro_basic_math.json"
# transport = TransportFactory.create_transport(cassette_path)
#
#
# # Inject into OpenAI client:
# provider._test_transport = transport
#
# # The provider's client property will detect _test_transport and use it
#
# # The provider's client property will detect _test_transport and use it

View File

@@ -7,11 +7,12 @@ request/response recordings to prevent accidental exposure of API keys,
tokens, personal information, and other sensitive data.
"""
import re
from typing import Any, Dict, List, Optional, Pattern
from dataclasses import dataclass
from copy import deepcopy
import logging
import re
from copy import deepcopy
from dataclasses import dataclass
from re import Pattern
from typing import Any, Optional
logger = logging.getLogger(__name__)
@@ -19,178 +20,170 @@ logger = logging.getLogger(__name__)
@dataclass
class PIIPattern:
"""Defines a pattern for detecting and sanitizing PII."""
name: str
pattern: Pattern[str]
replacement: str
description: str
@classmethod
def create(cls, name: str, pattern: str, replacement: str, description: str) -> 'PIIPattern':
def create(cls, name: str, pattern: str, replacement: str, description: str) -> "PIIPattern":
"""Create a PIIPattern with compiled regex."""
return cls(
name=name,
pattern=re.compile(pattern),
replacement=replacement,
description=description
)
return cls(name=name, pattern=re.compile(pattern), replacement=replacement, description=description)
class PIISanitizer:
"""Sanitizes PII from various data structures while preserving format."""
def __init__(self, patterns: Optional[List[PIIPattern]] = None):
def __init__(self, patterns: Optional[list[PIIPattern]] = None):
"""Initialize with optional custom patterns."""
self.patterns: List[PIIPattern] = patterns or []
self.patterns: list[PIIPattern] = patterns or []
self.sanitize_enabled = True
# Add default patterns if none provided
if not patterns:
self._add_default_patterns()
def _add_default_patterns(self):
"""Add comprehensive default PII patterns."""
default_patterns = [
# API Keys - Core patterns (Bearer tokens handled in sanitize_headers)
PIIPattern.create(
name="openai_api_key_proj",
pattern=r'sk-proj-[A-Za-z0-9\-_]{48,}',
pattern=r"sk-proj-[A-Za-z0-9\-_]{48,}",
replacement="sk-proj-SANITIZED",
description="OpenAI project API keys"
description="OpenAI project API keys",
),
PIIPattern.create(
name="openai_api_key",
pattern=r'sk-[A-Za-z0-9]{48,}',
pattern=r"sk-[A-Za-z0-9]{48,}",
replacement="sk-SANITIZED",
description="OpenAI API keys"
description="OpenAI API keys",
),
PIIPattern.create(
name="anthropic_api_key",
pattern=r'sk-ant-[A-Za-z0-9\-_]{48,}',
pattern=r"sk-ant-[A-Za-z0-9\-_]{48,}",
replacement="sk-ant-SANITIZED",
description="Anthropic API keys"
description="Anthropic API keys",
),
PIIPattern.create(
name="google_api_key",
pattern=r'AIza[A-Za-z0-9\-_]{35,}',
pattern=r"AIza[A-Za-z0-9\-_]{35,}",
replacement="AIza-SANITIZED",
description="Google API keys"
description="Google API keys",
),
PIIPattern.create(
name="github_tokens",
pattern=r'gh[psr]_[A-Za-z0-9]{36}',
pattern=r"gh[psr]_[A-Za-z0-9]{36}",
replacement="gh_SANITIZED",
description="GitHub tokens (all types)"
description="GitHub tokens (all types)",
),
# JWT tokens
PIIPattern.create(
name="jwt_token",
pattern=r'eyJ[A-Za-z0-9\-_]+\.eyJ[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_]+',
pattern=r"eyJ[A-Za-z0-9\-_]+\.eyJ[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_]+",
replacement="eyJ-SANITIZED",
description="JSON Web Tokens"
description="JSON Web Tokens",
),
# Personal Information
PIIPattern.create(
name="email_address",
pattern=r'[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}',
pattern=r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}",
replacement="user@example.com",
description="Email addresses"
description="Email addresses",
),
PIIPattern.create(
name="ipv4_address",
pattern=r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b',
pattern=r"\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b",
replacement="0.0.0.0",
description="IPv4 addresses"
description="IPv4 addresses",
),
PIIPattern.create(
name="ssn",
pattern=r'\b\d{3}-\d{2}-\d{4}\b',
pattern=r"\b\d{3}-\d{2}-\d{4}\b",
replacement="XXX-XX-XXXX",
description="Social Security Numbers"
description="Social Security Numbers",
),
PIIPattern.create(
name="credit_card",
pattern=r'\b\d{4}[\s\-]?\d{4}[\s\-]?\d{4}[\s\-]?\d{4}\b',
pattern=r"\b\d{4}[\s\-]?\d{4}[\s\-]?\d{4}[\s\-]?\d{4}\b",
replacement="XXXX-XXXX-XXXX-XXXX",
description="Credit card numbers"
description="Credit card numbers",
),
PIIPattern.create(
name="phone_number",
pattern=r'(?:\+\d{1,3}[\s\-]?)?\(?\d{3}\)?[\s\-]?\d{3}[\s\-]?\d{4}',
pattern=r"(?:\+\d{1,3}[\s\-]?)?\(?\d{3}\)?[\s\-]?\d{3}[\s\-]?\d{4}",
replacement="(XXX) XXX-XXXX",
description="Phone numbers (all formats)"
description="Phone numbers (all formats)",
),
# AWS
PIIPattern.create(
name="aws_access_key",
pattern=r'AKIA[0-9A-Z]{16}',
pattern=r"AKIA[0-9A-Z]{16}",
replacement="AKIA-SANITIZED",
description="AWS access keys"
description="AWS access keys",
),
# Other common patterns
PIIPattern.create(
name="slack_token",
pattern=r'xox[baprs]-[0-9]{10,13}-[0-9]{10,13}-[a-zA-Z0-9]{24,34}',
pattern=r"xox[baprs]-[0-9]{10,13}-[0-9]{10,13}-[a-zA-Z0-9]{24,34}",
replacement="xox-SANITIZED",
description="Slack tokens"
description="Slack tokens",
),
PIIPattern.create(
name="stripe_key",
pattern=r'(?:sk|pk)_(?:test|live)_[0-9a-zA-Z]{24,99}',
pattern=r"(?:sk|pk)_(?:test|live)_[0-9a-zA-Z]{24,99}",
replacement="sk_SANITIZED",
description="Stripe API keys"
description="Stripe API keys",
),
]
self.patterns.extend(default_patterns)
def add_pattern(self, pattern: PIIPattern):
"""Add a custom PII pattern."""
self.patterns.append(pattern)
logger.info(f"Added PII pattern: {pattern.name}")
def sanitize_string(self, text: str) -> str:
"""Apply all patterns to sanitize a string."""
if not self.sanitize_enabled or not isinstance(text, str):
return text
sanitized = text
for pattern in self.patterns:
if pattern.pattern.search(sanitized):
sanitized = pattern.pattern.sub(pattern.replacement, sanitized)
logger.debug(f"Applied {pattern.name} sanitization")
return sanitized
def sanitize_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
def sanitize_headers(self, headers: dict[str, str]) -> dict[str, str]:
"""Special handling for HTTP headers."""
if not self.sanitize_enabled:
return headers
sanitized_headers = {}
for key, value in headers.items():
# Special case for Authorization headers to preserve auth type
if key.lower() == 'authorization' and ' ' in value:
auth_type = value.split(' ', 1)[0]
if auth_type in ('Bearer', 'Basic'):
sanitized_headers[key] = f'{auth_type} SANITIZED'
if key.lower() == "authorization" and " " in value:
auth_type = value.split(" ", 1)[0]
if auth_type in ("Bearer", "Basic"):
sanitized_headers[key] = f"{auth_type} SANITIZED"
else:
sanitized_headers[key] = self.sanitize_string(value)
else:
# Apply standard sanitization to all other headers
sanitized_headers[key] = self.sanitize_string(value)
return sanitized_headers
def sanitize_value(self, value: Any) -> Any:
"""Recursively sanitize any value (string, dict, list, etc)."""
if not self.sanitize_enabled:
return value
if isinstance(value, str):
return self.sanitize_string(value)
elif isinstance(value, dict):
@@ -202,25 +195,25 @@ class PIISanitizer:
else:
# For other types (int, float, bool, None), return as-is
return value
def sanitize_url(self, url: str) -> str:
"""Sanitize sensitive data from URLs (query params, etc)."""
if not self.sanitize_enabled:
return url
# First apply general string sanitization
url = self.sanitize_string(url)
# Parse and sanitize query parameters
if '?' in url:
base, query = url.split('?', 1)
if "?" in url:
base, query = url.split("?", 1)
params = []
for param in query.split('&'):
if '=' in param:
key, value = param.split('=', 1)
for param in query.split("&"):
if "=" in param:
key, value = param.split("=", 1)
# Sanitize common sensitive parameter names
sensitive_params = {'key', 'token', 'api_key', 'secret', 'password'}
sensitive_params = {"key", "token", "api_key", "secret", "password"}
if key.lower() in sensitive_params:
params.append(f"{key}=SANITIZED")
else:
@@ -228,54 +221,53 @@ class PIISanitizer:
params.append(f"{key}={self.sanitize_string(value)}")
else:
params.append(param)
return f"{base}?{'&'.join(params)}"
return url
def sanitize_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
def sanitize_request(self, request_data: dict[str, Any]) -> dict[str, Any]:
"""Sanitize a complete request dictionary."""
sanitized = deepcopy(request_data)
# Sanitize headers
if 'headers' in sanitized:
sanitized['headers'] = self.sanitize_headers(sanitized['headers'])
if "headers" in sanitized:
sanitized["headers"] = self.sanitize_headers(sanitized["headers"])
# Sanitize URL
if 'url' in sanitized:
sanitized['url'] = self.sanitize_url(sanitized['url'])
if "url" in sanitized:
sanitized["url"] = self.sanitize_url(sanitized["url"])
# Sanitize content
if 'content' in sanitized:
sanitized['content'] = self.sanitize_value(sanitized['content'])
if "content" in sanitized:
sanitized["content"] = self.sanitize_value(sanitized["content"])
return sanitized
def sanitize_response(self, response_data: Dict[str, Any]) -> Dict[str, Any]:
def sanitize_response(self, response_data: dict[str, Any]) -> dict[str, Any]:
"""Sanitize a complete response dictionary."""
sanitized = deepcopy(response_data)
# Sanitize headers
if 'headers' in sanitized:
sanitized['headers'] = self.sanitize_headers(sanitized['headers'])
if "headers" in sanitized:
sanitized["headers"] = self.sanitize_headers(sanitized["headers"])
# Sanitize content
if 'content' in sanitized:
if "content" in sanitized:
# Handle base64 encoded content specially
if isinstance(sanitized['content'], dict) and sanitized['content'].get('encoding') == 'base64':
if isinstance(sanitized["content"], dict) and sanitized["content"].get("encoding") == "base64":
# Don't decode/re-encode the actual response body
# but sanitize any metadata
if 'data' in sanitized['content']:
if "data" in sanitized["content"]:
# Keep the data as-is but sanitize other fields
for key, value in sanitized['content'].items():
if key != 'data':
sanitized['content'][key] = self.sanitize_value(value)
for key, value in sanitized["content"].items():
if key != "data":
sanitized["content"][key] = self.sanitize_value(value)
else:
sanitized['content'] = self.sanitize_value(sanitized['content'])
sanitized["content"] = self.sanitize_value(sanitized["content"])
return sanitized
# Global instance for convenience
default_sanitizer = PIISanitizer()
default_sanitizer = PIISanitizer()

View File

@@ -10,10 +10,10 @@ This script will:
"""
import json
import sys
from pathlib import Path
import shutil
import sys
from datetime import datetime
from pathlib import Path
# Add tests directory to path to import our modules
sys.path.insert(0, str(Path(__file__).parent))
@@ -24,54 +24,55 @@ from pii_sanitizer import PIISanitizer
def sanitize_cassette(cassette_path: Path, backup: bool = True) -> bool:
"""Sanitize a single cassette file."""
print(f"\n🔍 Processing: {cassette_path}")
if not cassette_path.exists():
print(f"❌ File not found: {cassette_path}")
return False
try:
# Load cassette
with open(cassette_path, 'r') as f:
with open(cassette_path) as f:
cassette_data = json.load(f)
# Create backup if requested
if backup:
backup_path = cassette_path.with_suffix(f'.backup-{datetime.now().strftime("%Y%m%d-%H%M%S")}.json')
shutil.copy2(cassette_path, backup_path)
print(f"📦 Backup created: {backup_path}")
# Initialize sanitizer
sanitizer = PIISanitizer()
# Sanitize interactions
if 'interactions' in cassette_data:
if "interactions" in cassette_data:
sanitized_interactions = []
for interaction in cassette_data['interactions']:
for interaction in cassette_data["interactions"]:
sanitized_interaction = {}
# Sanitize request
if 'request' in interaction:
sanitized_interaction['request'] = sanitizer.sanitize_request(interaction['request'])
if "request" in interaction:
sanitized_interaction["request"] = sanitizer.sanitize_request(interaction["request"])
# Sanitize response
if 'response' in interaction:
sanitized_interaction['response'] = sanitizer.sanitize_response(interaction['response'])
if "response" in interaction:
sanitized_interaction["response"] = sanitizer.sanitize_response(interaction["response"])
sanitized_interactions.append(sanitized_interaction)
cassette_data['interactions'] = sanitized_interactions
cassette_data["interactions"] = sanitized_interactions
# Save sanitized cassette
with open(cassette_path, 'w') as f:
with open(cassette_path, "w") as f:
json.dump(cassette_data, f, indent=2, sort_keys=True)
print(f"✅ Sanitized: {cassette_path}")
return True
except Exception as e:
print(f"❌ Error processing {cassette_path}: {e}")
import traceback
traceback.print_exc()
return False
@@ -79,31 +80,31 @@ def sanitize_cassette(cassette_path: Path, backup: bool = True) -> bool:
def main():
"""Sanitize all cassettes in the openai_cassettes directory."""
cassettes_dir = Path(__file__).parent / "openai_cassettes"
if not cassettes_dir.exists():
print(f"❌ Directory not found: {cassettes_dir}")
sys.exit(1)
# Find all JSON cassettes
cassette_files = list(cassettes_dir.glob("*.json"))
if not cassette_files:
print(f"❌ No cassette files found in {cassettes_dir}")
sys.exit(1)
print(f"🎬 Found {len(cassette_files)} cassette(s) to sanitize")
# Process each cassette
success_count = 0
for cassette_path in cassette_files:
if sanitize_cassette(cassette_path):
success_count += 1
print(f"\n✨ Sanitization complete: {success_count}/{len(cassette_files)} cassettes processed successfully")
if success_count < len(cassette_files):
sys.exit(1)
if __name__ == "__main__":
main()
main()

View File

@@ -18,11 +18,11 @@ from pathlib import Path
import pytest
from dotenv import load_dotenv
from tools.chat import ChatTool
from providers import ModelProviderRegistry
from providers.base import ProviderType
from providers.openai_provider import OpenAIModelProvider
from tests.http_transport_recorder import TransportFactory
from tools.chat import ChatTool
# Load environment variables from .env file
load_dotenv()
@@ -32,54 +32,87 @@ cassette_dir = Path(__file__).parent / "openai_cassettes"
cassette_dir.mkdir(exist_ok=True)
@pytest.fixture
def allow_all_models(monkeypatch):
"""Allow all models by resetting the restriction service singleton."""
# Import here to avoid circular imports
from utils.model_restrictions import _restriction_service
# Store original state
original_service = _restriction_service
original_allowed_models = os.getenv("ALLOWED_MODELS")
original_openai_key = os.getenv("OPENAI_API_KEY")
# Reset the singleton so it will re-read env vars inside this fixture
monkeypatch.setattr("utils.model_restrictions._restriction_service", None)
monkeypatch.setenv("ALLOWED_MODELS", "") # empty string = no restrictions
monkeypatch.setenv("OPENAI_API_KEY", "dummy-key-for-replay") # transport layer expects a key
# Also clear the provider registry cache to ensure clean state
from providers.registry import ModelProviderRegistry
ModelProviderRegistry.clear_cache()
yield
# Clean up: reset singleton again so other tests don't see the unrestricted version
monkeypatch.setattr("utils.model_restrictions._restriction_service", None)
# Clear registry cache again for other tests
ModelProviderRegistry.clear_cache()
@pytest.mark.no_mock_provider # Disable provider mocking for this test
class TestO3ProOutputTextFix(unittest.IsolatedAsyncioTestCase):
"""Test o3-pro response parsing fix using respx for HTTP recording/replay."""
def setUp(self):
"""Set up the test by ensuring OpenAI provider is registered."""
# Clear any cached providers to ensure clean state
ModelProviderRegistry.clear_cache()
# Manually register the OpenAI provider to ensure it's available
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
@pytest.mark.usefixtures("allow_all_models")
async def test_o3_pro_uses_output_text_field(self):
"""Test that o3-pro parsing uses the output_text convenience field via ChatTool."""
cassette_path = cassette_dir / "o3_pro_basic_math.json"
# Skip if no API key available and cassette doesn't exist
if not cassette_path.exists() and not os.getenv("OPENAI_API_KEY"):
pytest.skip("Set real OPENAI_API_KEY to record cassettes")
# Create transport (automatically selects record vs replay mode)
transport = TransportFactory.create_transport(str(cassette_path))
# Get provider and inject custom transport
provider = ModelProviderRegistry.get_provider_for_model("o3-pro")
if not provider:
self.fail("OpenAI provider not available for o3-pro model")
# Inject transport for this test
original_transport = getattr(provider, '_test_transport', None)
original_transport = getattr(provider, "_test_transport", None)
provider._test_transport = transport
try:
# Execute ChatTool test with custom transport
result = await self._execute_chat_tool_test()
# Verify the response works correctly
self._verify_chat_tool_response(result)
# Verify cassette was created/used
if not cassette_path.exists():
self.fail(f"Cassette should exist at {cassette_path}")
print(f"✅ HTTP transport {'recorded' if isinstance(transport, type(transport).__bases__[0]) else 'replayed'} o3-pro interaction")
print(
f"✅ HTTP transport {'recorded' if isinstance(transport, type(transport).__bases__[0]) else 'replayed'} o3-pro interaction"
)
finally:
# Restore original transport (if any)
if original_transport:
provider._test_transport = original_transport
elif hasattr(provider, '_test_transport'):
delattr(provider, '_test_transport')
elif hasattr(provider, "_test_transport"):
delattr(provider, "_test_transport")
async def _execute_chat_tool_test(self):
"""Execute the ChatTool with o3-pro and return the result."""

View File

@@ -230,10 +230,8 @@ class TestOpenAIProvider:
mock_openai_class.return_value = mock_client
mock_response = MagicMock()
mock_response.output = MagicMock()
mock_response.output.content = [MagicMock()]
mock_response.output.content[0].type = "output_text"
mock_response.output.content[0].text = "4"
# New o3-pro format: direct output_text field
mock_response.output_text = "4"
mock_response.model = "o3-pro-2025-06-10"
mock_response.id = "test-id"
mock_response.created_at = 1234567890

View File

@@ -2,64 +2,59 @@
"""Test cases for PII sanitizer."""
import unittest
from tests.pii_sanitizer import PIISanitizer, PIIPattern
from tests.pii_sanitizer import PIIPattern, PIISanitizer
class TestPIISanitizer(unittest.TestCase):
"""Test PII sanitization functionality."""
def setUp(self):
"""Set up test sanitizer."""
self.sanitizer = PIISanitizer()
def test_api_key_sanitization(self):
"""Test various API key formats are sanitized."""
test_cases = [
# OpenAI keys
("sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12", "sk-proj-SANITIZED"),
("sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN", "sk-SANITIZED"),
# Anthropic keys
("sk-ant-abcd1234567890ABCD1234567890abcd1234567890ABCD12", "sk-ant-SANITIZED"),
# Google keys
("AIzaSyD-1234567890abcdefghijklmnopqrstuv", "AIza-SANITIZED"),
# GitHub tokens
("ghp_1234567890abcdefghijklmnopqrstuvwxyz", "gh_SANITIZED"),
("ghs_1234567890abcdefghijklmnopqrstuvwxyz", "gh_SANITIZED"),
]
for original, expected in test_cases:
with self.subTest(original=original):
result = self.sanitizer.sanitize_string(original)
self.assertEqual(result, expected)
def test_personal_info_sanitization(self):
"""Test personal information is sanitized."""
test_cases = [
# Email addresses
("john.doe@example.com", "user@example.com"),
("test123@company.org", "user@example.com"),
# Phone numbers (all now use the same pattern)
("(555) 123-4567", "(XXX) XXX-XXXX"),
("555-123-4567", "(XXX) XXX-XXXX"),
("+1-555-123-4567", "(XXX) XXX-XXXX"),
# SSN
("123-45-6789", "XXX-XX-XXXX"),
# Credit card
("1234 5678 9012 3456", "XXXX-XXXX-XXXX-XXXX"),
("1234-5678-9012-3456", "XXXX-XXXX-XXXX-XXXX"),
]
for original, expected in test_cases:
with self.subTest(original=original):
result = self.sanitizer.sanitize_string(original)
self.assertEqual(result, expected)
def test_header_sanitization(self):
"""Test HTTP header sanitization."""
headers = {
@@ -67,84 +62,82 @@ class TestPIISanitizer(unittest.TestCase):
"API-Key": "sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN",
"Content-Type": "application/json",
"User-Agent": "MyApp/1.0",
"Cookie": "session=abc123; user=john.doe@example.com"
"Cookie": "session=abc123; user=john.doe@example.com",
}
sanitized = self.sanitizer.sanitize_headers(headers)
self.assertEqual(sanitized["Authorization"], "Bearer SANITIZED")
self.assertEqual(sanitized["API-Key"], "sk-SANITIZED")
self.assertEqual(sanitized["Content-Type"], "application/json")
self.assertEqual(sanitized["User-Agent"], "MyApp/1.0")
self.assertIn("user@example.com", sanitized["Cookie"])
def test_nested_structure_sanitization(self):
"""Test sanitization of nested data structures."""
data = {
"user": {
"email": "john.doe@example.com",
"api_key": "sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12"
"api_key": "sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12",
},
"tokens": [
"ghp_1234567890abcdefghijklmnopqrstuvwxyz",
"Bearer sk-ant-abcd1234567890ABCD1234567890abcd1234567890ABCD12"
"Bearer sk-ant-abcd1234567890ABCD1234567890abcd1234567890ABCD12",
],
"metadata": {
"ip": "192.168.1.100",
"phone": "(555) 123-4567"
}
"metadata": {"ip": "192.168.1.100", "phone": "(555) 123-4567"},
}
sanitized = self.sanitizer.sanitize_value(data)
self.assertEqual(sanitized["user"]["email"], "user@example.com")
self.assertEqual(sanitized["user"]["api_key"], "sk-proj-SANITIZED")
self.assertEqual(sanitized["tokens"][0], "gh_SANITIZED")
self.assertEqual(sanitized["tokens"][1], "Bearer sk-ant-SANITIZED")
self.assertEqual(sanitized["metadata"]["ip"], "0.0.0.0")
self.assertEqual(sanitized["metadata"]["phone"], "(XXX) XXX-XXXX")
def test_url_sanitization(self):
"""Test URL parameter sanitization."""
urls = [
("https://api.example.com/v1/users?api_key=sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN",
"https://api.example.com/v1/users?api_key=SANITIZED"),
("https://example.com/login?token=ghp_1234567890abcdefghijklmnopqrstuvwxyz&user=test",
"https://example.com/login?token=SANITIZED&user=test"),
(
"https://api.example.com/v1/users?api_key=sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN",
"https://api.example.com/v1/users?api_key=SANITIZED",
),
(
"https://example.com/login?token=ghp_1234567890abcdefghijklmnopqrstuvwxyz&user=test",
"https://example.com/login?token=SANITIZED&user=test",
),
]
for original, expected in urls:
with self.subTest(url=original):
result = self.sanitizer.sanitize_url(original)
self.assertEqual(result, expected)
def test_disable_sanitization(self):
"""Test that sanitization can be disabled."""
self.sanitizer.sanitize_enabled = False
sensitive_data = "sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12"
result = self.sanitizer.sanitize_string(sensitive_data)
# Should return original when disabled
self.assertEqual(result, sensitive_data)
def test_custom_pattern(self):
"""Test adding custom PII patterns."""
# Add custom pattern for internal employee IDs
custom_pattern = PIIPattern.create(
name="employee_id",
pattern=r'EMP\d{6}',
replacement="EMP-REDACTED",
description="Internal employee IDs"
name="employee_id", pattern=r"EMP\d{6}", replacement="EMP-REDACTED", description="Internal employee IDs"
)
self.sanitizer.add_pattern(custom_pattern)
text = "Employee EMP123456 has access to the system"
result = self.sanitizer.sanitize_string(text)
self.assertEqual(result, "Employee EMP-REDACTED has access to the system")
if __name__ == "__main__":
unittest.main()
unittest.main()