fix: Resolve o3-pro response parsing and test execution issues

- Fix lint errors: trailing whitespace and deprecated typing imports
- Update test mock for o3-pro response format (output.content[] → output_text)
- Implement robust test isolation with monkeypatch fixture
- Clear provider registry cache to prevent test interference
- Ensure o3-pro tests pass in both individual and full suite execution

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Josh Vera
2025-07-12 20:24:34 -06:00
parent ae5e43b792
commit 3db49413ff
8 changed files with 328 additions and 320 deletions

View File

@@ -221,7 +221,7 @@ class OpenAICompatibleProvider(ModelProvider):
# Create httpx client with minimal config to avoid proxy conflicts
# Note: proxies parameter was removed in httpx 0.28.0
# Check for test transport injection
if hasattr(self, '_test_transport'):
if hasattr(self, "_test_transport"):
# Use custom transport for testing (HTTP recording/replay)
http_client = httpx.Client(
transport=self._test_transport,

View File

@@ -13,16 +13,16 @@ Key Features:
- JSON cassette format with data sanitization
"""
import json
import hashlib
import copy
import base64
import copy
import hashlib
import json
from pathlib import Path
from typing import Dict, Any, Optional
import httpx
from io import BytesIO
from .pii_sanitizer import PIISanitizer
from typing import Any, Optional
import httpx
from .pii_sanitizer import PIISanitizer
class RecordingTransport(httpx.HTTPTransport):
@@ -62,8 +62,9 @@ class RecordingTransport(httpx.HTTPTransport):
# Create a new response with the same metadata but buffered content
# If the original response was gzipped, we need to re-compress
response_content = content_bytes
if response.headers.get('content-encoding') == 'gzip':
if response.headers.get("content-encoding") == "gzip":
import gzip
print(f"🗜️ Re-compressing {len(content_bytes)} bytes with gzip...")
response_content = gzip.compress(content_bytes)
print(f"🗜️ Compressed to {len(response_content)} bytes")
@@ -85,6 +86,7 @@ class RecordingTransport(httpx.HTTPTransport):
except Exception as e:
print(f"⚠️ Content capture failed: {e}, falling back to stub")
import traceback
print(f"⚠️ Full exception traceback:\n{traceback.format_exc()}")
response_data = self._serialize_response(response)
self._record_interaction(request_data, response_data)
@@ -95,17 +97,14 @@ class RecordingTransport(httpx.HTTPTransport):
self._record_interaction(request_data, response_data)
return response
def _record_interaction(self, request_data: Dict[str, Any], response_data: Dict[str, Any]):
def _record_interaction(self, request_data: dict[str, Any], response_data: dict[str, Any]):
"""Helper method to record interaction and save cassette."""
interaction = {
"request": request_data,
"response": response_data
}
interaction = {"request": request_data, "response": response_data}
self.recorded_interactions.append(interaction)
self._save_cassette()
print(f"🎬 RecordingTransport: Saved cassette to {self.cassette_path}")
def _serialize_request(self, request: httpx.Request) -> Dict[str, Any]:
def _serialize_request(self, request: httpx.Request) -> dict[str, Any]:
"""Serialize httpx.Request to JSON-compatible format."""
# For requests, we can safely read the content since it's already been prepared
# httpx.Request.content is safe to access multiple times
@@ -114,7 +113,7 @@ class RecordingTransport(httpx.HTTPTransport):
# Convert bytes to string for JSON serialization
if isinstance(content, bytes):
try:
content_str = content.decode('utf-8')
content_str = content.decode("utf-8")
except UnicodeDecodeError:
# Handle binary content (shouldn't happen for o3-pro API)
content_str = content.hex()
@@ -126,7 +125,7 @@ class RecordingTransport(httpx.HTTPTransport):
"url": str(request.url),
"path": request.url.path,
"headers": dict(request.headers),
"content": self._sanitize_request_content(content_str)
"content": self._sanitize_request_content(content_str),
}
# Apply PII sanitization if enabled
@@ -135,17 +134,17 @@ class RecordingTransport(httpx.HTTPTransport):
return request_data
def _serialize_response(self, response: httpx.Response) -> Dict[str, Any]:
def _serialize_response(self, response: httpx.Response) -> dict[str, Any]:
"""Serialize httpx.Response to JSON-compatible format (legacy method without content)."""
# Legacy method for backward compatibility when content capture is disabled
return {
"status_code": response.status_code,
"headers": dict(response.headers),
"content": {"note": "Response content not recorded to avoid httpx.ResponseNotRead exception"},
"reason_phrase": response.reason_phrase
"reason_phrase": response.reason_phrase,
}
def _serialize_response_with_content(self, response: httpx.Response, content_bytes: bytes) -> Dict[str, Any]:
def _serialize_response_with_content(self, response: httpx.Response, content_bytes: bytes) -> dict[str, Any]:
"""Serialize httpx.Response with captured content."""
try:
# Debug: check what we got
@@ -156,24 +155,20 @@ class RecordingTransport(httpx.HTTPTransport):
if not isinstance(content_bytes, bytes):
print(f"⚠️ Content is not bytes, converting from {type(content_bytes)}")
if isinstance(content_bytes, str):
content_bytes = content_bytes.encode('utf-8')
content_bytes = content_bytes.encode("utf-8")
else:
content_bytes = str(content_bytes).encode('utf-8')
content_bytes = str(content_bytes).encode("utf-8")
# Encode content as base64 for JSON storage
print(f"🔍 Base64 encoding {len(content_bytes)} bytes...")
content_b64 = base64.b64encode(content_bytes).decode('utf-8')
content_b64 = base64.b64encode(content_bytes).decode("utf-8")
print(f"✅ Base64 encoded successfully, result length: {len(content_b64)}")
response_data = {
"status_code": response.status_code,
"headers": dict(response.headers),
"content": {
"data": content_b64,
"encoding": "base64",
"size": len(content_bytes)
},
"reason_phrase": response.reason_phrase
"content": {"data": content_b64, "encoding": "base64", "size": len(content_bytes)},
"reason_phrase": response.reason_phrase,
}
# Apply PII sanitization if enabled
@@ -184,13 +179,14 @@ class RecordingTransport(httpx.HTTPTransport):
except Exception as e:
print(f"🔍 Error in _serialize_response_with_content: {e}")
import traceback
print(f"🔍 Full traceback: {traceback.format_exc()}")
# Fall back to minimal info
return {
"status_code": response.status_code,
"headers": dict(response.headers),
"content": {"error": f"Failed to serialize content: {e}"},
"reason_phrase": response.reason_phrase
"reason_phrase": response.reason_phrase,
}
def _sanitize_request_content(self, content: str) -> Any:
@@ -240,13 +236,9 @@ class RecordingTransport(httpx.HTTPTransport):
self.cassette_path.parent.mkdir(parents=True, exist_ok=True)
# Save cassette
cassette_data = {
"interactions": self.recorded_interactions
}
cassette_data = {"interactions": self.recorded_interactions}
self.cassette_path.write_text(
json.dumps(cassette_data, indent=2, sort_keys=True)
)
self.cassette_path.write_text(json.dumps(cassette_data, indent=2, sort_keys=True))
class ReplayTransport(httpx.MockTransport):
@@ -278,10 +270,10 @@ class ReplayTransport(httpx.MockTransport):
# Debug: show actual request content
content = request.content
if hasattr(content, 'read'):
if hasattr(content, "read"):
content = content.read()
if isinstance(content, bytes):
content_str = content.decode('utf-8', errors='ignore')
content_str = content.decode("utf-8", errors="ignore")
else:
content_str = str(content) if content else ""
print(f"🔍 Actual request content: {content_str}")
@@ -297,11 +289,9 @@ class ReplayTransport(httpx.MockTransport):
interaction = self._find_matching_interaction(request)
if not interaction:
print("🚨 MYSTERY SOLVED: No matching interaction found! This should fail...")
raise ValueError(
f"No matching interaction found for {request.method} {request.url}"
)
raise ValueError(f"No matching interaction found for {request.method} {request.url}")
print(f"✅ Found matching interaction from cassette!")
print("✅ Found matching interaction from cassette!")
# Build response from saved data
response_data = interaction["response"]
@@ -317,18 +307,19 @@ class ReplayTransport(httpx.MockTransport):
print(f"🎬 ReplayTransport: Decoded {len(content_bytes)} bytes from base64")
except Exception as e:
print(f"⚠️ Failed to decode base64 content: {e}")
content_bytes = json.dumps(content).encode('utf-8')
content_bytes = json.dumps(content).encode("utf-8")
else:
# Legacy format or stub content
content_bytes = json.dumps(content).encode('utf-8')
content_bytes = json.dumps(content).encode("utf-8")
else:
content_bytes = str(content).encode('utf-8')
content_bytes = str(content).encode("utf-8")
# Check if response expects gzipped content
headers = response_data.get("headers", {})
if headers.get('content-encoding') == 'gzip':
if headers.get("content-encoding") == "gzip":
# Re-compress the content for httpx
import gzip
print(f"🗜️ ReplayTransport: Re-compressing {len(content_bytes)} bytes with gzip...")
content_bytes = gzip.compress(content_bytes)
print(f"🗜️ ReplayTransport: Compressed to {len(content_bytes)} bytes")
@@ -340,10 +331,10 @@ class ReplayTransport(httpx.MockTransport):
status_code=response_data["status_code"],
headers=response_data.get("headers", {}),
content=content_bytes,
request=request
request=request,
)
def _find_matching_interaction(self, request: httpx.Request) -> Optional[Dict[str, Any]]:
def _find_matching_interaction(self, request: httpx.Request) -> Optional[dict[str, Any]]:
"""Find interaction that matches the request."""
request_signature = self._get_request_signature(request)
@@ -358,11 +349,11 @@ class ReplayTransport(httpx.MockTransport):
"""Generate signature for request matching."""
# Use method, path, and content hash for matching
content = request.content
if hasattr(content, 'read'):
if hasattr(content, "read"):
content = content.read()
if isinstance(content, bytes):
content_str = content.decode('utf-8', errors='ignore')
content_str = content.decode("utf-8", errors="ignore")
else:
content_str = str(content) if content else ""
@@ -380,7 +371,7 @@ class ReplayTransport(httpx.MockTransport):
return f"{request.method}:{request.url.path}:{content_hash}"
def _get_saved_request_signature(self, saved_request: Dict[str, Any]) -> str:
def _get_saved_request_signature(self, saved_request: dict[str, Any]) -> str:
"""Generate signature for saved request."""
method = saved_request["method"]
path = saved_request["path"]

View File

@@ -7,11 +7,12 @@ request/response recordings to prevent accidental exposure of API keys,
tokens, personal information, and other sensitive data.
"""
import re
from typing import Any, Dict, List, Optional, Pattern
from dataclasses import dataclass
from copy import deepcopy
import logging
import re
from copy import deepcopy
from dataclasses import dataclass
from re import Pattern
from typing import Any, Optional
logger = logging.getLogger(__name__)
@@ -19,28 +20,24 @@ logger = logging.getLogger(__name__)
@dataclass
class PIIPattern:
"""Defines a pattern for detecting and sanitizing PII."""
name: str
pattern: Pattern[str]
replacement: str
description: str
@classmethod
def create(cls, name: str, pattern: str, replacement: str, description: str) -> 'PIIPattern':
def create(cls, name: str, pattern: str, replacement: str, description: str) -> "PIIPattern":
"""Create a PIIPattern with compiled regex."""
return cls(
name=name,
pattern=re.compile(pattern),
replacement=replacement,
description=description
)
return cls(name=name, pattern=re.compile(pattern), replacement=replacement, description=description)
class PIISanitizer:
"""Sanitizes PII from various data structures while preserving format."""
def __init__(self, patterns: Optional[List[PIIPattern]] = None):
def __init__(self, patterns: Optional[list[PIIPattern]] = None):
"""Initialize with optional custom patterns."""
self.patterns: List[PIIPattern] = patterns or []
self.patterns: list[PIIPattern] = patterns or []
self.sanitize_enabled = True
# Add default patterns if none provided
@@ -53,95 +50,91 @@ class PIISanitizer:
# API Keys - Core patterns (Bearer tokens handled in sanitize_headers)
PIIPattern.create(
name="openai_api_key_proj",
pattern=r'sk-proj-[A-Za-z0-9\-_]{48,}',
pattern=r"sk-proj-[A-Za-z0-9\-_]{48,}",
replacement="sk-proj-SANITIZED",
description="OpenAI project API keys"
description="OpenAI project API keys",
),
PIIPattern.create(
name="openai_api_key",
pattern=r'sk-[A-Za-z0-9]{48,}',
pattern=r"sk-[A-Za-z0-9]{48,}",
replacement="sk-SANITIZED",
description="OpenAI API keys"
description="OpenAI API keys",
),
PIIPattern.create(
name="anthropic_api_key",
pattern=r'sk-ant-[A-Za-z0-9\-_]{48,}',
pattern=r"sk-ant-[A-Za-z0-9\-_]{48,}",
replacement="sk-ant-SANITIZED",
description="Anthropic API keys"
description="Anthropic API keys",
),
PIIPattern.create(
name="google_api_key",
pattern=r'AIza[A-Za-z0-9\-_]{35,}',
pattern=r"AIza[A-Za-z0-9\-_]{35,}",
replacement="AIza-SANITIZED",
description="Google API keys"
description="Google API keys",
),
PIIPattern.create(
name="github_tokens",
pattern=r'gh[psr]_[A-Za-z0-9]{36}',
pattern=r"gh[psr]_[A-Za-z0-9]{36}",
replacement="gh_SANITIZED",
description="GitHub tokens (all types)"
description="GitHub tokens (all types)",
),
# JWT tokens
PIIPattern.create(
name="jwt_token",
pattern=r'eyJ[A-Za-z0-9\-_]+\.eyJ[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_]+',
pattern=r"eyJ[A-Za-z0-9\-_]+\.eyJ[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_]+",
replacement="eyJ-SANITIZED",
description="JSON Web Tokens"
description="JSON Web Tokens",
),
# Personal Information
PIIPattern.create(
name="email_address",
pattern=r'[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}',
pattern=r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}",
replacement="user@example.com",
description="Email addresses"
description="Email addresses",
),
PIIPattern.create(
name="ipv4_address",
pattern=r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b',
pattern=r"\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b",
replacement="0.0.0.0",
description="IPv4 addresses"
description="IPv4 addresses",
),
PIIPattern.create(
name="ssn",
pattern=r'\b\d{3}-\d{2}-\d{4}\b',
pattern=r"\b\d{3}-\d{2}-\d{4}\b",
replacement="XXX-XX-XXXX",
description="Social Security Numbers"
description="Social Security Numbers",
),
PIIPattern.create(
name="credit_card",
pattern=r'\b\d{4}[\s\-]?\d{4}[\s\-]?\d{4}[\s\-]?\d{4}\b',
pattern=r"\b\d{4}[\s\-]?\d{4}[\s\-]?\d{4}[\s\-]?\d{4}\b",
replacement="XXXX-XXXX-XXXX-XXXX",
description="Credit card numbers"
description="Credit card numbers",
),
PIIPattern.create(
name="phone_number",
pattern=r'(?:\+\d{1,3}[\s\-]?)?\(?\d{3}\)?[\s\-]?\d{3}[\s\-]?\d{4}',
pattern=r"(?:\+\d{1,3}[\s\-]?)?\(?\d{3}\)?[\s\-]?\d{3}[\s\-]?\d{4}",
replacement="(XXX) XXX-XXXX",
description="Phone numbers (all formats)"
description="Phone numbers (all formats)",
),
# AWS
PIIPattern.create(
name="aws_access_key",
pattern=r'AKIA[0-9A-Z]{16}',
pattern=r"AKIA[0-9A-Z]{16}",
replacement="AKIA-SANITIZED",
description="AWS access keys"
description="AWS access keys",
),
# Other common patterns
PIIPattern.create(
name="slack_token",
pattern=r'xox[baprs]-[0-9]{10,13}-[0-9]{10,13}-[a-zA-Z0-9]{24,34}',
pattern=r"xox[baprs]-[0-9]{10,13}-[0-9]{10,13}-[a-zA-Z0-9]{24,34}",
replacement="xox-SANITIZED",
description="Slack tokens"
description="Slack tokens",
),
PIIPattern.create(
name="stripe_key",
pattern=r'(?:sk|pk)_(?:test|live)_[0-9a-zA-Z]{24,99}',
pattern=r"(?:sk|pk)_(?:test|live)_[0-9a-zA-Z]{24,99}",
replacement="sk_SANITIZED",
description="Stripe API keys"
description="Stripe API keys",
),
]
@@ -165,7 +158,7 @@ class PIISanitizer:
return sanitized
def sanitize_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
def sanitize_headers(self, headers: dict[str, str]) -> dict[str, str]:
"""Special handling for HTTP headers."""
if not self.sanitize_enabled:
return headers
@@ -174,10 +167,10 @@ class PIISanitizer:
for key, value in headers.items():
# Special case for Authorization headers to preserve auth type
if key.lower() == 'authorization' and ' ' in value:
auth_type = value.split(' ', 1)[0]
if auth_type in ('Bearer', 'Basic'):
sanitized_headers[key] = f'{auth_type} SANITIZED'
if key.lower() == "authorization" and " " in value:
auth_type = value.split(" ", 1)[0]
if auth_type in ("Bearer", "Basic"):
sanitized_headers[key] = f"{auth_type} SANITIZED"
else:
sanitized_headers[key] = self.sanitize_string(value)
else:
@@ -212,15 +205,15 @@ class PIISanitizer:
url = self.sanitize_string(url)
# Parse and sanitize query parameters
if '?' in url:
base, query = url.split('?', 1)
if "?" in url:
base, query = url.split("?", 1)
params = []
for param in query.split('&'):
if '=' in param:
key, value = param.split('=', 1)
for param in query.split("&"):
if "=" in param:
key, value = param.split("=", 1)
# Sanitize common sensitive parameter names
sensitive_params = {'key', 'token', 'api_key', 'secret', 'password'}
sensitive_params = {"key", "token", "api_key", "secret", "password"}
if key.lower() in sensitive_params:
params.append(f"{key}=SANITIZED")
else:
@@ -233,46 +226,45 @@ class PIISanitizer:
return url
def sanitize_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
def sanitize_request(self, request_data: dict[str, Any]) -> dict[str, Any]:
"""Sanitize a complete request dictionary."""
sanitized = deepcopy(request_data)
# Sanitize headers
if 'headers' in sanitized:
sanitized['headers'] = self.sanitize_headers(sanitized['headers'])
if "headers" in sanitized:
sanitized["headers"] = self.sanitize_headers(sanitized["headers"])
# Sanitize URL
if 'url' in sanitized:
sanitized['url'] = self.sanitize_url(sanitized['url'])
if "url" in sanitized:
sanitized["url"] = self.sanitize_url(sanitized["url"])
# Sanitize content
if 'content' in sanitized:
sanitized['content'] = self.sanitize_value(sanitized['content'])
if "content" in sanitized:
sanitized["content"] = self.sanitize_value(sanitized["content"])
return sanitized
def sanitize_response(self, response_data: Dict[str, Any]) -> Dict[str, Any]:
def sanitize_response(self, response_data: dict[str, Any]) -> dict[str, Any]:
"""Sanitize a complete response dictionary."""
sanitized = deepcopy(response_data)
# Sanitize headers
if 'headers' in sanitized:
sanitized['headers'] = self.sanitize_headers(sanitized['headers'])
if "headers" in sanitized:
sanitized["headers"] = self.sanitize_headers(sanitized["headers"])
# Sanitize content
if 'content' in sanitized:
if "content" in sanitized:
# Handle base64 encoded content specially
if isinstance(sanitized['content'], dict) and sanitized['content'].get('encoding') == 'base64':
if isinstance(sanitized["content"], dict) and sanitized["content"].get("encoding") == "base64":
# Don't decode/re-encode the actual response body
# but sanitize any metadata
if 'data' in sanitized['content']:
if "data" in sanitized["content"]:
# Keep the data as-is but sanitize other fields
for key, value in sanitized['content'].items():
if key != 'data':
sanitized['content'][key] = self.sanitize_value(value)
for key, value in sanitized["content"].items():
if key != "data":
sanitized["content"][key] = self.sanitize_value(value)
else:
sanitized['content'] = self.sanitize_value(sanitized['content'])
sanitized["content"] = self.sanitize_value(sanitized["content"])
return sanitized

View File

@@ -10,10 +10,10 @@ This script will:
"""
import json
import sys
from pathlib import Path
import shutil
import sys
from datetime import datetime
from pathlib import Path
# Add tests directory to path to import our modules
sys.path.insert(0, str(Path(__file__).parent))
@@ -31,7 +31,7 @@ def sanitize_cassette(cassette_path: Path, backup: bool = True) -> bool:
try:
# Load cassette
with open(cassette_path, 'r') as f:
with open(cassette_path) as f:
cassette_data = json.load(f)
# Create backup if requested
@@ -44,26 +44,26 @@ def sanitize_cassette(cassette_path: Path, backup: bool = True) -> bool:
sanitizer = PIISanitizer()
# Sanitize interactions
if 'interactions' in cassette_data:
if "interactions" in cassette_data:
sanitized_interactions = []
for interaction in cassette_data['interactions']:
for interaction in cassette_data["interactions"]:
sanitized_interaction = {}
# Sanitize request
if 'request' in interaction:
sanitized_interaction['request'] = sanitizer.sanitize_request(interaction['request'])
if "request" in interaction:
sanitized_interaction["request"] = sanitizer.sanitize_request(interaction["request"])
# Sanitize response
if 'response' in interaction:
sanitized_interaction['response'] = sanitizer.sanitize_response(interaction['response'])
if "response" in interaction:
sanitized_interaction["response"] = sanitizer.sanitize_response(interaction["response"])
sanitized_interactions.append(sanitized_interaction)
cassette_data['interactions'] = sanitized_interactions
cassette_data["interactions"] = sanitized_interactions
# Save sanitized cassette
with open(cassette_path, 'w') as f:
with open(cassette_path, "w") as f:
json.dump(cassette_data, f, indent=2, sort_keys=True)
print(f"✅ Sanitized: {cassette_path}")
@@ -72,6 +72,7 @@ def sanitize_cassette(cassette_path: Path, backup: bool = True) -> bool:
except Exception as e:
print(f"❌ Error processing {cassette_path}: {e}")
import traceback
traceback.print_exc()
return False

View File

@@ -18,11 +18,11 @@ from pathlib import Path
import pytest
from dotenv import load_dotenv
from tools.chat import ChatTool
from providers import ModelProviderRegistry
from providers.base import ProviderType
from providers.openai_provider import OpenAIModelProvider
from tests.http_transport_recorder import TransportFactory
from tools.chat import ChatTool
# Load environment variables from .env file
load_dotenv()
@@ -32,15 +32,46 @@ cassette_dir = Path(__file__).parent / "openai_cassettes"
cassette_dir.mkdir(exist_ok=True)
@pytest.fixture
def allow_all_models(monkeypatch):
"""Allow all models by resetting the restriction service singleton."""
# Import here to avoid circular imports
from utils.model_restrictions import _restriction_service
# Store original state
original_service = _restriction_service
original_allowed_models = os.getenv("ALLOWED_MODELS")
original_openai_key = os.getenv("OPENAI_API_KEY")
# Reset the singleton so it will re-read env vars inside this fixture
monkeypatch.setattr("utils.model_restrictions._restriction_service", None)
monkeypatch.setenv("ALLOWED_MODELS", "") # empty string = no restrictions
monkeypatch.setenv("OPENAI_API_KEY", "dummy-key-for-replay") # transport layer expects a key
# Also clear the provider registry cache to ensure clean state
from providers.registry import ModelProviderRegistry
ModelProviderRegistry.clear_cache()
yield
# Clean up: reset singleton again so other tests don't see the unrestricted version
monkeypatch.setattr("utils.model_restrictions._restriction_service", None)
# Clear registry cache again for other tests
ModelProviderRegistry.clear_cache()
@pytest.mark.no_mock_provider # Disable provider mocking for this test
class TestO3ProOutputTextFix(unittest.IsolatedAsyncioTestCase):
"""Test o3-pro response parsing fix using respx for HTTP recording/replay."""
def setUp(self):
"""Set up the test by ensuring OpenAI provider is registered."""
# Clear any cached providers to ensure clean state
ModelProviderRegistry.clear_cache()
# Manually register the OpenAI provider to ensure it's available
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
@pytest.mark.usefixtures("allow_all_models")
async def test_o3_pro_uses_output_text_field(self):
"""Test that o3-pro parsing uses the output_text convenience field via ChatTool."""
cassette_path = cassette_dir / "o3_pro_basic_math.json"
@@ -58,7 +89,7 @@ class TestO3ProOutputTextFix(unittest.IsolatedAsyncioTestCase):
self.fail("OpenAI provider not available for o3-pro model")
# Inject transport for this test
original_transport = getattr(provider, '_test_transport', None)
original_transport = getattr(provider, "_test_transport", None)
provider._test_transport = transport
try:
@@ -72,14 +103,16 @@ class TestO3ProOutputTextFix(unittest.IsolatedAsyncioTestCase):
if not cassette_path.exists():
self.fail(f"Cassette should exist at {cassette_path}")
print(f"✅ HTTP transport {'recorded' if isinstance(transport, type(transport).__bases__[0]) else 'replayed'} o3-pro interaction")
print(
f"✅ HTTP transport {'recorded' if isinstance(transport, type(transport).__bases__[0]) else 'replayed'} o3-pro interaction"
)
finally:
# Restore original transport (if any)
if original_transport:
provider._test_transport = original_transport
elif hasattr(provider, '_test_transport'):
delattr(provider, '_test_transport')
elif hasattr(provider, "_test_transport"):
delattr(provider, "_test_transport")
async def _execute_chat_tool_test(self):
"""Execute the ChatTool with o3-pro and return the result."""

View File

@@ -230,10 +230,8 @@ class TestOpenAIProvider:
mock_openai_class.return_value = mock_client
mock_response = MagicMock()
mock_response.output = MagicMock()
mock_response.output.content = [MagicMock()]
mock_response.output.content[0].type = "output_text"
mock_response.output.content[0].text = "4"
# New o3-pro format: direct output_text field
mock_response.output_text = "4"
mock_response.model = "o3-pro-2025-06-10"
mock_response.id = "test-id"
mock_response.created_at = 1234567890

View File

@@ -2,7 +2,8 @@
"""Test cases for PII sanitizer."""
import unittest
from tests.pii_sanitizer import PIISanitizer, PIIPattern
from tests.pii_sanitizer import PIIPattern, PIISanitizer
class TestPIISanitizer(unittest.TestCase):
@@ -18,13 +19,10 @@ class TestPIISanitizer(unittest.TestCase):
# OpenAI keys
("sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12", "sk-proj-SANITIZED"),
("sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN", "sk-SANITIZED"),
# Anthropic keys
("sk-ant-abcd1234567890ABCD1234567890abcd1234567890ABCD12", "sk-ant-SANITIZED"),
# Google keys
("AIzaSyD-1234567890abcdefghijklmnopqrstuv", "AIza-SANITIZED"),
# GitHub tokens
("ghp_1234567890abcdefghijklmnopqrstuvwxyz", "gh_SANITIZED"),
("ghs_1234567890abcdefghijklmnopqrstuvwxyz", "gh_SANITIZED"),
@@ -41,15 +39,12 @@ class TestPIISanitizer(unittest.TestCase):
# Email addresses
("john.doe@example.com", "user@example.com"),
("test123@company.org", "user@example.com"),
# Phone numbers (all now use the same pattern)
("(555) 123-4567", "(XXX) XXX-XXXX"),
("555-123-4567", "(XXX) XXX-XXXX"),
("+1-555-123-4567", "(XXX) XXX-XXXX"),
# SSN
("123-45-6789", "XXX-XX-XXXX"),
# Credit card
("1234 5678 9012 3456", "XXXX-XXXX-XXXX-XXXX"),
("1234-5678-9012-3456", "XXXX-XXXX-XXXX-XXXX"),
@@ -67,7 +62,7 @@ class TestPIISanitizer(unittest.TestCase):
"API-Key": "sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN",
"Content-Type": "application/json",
"User-Agent": "MyApp/1.0",
"Cookie": "session=abc123; user=john.doe@example.com"
"Cookie": "session=abc123; user=john.doe@example.com",
}
sanitized = self.sanitizer.sanitize_headers(headers)
@@ -83,16 +78,13 @@ class TestPIISanitizer(unittest.TestCase):
data = {
"user": {
"email": "john.doe@example.com",
"api_key": "sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12"
"api_key": "sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12",
},
"tokens": [
"ghp_1234567890abcdefghijklmnopqrstuvwxyz",
"Bearer sk-ant-abcd1234567890ABCD1234567890abcd1234567890ABCD12"
"Bearer sk-ant-abcd1234567890ABCD1234567890abcd1234567890ABCD12",
],
"metadata": {
"ip": "192.168.1.100",
"phone": "(555) 123-4567"
}
"metadata": {"ip": "192.168.1.100", "phone": "(555) 123-4567"},
}
sanitized = self.sanitizer.sanitize_value(data)
@@ -107,10 +99,14 @@ class TestPIISanitizer(unittest.TestCase):
def test_url_sanitization(self):
"""Test URL parameter sanitization."""
urls = [
("https://api.example.com/v1/users?api_key=sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN",
"https://api.example.com/v1/users?api_key=SANITIZED"),
("https://example.com/login?token=ghp_1234567890abcdefghijklmnopqrstuvwxyz&user=test",
"https://example.com/login?token=SANITIZED&user=test"),
(
"https://api.example.com/v1/users?api_key=sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN",
"https://api.example.com/v1/users?api_key=SANITIZED",
),
(
"https://example.com/login?token=ghp_1234567890abcdefghijklmnopqrstuvwxyz&user=test",
"https://example.com/login?token=SANITIZED&user=test",
),
]
for original, expected in urls:
@@ -132,10 +128,7 @@ class TestPIISanitizer(unittest.TestCase):
"""Test adding custom PII patterns."""
# Add custom pattern for internal employee IDs
custom_pattern = PIIPattern.create(
name="employee_id",
pattern=r'EMP\d{6}',
replacement="EMP-REDACTED",
description="Internal employee IDs"
name="employee_id", pattern=r"EMP\d{6}", replacement="EMP-REDACTED", description="Internal employee IDs"
)
self.sanitizer.add_pattern(custom_pattern)