Resolve merge conflicts in o3-pro response parsing fix
- Use new output_text field format for o3-pro responses - Update test expectations to use resolved model name o3-pro-2025-06-10 - Keep HTTP transport recorder and PII sanitization improvements - Preserve both bug fix and recent GPT-5 updates 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -115,6 +115,14 @@ Test isolated components and functions:
|
|||||||
- **File handling**: Path validation, token limits, deduplication
|
- **File handling**: Path validation, token limits, deduplication
|
||||||
- **Auto mode**: Model selection logic and fallback behavior
|
- **Auto mode**: Model selection logic and fallback behavior
|
||||||
|
|
||||||
|
### HTTP Recording/Replay Tests (HTTP Transport Recorder)
|
||||||
|
Tests for expensive API calls (like o3-pro) use custom recording/replay:
|
||||||
|
- **Real API validation**: Tests against actual provider responses
|
||||||
|
- **Cost efficiency**: Record once, replay forever
|
||||||
|
- **Provider compatibility**: Validates fixes against real APIs
|
||||||
|
- Uses HTTP Transport Recorder for httpx-based API calls
|
||||||
|
- See [HTTP Recording/Replay Testing Guide](./vcr-testing.md) for details
|
||||||
|
|
||||||
### Simulator Tests
|
### Simulator Tests
|
||||||
Validate real-world usage scenarios by simulating actual Claude prompts:
|
Validate real-world usage scenarios by simulating actual Claude prompts:
|
||||||
- **Basic conversations**: Multi-turn chat functionality with real prompts
|
- **Basic conversations**: Multi-turn chat functionality with real prompts
|
||||||
|
|||||||
128
docs/vcr-testing.md
Normal file
128
docs/vcr-testing.md
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
# HTTP Transport Recorder for Testing
|
||||||
|
|
||||||
|
A custom HTTP recorder for testing expensive API calls (like o3-pro) with real responses.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The HTTP Transport Recorder captures and replays HTTP interactions at the transport layer, enabling:
|
||||||
|
- Cost-efficient testing of expensive APIs (record once, replay forever)
|
||||||
|
- Deterministic tests with real API responses
|
||||||
|
- Seamless integration with httpx and OpenAI SDK
|
||||||
|
- Automatic PII sanitization for secure recordings
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```python
|
||||||
|
from tests.transport_helpers import inject_transport
|
||||||
|
|
||||||
|
# Simple one-line setup with automatic transport injection
|
||||||
|
def test_expensive_api_call(monkeypatch):
|
||||||
|
inject_transport(monkeypatch, "tests/openai_cassettes/my_test.json")
|
||||||
|
|
||||||
|
# Make API calls - automatically recorded/replayed with PII sanitization
|
||||||
|
result = await chat_tool.execute({"prompt": "2+2?", "model": "o3-pro"})
|
||||||
|
```
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
1. **First run** (cassette doesn't exist): Records real API calls
|
||||||
|
2. **Subsequent runs** (cassette exists): Replays saved responses
|
||||||
|
3. **Re-record**: Delete cassette file and run again
|
||||||
|
|
||||||
|
## Usage in Tests
|
||||||
|
|
||||||
|
The `transport_helpers.inject_transport()` function simplifies test setup:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from tests.transport_helpers import inject_transport
|
||||||
|
|
||||||
|
async def test_with_recording(monkeypatch):
|
||||||
|
# One-line setup - handles all transport injection complexity
|
||||||
|
inject_transport(monkeypatch, "tests/openai_cassettes/my_test.json")
|
||||||
|
|
||||||
|
# Use API normally - recording/replay happens transparently
|
||||||
|
result = await chat_tool.execute({"prompt": "2+2?", "model": "o3-pro"})
|
||||||
|
```
|
||||||
|
|
||||||
|
For manual setup, see `test_o3_pro_output_text_fix.py`.
|
||||||
|
|
||||||
|
## Automatic PII Sanitization
|
||||||
|
|
||||||
|
All recordings are automatically sanitized to remove sensitive data:
|
||||||
|
|
||||||
|
- **API Keys & Tokens**: Bearer tokens, API keys, and auth headers
|
||||||
|
- **Personal Data**: Email addresses, IP addresses, phone numbers
|
||||||
|
- **URLs**: Sensitive query parameters and paths
|
||||||
|
- **Custom Patterns**: Add your own sanitization rules
|
||||||
|
|
||||||
|
Sanitization is enabled by default in `RecordingTransport`. To disable:
|
||||||
|
|
||||||
|
```python
|
||||||
|
transport = TransportFactory.create_transport(cassette_path, sanitize=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
## File Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
tests/
|
||||||
|
├── openai_cassettes/ # Recorded API interactions
|
||||||
|
│ └── *.json # Cassette files
|
||||||
|
├── http_transport_recorder.py # Transport implementation
|
||||||
|
├── pii_sanitizer.py # Automatic PII sanitization
|
||||||
|
├── transport_helpers.py # Simplified transport injection
|
||||||
|
├── sanitize_cassettes.py # Batch sanitization script
|
||||||
|
└── test_o3_pro_output_text_fix.py # Example usage
|
||||||
|
```
|
||||||
|
|
||||||
|
## Sanitizing Existing Cassettes
|
||||||
|
|
||||||
|
Use the `sanitize_cassettes.py` script to clean existing recordings:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Sanitize all cassettes (creates backups)
|
||||||
|
python tests/sanitize_cassettes.py
|
||||||
|
|
||||||
|
# Sanitize specific cassette
|
||||||
|
python tests/sanitize_cassettes.py tests/openai_cassettes/my_test.json
|
||||||
|
|
||||||
|
# Skip backup creation
|
||||||
|
python tests/sanitize_cassettes.py --no-backup
|
||||||
|
```
|
||||||
|
|
||||||
|
The script will:
|
||||||
|
- Create timestamped backups of original files
|
||||||
|
- Apply comprehensive PII sanitization
|
||||||
|
- Preserve JSON structure and functionality
|
||||||
|
|
||||||
|
## Cost Management
|
||||||
|
|
||||||
|
- **One-time cost**: Initial recording only
|
||||||
|
- **Zero ongoing cost**: Replays are free
|
||||||
|
- **CI-friendly**: No API keys needed for replay
|
||||||
|
|
||||||
|
## Re-recording
|
||||||
|
|
||||||
|
When API changes require new recordings:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Delete specific cassette
|
||||||
|
rm tests/openai_cassettes/my_test.json
|
||||||
|
|
||||||
|
# Run test with real API key
|
||||||
|
python -m pytest tests/test_o3_pro_output_text_fix.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Implementation Details
|
||||||
|
|
||||||
|
- **RecordingTransport**: Captures real HTTP calls with automatic PII sanitization
|
||||||
|
- **ReplayTransport**: Serves saved responses from cassettes
|
||||||
|
- **TransportFactory**: Auto-selects mode based on cassette existence
|
||||||
|
- **PIISanitizer**: Comprehensive sanitization of sensitive data (integrated by default)
|
||||||
|
|
||||||
|
**Security Note**: While recordings are automatically sanitized, always review new cassette files before committing. The sanitizer removes known patterns of sensitive data, but domain-specific secrets may need custom rules.
|
||||||
|
|
||||||
|
For implementation details, see:
|
||||||
|
- `tests/http_transport_recorder.py` - Core transport implementation
|
||||||
|
- `tests/pii_sanitizer.py` - Sanitization patterns and logic
|
||||||
|
- `tests/transport_helpers.py` - Simplified test integration
|
||||||
|
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Base class for OpenAI-compatible API providers."""
|
"""Base class for OpenAI-compatible API providers."""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
import copy
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -220,6 +221,16 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
|
|
||||||
# Create httpx client with minimal config to avoid proxy conflicts
|
# Create httpx client with minimal config to avoid proxy conflicts
|
||||||
# Note: proxies parameter was removed in httpx 0.28.0
|
# Note: proxies parameter was removed in httpx 0.28.0
|
||||||
|
# Check for test transport injection
|
||||||
|
if hasattr(self, "_test_transport"):
|
||||||
|
# Use custom transport for testing (HTTP recording/replay)
|
||||||
|
http_client = httpx.Client(
|
||||||
|
transport=self._test_transport,
|
||||||
|
timeout=timeout_config,
|
||||||
|
follow_redirects=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Normal production client
|
||||||
http_client = httpx.Client(
|
http_client = httpx.Client(
|
||||||
timeout=timeout_config,
|
timeout=timeout_config,
|
||||||
follow_redirects=True,
|
follow_redirects=True,
|
||||||
@@ -264,6 +275,63 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
|
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
|
def _sanitize_for_logging(self, params: dict) -> dict:
|
||||||
|
"""Sanitize sensitive data from parameters before logging.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: Dictionary of API parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Sanitized copy of parameters safe for logging
|
||||||
|
"""
|
||||||
|
sanitized = copy.deepcopy(params)
|
||||||
|
|
||||||
|
# Sanitize messages content
|
||||||
|
if "input" in sanitized:
|
||||||
|
for msg in sanitized.get("input", []):
|
||||||
|
if isinstance(msg, dict) and "content" in msg:
|
||||||
|
for content_item in msg.get("content", []):
|
||||||
|
if isinstance(content_item, dict) and "text" in content_item:
|
||||||
|
# Truncate long text and add ellipsis
|
||||||
|
text = content_item["text"]
|
||||||
|
if len(text) > 100:
|
||||||
|
content_item["text"] = text[:100] + "... [truncated]"
|
||||||
|
|
||||||
|
# Remove any API keys that might be in headers/auth
|
||||||
|
sanitized.pop("api_key", None)
|
||||||
|
sanitized.pop("authorization", None)
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
def _safe_extract_output_text(self, response) -> str:
|
||||||
|
"""Safely extract output_text from o3-pro response with validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Response object from OpenAI SDK
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The output text content
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If output_text is missing, None, or not a string
|
||||||
|
"""
|
||||||
|
logging.debug(f"Response object type: {type(response)}")
|
||||||
|
logging.debug(f"Response attributes: {dir(response)}")
|
||||||
|
|
||||||
|
if not hasattr(response, "output_text"):
|
||||||
|
raise ValueError(f"o3-pro response missing output_text field. Response type: {type(response).__name__}")
|
||||||
|
|
||||||
|
content = response.output_text
|
||||||
|
logging.debug(f"Extracted output_text: '{content}' (type: {type(content)})")
|
||||||
|
|
||||||
|
if content is None:
|
||||||
|
raise ValueError("o3-pro returned None for output_text")
|
||||||
|
|
||||||
|
if not isinstance(content, str):
|
||||||
|
raise ValueError(f"o3-pro output_text is not a string. Got type: {type(content).__name__}")
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
def _generate_with_responses_endpoint(
|
def _generate_with_responses_endpoint(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@@ -312,29 +380,20 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
actual_attempts = 0
|
actual_attempts = 0
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
|
try: # Log sanitized payload for debugging
|
||||||
try: # Log the exact payload being sent for debugging
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
sanitized_params = self._sanitize_for_logging(completion_params)
|
||||||
logging.info(
|
logging.info(
|
||||||
f"o3-pro API request payload: {json.dumps(completion_params, indent=2, ensure_ascii=False)}"
|
f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use OpenAI client's responses endpoint
|
# Use OpenAI client's responses endpoint
|
||||||
response = self.client.responses.create(**completion_params)
|
response = self.client.responses.create(**completion_params)
|
||||||
|
|
||||||
# Extract content and usage from responses endpoint format
|
# Extract content from responses endpoint format
|
||||||
# The response format is different for responses endpoint
|
# Use validation helper to safely extract output_text
|
||||||
content = ""
|
content = self._safe_extract_output_text(response)
|
||||||
if hasattr(response, "output") and response.output:
|
|
||||||
if hasattr(response.output, "content") and response.output.content:
|
|
||||||
# Look for output_text in content
|
|
||||||
for content_item in response.output.content:
|
|
||||||
if hasattr(content_item, "type") and content_item.type == "output_text":
|
|
||||||
content = content_item.text
|
|
||||||
break
|
|
||||||
elif hasattr(response.output, "text"):
|
|
||||||
content = response.output.text
|
|
||||||
|
|
||||||
# Try to extract usage information
|
# Try to extract usage information
|
||||||
usage = None
|
usage = None
|
||||||
@@ -482,7 +541,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
completion_params[key] = value
|
completion_params[key] = value
|
||||||
|
|
||||||
# Check if this is o3-pro and needs the responses endpoint
|
# Check if this is o3-pro and needs the responses endpoint
|
||||||
if resolved_model == "o3-pro":
|
if resolved_model == "o3-pro-2025-06-10":
|
||||||
# This model requires the /v1/responses endpoint
|
# This model requires the /v1/responses endpoint
|
||||||
# If it fails, we should not fall back to chat/completions
|
# If it fails, we should not fall back to chat/completions
|
||||||
return self._generate_with_responses_endpoint(
|
return self._generate_with_responses_endpoint(
|
||||||
|
|||||||
@@ -351,6 +351,17 @@ class ModelProviderRegistry:
|
|||||||
instance = cls()
|
instance = cls()
|
||||||
instance._initialized_providers.clear()
|
instance._initialized_providers.clear()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reset_for_testing(cls) -> None:
|
||||||
|
"""Reset the registry to a clean state for testing.
|
||||||
|
|
||||||
|
This provides a safe, public API for tests to clean up registry state
|
||||||
|
without directly manipulating private attributes.
|
||||||
|
"""
|
||||||
|
cls._instance = None
|
||||||
|
if hasattr(cls, "_providers"):
|
||||||
|
cls._providers = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def unregister_provider(cls, provider_type: ProviderType) -> None:
|
def unregister_provider(cls, provider_type: ProviderType) -> None:
|
||||||
"""Unregister a provider (mainly for testing)."""
|
"""Unregister a provider (mainly for testing)."""
|
||||||
|
|||||||
@@ -15,13 +15,6 @@ parent_dir = Path(__file__).resolve().parent.parent
|
|||||||
if str(parent_dir) not in sys.path:
|
if str(parent_dir) not in sys.path:
|
||||||
sys.path.insert(0, str(parent_dir))
|
sys.path.insert(0, str(parent_dir))
|
||||||
|
|
||||||
# Set dummy API keys for tests if not already set or if empty
|
|
||||||
if not os.environ.get("GEMINI_API_KEY"):
|
|
||||||
os.environ["GEMINI_API_KEY"] = "dummy-key-for-tests"
|
|
||||||
if not os.environ.get("OPENAI_API_KEY"):
|
|
||||||
os.environ["OPENAI_API_KEY"] = "dummy-key-for-tests"
|
|
||||||
if not os.environ.get("XAI_API_KEY"):
|
|
||||||
os.environ["XAI_API_KEY"] = "dummy-key-for-tests"
|
|
||||||
|
|
||||||
# Set default model to a specific value for tests to avoid auto mode
|
# Set default model to a specific value for tests to avoid auto mode
|
||||||
# This prevents all tests from failing due to missing model parameter
|
# This prevents all tests from failing due to missing model parameter
|
||||||
@@ -77,11 +70,27 @@ def project_path(tmp_path):
|
|||||||
return test_dir
|
return test_dir
|
||||||
|
|
||||||
|
|
||||||
|
def _set_dummy_keys_if_missing():
|
||||||
|
"""Set dummy API keys only when they are completely absent."""
|
||||||
|
for var in ("GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY"):
|
||||||
|
if not os.environ.get(var):
|
||||||
|
os.environ[var] = "dummy-key-for-tests"
|
||||||
|
|
||||||
|
|
||||||
# Pytest configuration
|
# Pytest configuration
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
"""Configure pytest with custom markers"""
|
"""Configure pytest with custom markers"""
|
||||||
config.addinivalue_line("markers", "asyncio: mark test as async")
|
config.addinivalue_line("markers", "asyncio: mark test as async")
|
||||||
config.addinivalue_line("markers", "no_mock_provider: disable automatic provider mocking")
|
config.addinivalue_line("markers", "no_mock_provider: disable automatic provider mocking")
|
||||||
|
# Assume we need dummy keys until we learn otherwise
|
||||||
|
config._needs_dummy_keys = True
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_collection_modifyitems(session, config, items):
|
||||||
|
"""Hook that runs after test collection to check for no_mock_provider markers."""
|
||||||
|
# Always set dummy keys if real keys are missing
|
||||||
|
# This ensures tests work in CI even with no_mock_provider marker
|
||||||
|
_set_dummy_keys_if_missing()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
|
|||||||
376
tests/http_transport_recorder.py
Normal file
376
tests/http_transport_recorder.py
Normal file
@@ -0,0 +1,376 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
HTTP Transport Recorder for O3-Pro Testing
|
||||||
|
|
||||||
|
Custom httpx transport solution that replaces respx for recording/replaying
|
||||||
|
HTTP interactions. Provides full control over the recording process without
|
||||||
|
respx limitations.
|
||||||
|
|
||||||
|
Key Features:
|
||||||
|
- RecordingTransport: Wraps default transport, captures real HTTP calls
|
||||||
|
- ReplayTransport: Serves saved responses from cassettes
|
||||||
|
- TransportFactory: Auto-selects record vs replay mode
|
||||||
|
- JSON cassette format with data sanitization
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from .pii_sanitizer import PIISanitizer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RecordingTransport(httpx.HTTPTransport):
|
||||||
|
"""Transport that wraps default httpx transport and records all interactions."""
|
||||||
|
|
||||||
|
def __init__(self, cassette_path: str, capture_content: bool = True, sanitize: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
self.cassette_path = Path(cassette_path)
|
||||||
|
self.recorded_interactions = []
|
||||||
|
self.capture_content = capture_content
|
||||||
|
self.sanitizer = PIISanitizer() if sanitize else None
|
||||||
|
|
||||||
|
def handle_request(self, request: httpx.Request) -> httpx.Response:
|
||||||
|
"""Handle request by recording interaction and delegating to real transport."""
|
||||||
|
logger.debug(f"RecordingTransport: Making request to {request.method} {request.url}")
|
||||||
|
|
||||||
|
# Record request BEFORE making the call
|
||||||
|
request_data = self._serialize_request(request)
|
||||||
|
|
||||||
|
# Make real HTTP call using parent transport
|
||||||
|
response = super().handle_request(request)
|
||||||
|
|
||||||
|
logger.debug(f"RecordingTransport: Got response {response.status_code}")
|
||||||
|
|
||||||
|
# Post-response content capture (proper approach)
|
||||||
|
if self.capture_content:
|
||||||
|
try:
|
||||||
|
# Consume the response stream to capture content
|
||||||
|
# Note: httpx automatically handles gzip decompression
|
||||||
|
content_bytes = response.read()
|
||||||
|
response.close() # Close the original stream
|
||||||
|
logger.debug(f"RecordingTransport: Captured {len(content_bytes)} bytes")
|
||||||
|
|
||||||
|
# Serialize response with captured content
|
||||||
|
response_data = self._serialize_response_with_content(response, content_bytes)
|
||||||
|
|
||||||
|
# Create a new response with the same metadata but buffered content
|
||||||
|
# If the original response was gzipped, we need to re-compress
|
||||||
|
response_content = content_bytes
|
||||||
|
if response.headers.get("content-encoding") == "gzip":
|
||||||
|
import gzip
|
||||||
|
|
||||||
|
response_content = gzip.compress(content_bytes)
|
||||||
|
logger.debug(f"Re-compressed content: {len(content_bytes)} → {len(response_content)} bytes")
|
||||||
|
|
||||||
|
new_response = httpx.Response(
|
||||||
|
status_code=response.status_code,
|
||||||
|
headers=response.headers, # Keep original headers intact
|
||||||
|
content=response_content,
|
||||||
|
request=request,
|
||||||
|
extensions=response.extensions,
|
||||||
|
history=response.history,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record the interaction
|
||||||
|
self._record_interaction(request_data, response_data)
|
||||||
|
|
||||||
|
return new_response
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Content capture failed, falling back to stub", exc_info=True)
|
||||||
|
response_data = self._serialize_response(response)
|
||||||
|
self._record_interaction(request_data, response_data)
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
# Legacy mode: record with stub content
|
||||||
|
response_data = self._serialize_response(response)
|
||||||
|
self._record_interaction(request_data, response_data)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _record_interaction(self, request_data: dict[str, Any], response_data: dict[str, Any]):
|
||||||
|
"""Helper method to record interaction and save cassette."""
|
||||||
|
interaction = {"request": request_data, "response": response_data}
|
||||||
|
self.recorded_interactions.append(interaction)
|
||||||
|
self._save_cassette()
|
||||||
|
logger.debug(f"Saved cassette to {self.cassette_path}")
|
||||||
|
|
||||||
|
def _serialize_request(self, request: httpx.Request) -> dict[str, Any]:
|
||||||
|
"""Serialize httpx.Request to JSON-compatible format."""
|
||||||
|
# For requests, we can safely read the content since it's already been prepared
|
||||||
|
# httpx.Request.content is safe to access multiple times
|
||||||
|
content = request.content
|
||||||
|
|
||||||
|
# Convert bytes to string for JSON serialization
|
||||||
|
if isinstance(content, bytes):
|
||||||
|
try:
|
||||||
|
content_str = content.decode("utf-8")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
# Handle binary content (shouldn't happen for o3-pro API)
|
||||||
|
content_str = content.hex()
|
||||||
|
else:
|
||||||
|
content_str = str(content) if content else ""
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"method": request.method,
|
||||||
|
"url": str(request.url),
|
||||||
|
"path": request.url.path,
|
||||||
|
"headers": dict(request.headers),
|
||||||
|
"content": self._sanitize_request_content(content_str),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Apply PII sanitization if enabled
|
||||||
|
if self.sanitizer:
|
||||||
|
request_data = self.sanitizer.sanitize_request(request_data)
|
||||||
|
|
||||||
|
return request_data
|
||||||
|
|
||||||
|
def _serialize_response(self, response: httpx.Response) -> dict[str, Any]:
|
||||||
|
"""Serialize httpx.Response to JSON-compatible format (legacy method without content)."""
|
||||||
|
# Legacy method for backward compatibility when content capture is disabled
|
||||||
|
return {
|
||||||
|
"status_code": response.status_code,
|
||||||
|
"headers": dict(response.headers),
|
||||||
|
"content": {"note": "Response content not recorded to avoid httpx.ResponseNotRead exception"},
|
||||||
|
"reason_phrase": response.reason_phrase,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _serialize_response_with_content(self, response: httpx.Response, content_bytes: bytes) -> dict[str, Any]:
|
||||||
|
"""Serialize httpx.Response with captured content."""
|
||||||
|
try:
|
||||||
|
# Debug: check what we got
|
||||||
|
|
||||||
|
# Ensure we have bytes for base64 encoding
|
||||||
|
if not isinstance(content_bytes, bytes):
|
||||||
|
logger.warning(f"Content is not bytes, converting from {type(content_bytes)}")
|
||||||
|
if isinstance(content_bytes, str):
|
||||||
|
content_bytes = content_bytes.encode("utf-8")
|
||||||
|
else:
|
||||||
|
content_bytes = str(content_bytes).encode("utf-8")
|
||||||
|
|
||||||
|
# Encode content as base64 for JSON storage
|
||||||
|
content_b64 = base64.b64encode(content_bytes).decode("utf-8")
|
||||||
|
logger.debug(f"Base64 encoded {len(content_bytes)} bytes → {len(content_b64)} chars")
|
||||||
|
|
||||||
|
response_data = {
|
||||||
|
"status_code": response.status_code,
|
||||||
|
"headers": dict(response.headers),
|
||||||
|
"content": {"data": content_b64, "encoding": "base64", "size": len(content_bytes)},
|
||||||
|
"reason_phrase": response.reason_phrase,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Apply PII sanitization if enabled
|
||||||
|
if self.sanitizer:
|
||||||
|
response_data = self.sanitizer.sanitize_response(response_data)
|
||||||
|
|
||||||
|
return response_data
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error in _serialize_response_with_content")
|
||||||
|
# Fall back to minimal info
|
||||||
|
return {
|
||||||
|
"status_code": response.status_code,
|
||||||
|
"headers": dict(response.headers),
|
||||||
|
"content": {"error": f"Failed to serialize content: {e}"},
|
||||||
|
"reason_phrase": response.reason_phrase,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _sanitize_request_content(self, content: str) -> Any:
|
||||||
|
"""Sanitize request content to remove sensitive data."""
|
||||||
|
try:
|
||||||
|
if content.strip():
|
||||||
|
data = json.loads(content)
|
||||||
|
# Don't sanitize request content for now - it's user input
|
||||||
|
return data
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
return content
|
||||||
|
|
||||||
|
def _save_cassette(self):
|
||||||
|
"""Save recorded interactions to cassette file."""
|
||||||
|
# Ensure directory exists
|
||||||
|
self.cassette_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Save cassette
|
||||||
|
cassette_data = {"interactions": self.recorded_interactions}
|
||||||
|
|
||||||
|
self.cassette_path.write_text(json.dumps(cassette_data, indent=2, sort_keys=True))
|
||||||
|
|
||||||
|
|
||||||
|
class ReplayTransport(httpx.MockTransport):
|
||||||
|
"""Transport that replays saved HTTP interactions from cassettes."""
|
||||||
|
|
||||||
|
def __init__(self, cassette_path: str):
|
||||||
|
self.cassette_path = Path(cassette_path)
|
||||||
|
self.interactions = self._load_cassette()
|
||||||
|
super().__init__(self._handle_request)
|
||||||
|
|
||||||
|
def _load_cassette(self) -> list:
|
||||||
|
"""Load interactions from cassette file."""
|
||||||
|
if not self.cassette_path.exists():
|
||||||
|
raise FileNotFoundError(f"Cassette file not found: {self.cassette_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
cassette_data = json.loads(self.cassette_path.read_text())
|
||||||
|
return cassette_data.get("interactions", [])
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(f"Invalid cassette file format: {e}")
|
||||||
|
|
||||||
|
def _handle_request(self, request: httpx.Request) -> httpx.Response:
|
||||||
|
"""Handle request by finding matching interaction and returning saved response."""
|
||||||
|
logger.debug(f"ReplayTransport: Looking for {request.method} {request.url}")
|
||||||
|
|
||||||
|
# Debug: show what we're trying to match
|
||||||
|
request_signature = self._get_request_signature(request)
|
||||||
|
logger.debug(f"Request signature: {request_signature}")
|
||||||
|
|
||||||
|
# Find matching interaction
|
||||||
|
interaction = self._find_matching_interaction(request)
|
||||||
|
if not interaction:
|
||||||
|
logger.warning("No matching interaction found in cassette")
|
||||||
|
raise ValueError(f"No matching interaction found for {request.method} {request.url}")
|
||||||
|
|
||||||
|
logger.debug("Found matching interaction in cassette")
|
||||||
|
|
||||||
|
# Build response from saved data
|
||||||
|
response_data = interaction["response"]
|
||||||
|
|
||||||
|
# Convert content back to appropriate format
|
||||||
|
content = response_data.get("content", {})
|
||||||
|
if isinstance(content, dict):
|
||||||
|
# Check if this is base64-encoded content
|
||||||
|
if content.get("encoding") == "base64" and "data" in content:
|
||||||
|
# Decode base64 content
|
||||||
|
try:
|
||||||
|
content_bytes = base64.b64decode(content["data"])
|
||||||
|
logger.debug(f"Decoded {len(content_bytes)} bytes from base64")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to decode base64 content: {e}")
|
||||||
|
content_bytes = json.dumps(content).encode("utf-8")
|
||||||
|
else:
|
||||||
|
# Legacy format or stub content
|
||||||
|
content_bytes = json.dumps(content).encode("utf-8")
|
||||||
|
else:
|
||||||
|
content_bytes = str(content).encode("utf-8")
|
||||||
|
|
||||||
|
# Check if response expects gzipped content
|
||||||
|
headers = response_data.get("headers", {})
|
||||||
|
if headers.get("content-encoding") == "gzip":
|
||||||
|
# Re-compress the content for httpx
|
||||||
|
import gzip
|
||||||
|
|
||||||
|
content_bytes = gzip.compress(content_bytes)
|
||||||
|
logger.debug(f"Re-compressed for replay: {len(content_bytes)} bytes")
|
||||||
|
|
||||||
|
logger.debug(f"Returning cassette response ({len(content_bytes)} bytes)")
|
||||||
|
|
||||||
|
# Create httpx.Response
|
||||||
|
return httpx.Response(
|
||||||
|
status_code=response_data["status_code"],
|
||||||
|
headers=response_data.get("headers", {}),
|
||||||
|
content=content_bytes,
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _find_matching_interaction(self, request: httpx.Request) -> Optional[dict[str, Any]]:
|
||||||
|
"""Find interaction that matches the request."""
|
||||||
|
request_signature = self._get_request_signature(request)
|
||||||
|
|
||||||
|
for interaction in self.interactions:
|
||||||
|
saved_signature = self._get_saved_request_signature(interaction["request"])
|
||||||
|
if request_signature == saved_signature:
|
||||||
|
return interaction
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_request_signature(self, request: httpx.Request) -> str:
|
||||||
|
"""Generate signature for request matching."""
|
||||||
|
# Use method, path, and content hash for matching
|
||||||
|
content = request.content
|
||||||
|
if hasattr(content, "read"):
|
||||||
|
content = content.read()
|
||||||
|
|
||||||
|
if isinstance(content, bytes):
|
||||||
|
content_str = content.decode("utf-8", errors="ignore")
|
||||||
|
else:
|
||||||
|
content_str = str(content) if content else ""
|
||||||
|
|
||||||
|
# Parse JSON and re-serialize with sorted keys for consistent hashing
|
||||||
|
try:
|
||||||
|
if content_str.strip():
|
||||||
|
content_dict = json.loads(content_str)
|
||||||
|
content_str = json.dumps(content_dict, sort_keys=True)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Not JSON, use as-is
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Create hash of content for stable matching
|
||||||
|
content_hash = hashlib.md5(content_str.encode()).hexdigest()
|
||||||
|
|
||||||
|
return f"{request.method}:{request.url.path}:{content_hash}"
|
||||||
|
|
||||||
|
def _get_saved_request_signature(self, saved_request: dict[str, Any]) -> str:
|
||||||
|
"""Generate signature for saved request."""
|
||||||
|
method = saved_request["method"]
|
||||||
|
path = saved_request["path"]
|
||||||
|
|
||||||
|
# Hash the saved content
|
||||||
|
content = saved_request.get("content", "")
|
||||||
|
if isinstance(content, dict):
|
||||||
|
content_str = json.dumps(content, sort_keys=True)
|
||||||
|
else:
|
||||||
|
content_str = str(content)
|
||||||
|
|
||||||
|
content_hash = hashlib.md5(content_str.encode()).hexdigest()
|
||||||
|
|
||||||
|
return f"{method}:{path}:{content_hash}"
|
||||||
|
|
||||||
|
|
||||||
|
class TransportFactory:
|
||||||
|
"""Factory for creating appropriate transport based on cassette availability."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_transport(cassette_path: str) -> httpx.HTTPTransport:
|
||||||
|
"""Create transport based on cassette existence and API key availability."""
|
||||||
|
cassette_file = Path(cassette_path)
|
||||||
|
|
||||||
|
# Check if we should record or replay
|
||||||
|
if cassette_file.exists():
|
||||||
|
# Cassette exists - use replay mode
|
||||||
|
return ReplayTransport(cassette_path)
|
||||||
|
else:
|
||||||
|
# No cassette - use recording mode
|
||||||
|
# Note: We'll check for API key in the test itself
|
||||||
|
return RecordingTransport(cassette_path)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def should_record(cassette_path: str, api_key: Optional[str] = None) -> bool:
|
||||||
|
"""Determine if we should record based on cassette and API key availability."""
|
||||||
|
cassette_file = Path(cassette_path)
|
||||||
|
|
||||||
|
# Record if cassette doesn't exist AND we have API key
|
||||||
|
return not cassette_file.exists() and bool(api_key)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def should_replay(cassette_path: str) -> bool:
|
||||||
|
"""Determine if we should replay based on cassette availability."""
|
||||||
|
cassette_file = Path(cassette_path)
|
||||||
|
return cassette_file.exists()
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
#
|
||||||
|
# # In test setup:
|
||||||
|
# cassette_path = "tests/cassettes/o3_pro_basic_math.json"
|
||||||
|
# transport = TransportFactory.create_transport(cassette_path)
|
||||||
|
#
|
||||||
|
# # Inject into OpenAI client:
|
||||||
|
# provider._test_transport = transport
|
||||||
|
#
|
||||||
|
# # The provider's client property will detect _test_transport and use it
|
||||||
90
tests/openai_cassettes/o3_pro_basic_math.json
Normal file
90
tests/openai_cassettes/o3_pro_basic_math.json
Normal file
File diff suppressed because one or more lines are too long
290
tests/pii_sanitizer.py
Normal file
290
tests/pii_sanitizer.py
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
PII (Personally Identifiable Information) Sanitizer for HTTP recordings.
|
||||||
|
|
||||||
|
This module provides comprehensive sanitization of sensitive data in HTTP
|
||||||
|
request/response recordings to prevent accidental exposure of API keys,
|
||||||
|
tokens, personal information, and other sensitive data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from copy import deepcopy
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from re import Pattern
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PIIPattern:
|
||||||
|
"""Defines a pattern for detecting and sanitizing PII."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
pattern: Pattern[str]
|
||||||
|
replacement: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, name: str, pattern: str, replacement: str, description: str) -> "PIIPattern":
|
||||||
|
"""Create a PIIPattern with compiled regex."""
|
||||||
|
return cls(name=name, pattern=re.compile(pattern), replacement=replacement, description=description)
|
||||||
|
|
||||||
|
|
||||||
|
class PIISanitizer:
|
||||||
|
"""Sanitizes PII from various data structures while preserving format."""
|
||||||
|
|
||||||
|
def __init__(self, patterns: Optional[list[PIIPattern]] = None):
|
||||||
|
"""Initialize with optional custom patterns."""
|
||||||
|
self.patterns: list[PIIPattern] = patterns or []
|
||||||
|
self.sanitize_enabled = True
|
||||||
|
|
||||||
|
# Add default patterns if none provided
|
||||||
|
if not patterns:
|
||||||
|
self._add_default_patterns()
|
||||||
|
|
||||||
|
def _add_default_patterns(self):
|
||||||
|
"""Add comprehensive default PII patterns."""
|
||||||
|
default_patterns = [
|
||||||
|
# API Keys - Core patterns (Bearer tokens handled in sanitize_headers)
|
||||||
|
PIIPattern.create(
|
||||||
|
name="openai_api_key_proj",
|
||||||
|
pattern=r"sk-proj-[A-Za-z0-9\-_]{48,}",
|
||||||
|
replacement="sk-proj-SANITIZED",
|
||||||
|
description="OpenAI project API keys",
|
||||||
|
),
|
||||||
|
PIIPattern.create(
|
||||||
|
name="openai_api_key",
|
||||||
|
pattern=r"sk-[A-Za-z0-9]{48,}",
|
||||||
|
replacement="sk-SANITIZED",
|
||||||
|
description="OpenAI API keys",
|
||||||
|
),
|
||||||
|
PIIPattern.create(
|
||||||
|
name="anthropic_api_key",
|
||||||
|
pattern=r"sk-ant-[A-Za-z0-9\-_]{48,}",
|
||||||
|
replacement="sk-ant-SANITIZED",
|
||||||
|
description="Anthropic API keys",
|
||||||
|
),
|
||||||
|
PIIPattern.create(
|
||||||
|
name="google_api_key",
|
||||||
|
pattern=r"AIza[A-Za-z0-9\-_]{35,}",
|
||||||
|
replacement="AIza-SANITIZED",
|
||||||
|
description="Google API keys",
|
||||||
|
),
|
||||||
|
PIIPattern.create(
|
||||||
|
name="github_tokens",
|
||||||
|
pattern=r"gh[psr]_[A-Za-z0-9]{36}",
|
||||||
|
replacement="gh_SANITIZED",
|
||||||
|
description="GitHub tokens (all types)",
|
||||||
|
),
|
||||||
|
# JWT tokens
|
||||||
|
PIIPattern.create(
|
||||||
|
name="jwt_token",
|
||||||
|
pattern=r"eyJ[A-Za-z0-9\-_]+\.eyJ[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_]+",
|
||||||
|
replacement="eyJ-SANITIZED",
|
||||||
|
description="JSON Web Tokens",
|
||||||
|
),
|
||||||
|
# Personal Information
|
||||||
|
PIIPattern.create(
|
||||||
|
name="email_address",
|
||||||
|
pattern=r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}",
|
||||||
|
replacement="user@example.com",
|
||||||
|
description="Email addresses",
|
||||||
|
),
|
||||||
|
PIIPattern.create(
|
||||||
|
name="ipv4_address",
|
||||||
|
pattern=r"\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b",
|
||||||
|
replacement="0.0.0.0",
|
||||||
|
description="IPv4 addresses",
|
||||||
|
),
|
||||||
|
PIIPattern.create(
|
||||||
|
name="ssn",
|
||||||
|
pattern=r"\b\d{3}-\d{2}-\d{4}\b",
|
||||||
|
replacement="XXX-XX-XXXX",
|
||||||
|
description="Social Security Numbers",
|
||||||
|
),
|
||||||
|
PIIPattern.create(
|
||||||
|
name="credit_card",
|
||||||
|
pattern=r"\b\d{4}[\s\-]?\d{4}[\s\-]?\d{4}[\s\-]?\d{4}\b",
|
||||||
|
replacement="XXXX-XXXX-XXXX-XXXX",
|
||||||
|
description="Credit card numbers",
|
||||||
|
),
|
||||||
|
PIIPattern.create(
|
||||||
|
name="phone_number",
|
||||||
|
pattern=r"(?:\+\d{1,3}[\s\-]?)?\(?\d{3}\)?[\s\-]?\d{3}[\s\-]?\d{4}\b(?![\d\.\,\]\}])",
|
||||||
|
replacement="(XXX) XXX-XXXX",
|
||||||
|
description="Phone numbers (all formats)",
|
||||||
|
),
|
||||||
|
# AWS
|
||||||
|
PIIPattern.create(
|
||||||
|
name="aws_access_key",
|
||||||
|
pattern=r"AKIA[0-9A-Z]{16}",
|
||||||
|
replacement="AKIA-SANITIZED",
|
||||||
|
description="AWS access keys",
|
||||||
|
),
|
||||||
|
# Other common patterns
|
||||||
|
PIIPattern.create(
|
||||||
|
name="slack_token",
|
||||||
|
pattern=r"xox[baprs]-[0-9]{10,13}-[0-9]{10,13}-[a-zA-Z0-9]{24,34}",
|
||||||
|
replacement="xox-SANITIZED",
|
||||||
|
description="Slack tokens",
|
||||||
|
),
|
||||||
|
PIIPattern.create(
|
||||||
|
name="stripe_key",
|
||||||
|
pattern=r"(?:sk|pk)_(?:test|live)_[0-9a-zA-Z]{24,99}",
|
||||||
|
replacement="sk_SANITIZED",
|
||||||
|
description="Stripe API keys",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.patterns.extend(default_patterns)
|
||||||
|
|
||||||
|
def add_pattern(self, pattern: PIIPattern):
|
||||||
|
"""Add a custom PII pattern."""
|
||||||
|
self.patterns.append(pattern)
|
||||||
|
logger.info(f"Added PII pattern: {pattern.name}")
|
||||||
|
|
||||||
|
def sanitize_string(self, text: str) -> str:
|
||||||
|
"""Apply all patterns to sanitize a string."""
|
||||||
|
if not self.sanitize_enabled or not isinstance(text, str):
|
||||||
|
return text
|
||||||
|
|
||||||
|
sanitized = text
|
||||||
|
for pattern in self.patterns:
|
||||||
|
if pattern.pattern.search(sanitized):
|
||||||
|
sanitized = pattern.pattern.sub(pattern.replacement, sanitized)
|
||||||
|
logger.debug(f"Applied {pattern.name} sanitization")
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
def sanitize_headers(self, headers: dict[str, str]) -> dict[str, str]:
|
||||||
|
"""Special handling for HTTP headers."""
|
||||||
|
if not self.sanitize_enabled:
|
||||||
|
return headers
|
||||||
|
|
||||||
|
sanitized_headers = {}
|
||||||
|
|
||||||
|
for key, value in headers.items():
|
||||||
|
# Special case for Authorization headers to preserve auth type
|
||||||
|
if key.lower() == "authorization" and " " in value:
|
||||||
|
auth_type = value.split(" ", 1)[0]
|
||||||
|
if auth_type in ("Bearer", "Basic"):
|
||||||
|
sanitized_headers[key] = f"{auth_type} SANITIZED"
|
||||||
|
else:
|
||||||
|
sanitized_headers[key] = self.sanitize_string(value)
|
||||||
|
else:
|
||||||
|
# Apply standard sanitization to all other headers
|
||||||
|
sanitized_headers[key] = self.sanitize_string(value)
|
||||||
|
|
||||||
|
return sanitized_headers
|
||||||
|
|
||||||
|
def sanitize_value(self, value: Any) -> Any:
|
||||||
|
"""Recursively sanitize any value (string, dict, list, etc)."""
|
||||||
|
if not self.sanitize_enabled:
|
||||||
|
return value
|
||||||
|
|
||||||
|
if isinstance(value, str):
|
||||||
|
return self.sanitize_string(value)
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
return {k: self.sanitize_value(v) for k, v in value.items()}
|
||||||
|
elif isinstance(value, list):
|
||||||
|
return [self.sanitize_value(item) for item in value]
|
||||||
|
elif isinstance(value, tuple):
|
||||||
|
return tuple(self.sanitize_value(item) for item in value)
|
||||||
|
else:
|
||||||
|
# For other types (int, float, bool, None), return as-is
|
||||||
|
return value
|
||||||
|
|
||||||
|
def sanitize_url(self, url: str) -> str:
|
||||||
|
"""Sanitize sensitive data from URLs (query params, etc)."""
|
||||||
|
if not self.sanitize_enabled:
|
||||||
|
return url
|
||||||
|
|
||||||
|
# First apply general string sanitization
|
||||||
|
url = self.sanitize_string(url)
|
||||||
|
|
||||||
|
# Parse and sanitize query parameters
|
||||||
|
if "?" in url:
|
||||||
|
base, query = url.split("?", 1)
|
||||||
|
params = []
|
||||||
|
|
||||||
|
for param in query.split("&"):
|
||||||
|
if "=" in param:
|
||||||
|
key, value = param.split("=", 1)
|
||||||
|
# Sanitize common sensitive parameter names
|
||||||
|
sensitive_params = {"key", "token", "api_key", "secret", "password"}
|
||||||
|
if key.lower() in sensitive_params:
|
||||||
|
params.append(f"{key}=SANITIZED")
|
||||||
|
else:
|
||||||
|
# Still sanitize the value for PII
|
||||||
|
params.append(f"{key}={self.sanitize_string(value)}")
|
||||||
|
else:
|
||||||
|
params.append(param)
|
||||||
|
|
||||||
|
return f"{base}?{'&'.join(params)}"
|
||||||
|
|
||||||
|
return url
|
||||||
|
|
||||||
|
def sanitize_request(self, request_data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Sanitize a complete request dictionary."""
|
||||||
|
sanitized = deepcopy(request_data)
|
||||||
|
|
||||||
|
# Sanitize headers
|
||||||
|
if "headers" in sanitized:
|
||||||
|
sanitized["headers"] = self.sanitize_headers(sanitized["headers"])
|
||||||
|
|
||||||
|
# Sanitize URL
|
||||||
|
if "url" in sanitized:
|
||||||
|
sanitized["url"] = self.sanitize_url(sanitized["url"])
|
||||||
|
|
||||||
|
# Sanitize content
|
||||||
|
if "content" in sanitized:
|
||||||
|
sanitized["content"] = self.sanitize_value(sanitized["content"])
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
def sanitize_response(self, response_data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Sanitize a complete response dictionary."""
|
||||||
|
sanitized = deepcopy(response_data)
|
||||||
|
|
||||||
|
# Sanitize headers
|
||||||
|
if "headers" in sanitized:
|
||||||
|
sanitized["headers"] = self.sanitize_headers(sanitized["headers"])
|
||||||
|
|
||||||
|
# Sanitize content
|
||||||
|
if "content" in sanitized:
|
||||||
|
# Handle base64 encoded content specially
|
||||||
|
if isinstance(sanitized["content"], dict) and sanitized["content"].get("encoding") == "base64":
|
||||||
|
if "data" in sanitized["content"]:
|
||||||
|
import base64
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Decode, sanitize, and re-encode the actual response body
|
||||||
|
decoded_bytes = base64.b64decode(sanitized["content"]["data"])
|
||||||
|
# Attempt to decode as UTF-8 for sanitization. If it fails, it's likely binary.
|
||||||
|
try:
|
||||||
|
decoded_str = decoded_bytes.decode("utf-8")
|
||||||
|
sanitized_str = self.sanitize_string(decoded_str)
|
||||||
|
sanitized["content"]["data"] = base64.b64encode(sanitized_str.encode("utf-8")).decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
# Content is not text, leave as is.
|
||||||
|
pass
|
||||||
|
except (base64.binascii.Error, TypeError):
|
||||||
|
# Handle cases where data is not valid base64
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Sanitize other metadata fields
|
||||||
|
for key, value in sanitized["content"].items():
|
||||||
|
if key != "data":
|
||||||
|
sanitized["content"][key] = self.sanitize_value(value)
|
||||||
|
else:
|
||||||
|
sanitized["content"] = self.sanitize_value(sanitized["content"])
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
# Global instance for convenience
|
||||||
|
default_sanitizer = PIISanitizer()
|
||||||
110
tests/sanitize_cassettes.py
Executable file
110
tests/sanitize_cassettes.py
Executable file
@@ -0,0 +1,110 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Script to sanitize existing cassettes by applying PII sanitization.
|
||||||
|
|
||||||
|
This script will:
|
||||||
|
1. Load existing cassettes
|
||||||
|
2. Apply PII sanitization to all interactions
|
||||||
|
3. Create backups of originals
|
||||||
|
4. Save sanitized versions
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add tests directory to path to import our modules
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from pii_sanitizer import PIISanitizer
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_cassette(cassette_path: Path, backup: bool = True) -> bool:
|
||||||
|
"""Sanitize a single cassette file."""
|
||||||
|
print(f"\n🔍 Processing: {cassette_path}")
|
||||||
|
|
||||||
|
if not cassette_path.exists():
|
||||||
|
print(f"❌ File not found: {cassette_path}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load cassette
|
||||||
|
with open(cassette_path) as f:
|
||||||
|
cassette_data = json.load(f)
|
||||||
|
|
||||||
|
# Create backup if requested
|
||||||
|
if backup:
|
||||||
|
backup_path = cassette_path.with_suffix(f'.backup-{datetime.now().strftime("%Y%m%d-%H%M%S")}.json')
|
||||||
|
shutil.copy2(cassette_path, backup_path)
|
||||||
|
print(f"📦 Backup created: {backup_path}")
|
||||||
|
|
||||||
|
# Initialize sanitizer
|
||||||
|
sanitizer = PIISanitizer()
|
||||||
|
|
||||||
|
# Sanitize interactions
|
||||||
|
if "interactions" in cassette_data:
|
||||||
|
sanitized_interactions = []
|
||||||
|
|
||||||
|
for interaction in cassette_data["interactions"]:
|
||||||
|
sanitized_interaction = {}
|
||||||
|
|
||||||
|
# Sanitize request
|
||||||
|
if "request" in interaction:
|
||||||
|
sanitized_interaction["request"] = sanitizer.sanitize_request(interaction["request"])
|
||||||
|
|
||||||
|
# Sanitize response
|
||||||
|
if "response" in interaction:
|
||||||
|
sanitized_interaction["response"] = sanitizer.sanitize_response(interaction["response"])
|
||||||
|
|
||||||
|
sanitized_interactions.append(sanitized_interaction)
|
||||||
|
|
||||||
|
cassette_data["interactions"] = sanitized_interactions
|
||||||
|
|
||||||
|
# Save sanitized cassette
|
||||||
|
with open(cassette_path, "w") as f:
|
||||||
|
json.dump(cassette_data, f, indent=2, sort_keys=True)
|
||||||
|
|
||||||
|
print(f"✅ Sanitized: {cassette_path}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error processing {cassette_path}: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Sanitize all cassettes in the openai_cassettes directory."""
|
||||||
|
cassettes_dir = Path(__file__).parent / "openai_cassettes"
|
||||||
|
|
||||||
|
if not cassettes_dir.exists():
|
||||||
|
print(f"❌ Directory not found: {cassettes_dir}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Find all JSON cassettes
|
||||||
|
cassette_files = list(cassettes_dir.glob("*.json"))
|
||||||
|
|
||||||
|
if not cassette_files:
|
||||||
|
print(f"❌ No cassette files found in {cassettes_dir}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"🎬 Found {len(cassette_files)} cassette(s) to sanitize")
|
||||||
|
|
||||||
|
# Process each cassette
|
||||||
|
success_count = 0
|
||||||
|
for cassette_path in cassette_files:
|
||||||
|
if sanitize_cassette(cassette_path):
|
||||||
|
success_count += 1
|
||||||
|
|
||||||
|
print(f"\n✨ Sanitization complete: {success_count}/{len(cassette_files)} cassettes processed successfully")
|
||||||
|
|
||||||
|
if success_count < len(cassette_files):
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -663,9 +663,13 @@ class TestAutoModeWithRestrictions:
|
|||||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||||
assert model == "o4-mini"
|
assert model == "o4-mini"
|
||||||
|
|
||||||
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini", "GEMINI_API_KEY": "", "OPENAI_API_KEY": "test-key"})
|
def test_fallback_with_shorthand_restrictions(self, monkeypatch):
|
||||||
def test_fallback_with_shorthand_restrictions(self):
|
|
||||||
"""Test fallback model selection with shorthand restrictions."""
|
"""Test fallback model selection with shorthand restrictions."""
|
||||||
|
# Use monkeypatch to set environment variables with automatic cleanup
|
||||||
|
monkeypatch.setenv("OPENAI_ALLOWED_MODELS", "mini")
|
||||||
|
monkeypatch.setenv("GEMINI_API_KEY", "")
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||||
|
|
||||||
# Clear caches and reset registry
|
# Clear caches and reset registry
|
||||||
import utils.model_restrictions
|
import utils.model_restrictions
|
||||||
from providers.registry import ModelProviderRegistry
|
from providers.registry import ModelProviderRegistry
|
||||||
|
|||||||
124
tests/test_o3_pro_output_text_fix.py
Normal file
124
tests/test_o3_pro_output_text_fix.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
"""
|
||||||
|
Tests for o3-pro output_text parsing fix using HTTP transport recording.
|
||||||
|
|
||||||
|
This test validates the fix that uses `response.output_text` convenience field
|
||||||
|
instead of manually parsing `response.output.content[].text`.
|
||||||
|
|
||||||
|
Uses HTTP transport recorder to record real o3-pro API responses at the HTTP level while allowing
|
||||||
|
the OpenAI SDK to create real response objects that we can test.
|
||||||
|
|
||||||
|
RECORDING: To record new responses, delete the cassette file and run with real API keys.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from providers import ModelProviderRegistry
|
||||||
|
from tests.transport_helpers import inject_transport
|
||||||
|
from tools.chat import ChatTool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Load environment variables from .env file
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Use absolute path for cassette directory
|
||||||
|
cassette_dir = Path(__file__).parent / "openai_cassettes"
|
||||||
|
cassette_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestO3ProOutputTextFix:
|
||||||
|
"""Test o3-pro response parsing fix using respx for HTTP recording/replay."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Set up the test by ensuring clean registry state."""
|
||||||
|
# Use the new public API for registry cleanup
|
||||||
|
ModelProviderRegistry.reset_for_testing()
|
||||||
|
# Provider registration is now handled by inject_transport helper
|
||||||
|
|
||||||
|
# Clear restriction service to ensure it re-reads environment
|
||||||
|
# This is necessary because previous tests may have set restrictions
|
||||||
|
# that are cached in the singleton
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
"""Clean up after test to ensure no state pollution."""
|
||||||
|
# Use the new public API for registry cleanup
|
||||||
|
ModelProviderRegistry.reset_for_testing()
|
||||||
|
|
||||||
|
@pytest.mark.no_mock_provider # Disable provider mocking for this test
|
||||||
|
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-pro,o3-pro-2025-06-10", "LOCALE": ""})
|
||||||
|
async def test_o3_pro_uses_output_text_field(self, monkeypatch):
|
||||||
|
"""Test that o3-pro parsing uses the output_text convenience field via ChatTool."""
|
||||||
|
cassette_path = cassette_dir / "o3_pro_basic_math.json"
|
||||||
|
|
||||||
|
# Check if we need to record or replay
|
||||||
|
if not cassette_path.exists():
|
||||||
|
# Recording mode - check for real API key
|
||||||
|
real_api_key = os.getenv("OPENAI_API_KEY", "").strip()
|
||||||
|
if not real_api_key or real_api_key.startswith("dummy"):
|
||||||
|
pytest.fail(
|
||||||
|
f"Cassette file not found at {cassette_path}. "
|
||||||
|
"To record: Set OPENAI_API_KEY environment variable to a valid key and run this test. "
|
||||||
|
"Note: Recording will make a real API call to OpenAI."
|
||||||
|
)
|
||||||
|
# Real API key is available, we'll record the cassette
|
||||||
|
logger.debug("🎬 Recording mode: Using real API key to record cassette")
|
||||||
|
else:
|
||||||
|
# Replay mode - use dummy key
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "dummy-key-for-replay")
|
||||||
|
logger.debug("📼 Replay mode: Using recorded cassette")
|
||||||
|
|
||||||
|
# Simplified transport injection - just one line!
|
||||||
|
inject_transport(monkeypatch, cassette_path)
|
||||||
|
|
||||||
|
# Execute ChatTool test with custom transport
|
||||||
|
result = await self._execute_chat_tool_test()
|
||||||
|
|
||||||
|
# Verify the response works correctly
|
||||||
|
self._verify_chat_tool_response(result)
|
||||||
|
|
||||||
|
# Verify cassette exists
|
||||||
|
assert cassette_path.exists()
|
||||||
|
|
||||||
|
async def _execute_chat_tool_test(self):
|
||||||
|
"""Execute the ChatTool with o3-pro and return the result."""
|
||||||
|
chat_tool = ChatTool()
|
||||||
|
arguments = {"prompt": "What is 2 + 2?", "model": "o3-pro", "temperature": 1.0}
|
||||||
|
|
||||||
|
return await chat_tool.execute(arguments)
|
||||||
|
|
||||||
|
def _verify_chat_tool_response(self, result):
|
||||||
|
"""Verify the ChatTool response contains expected data."""
|
||||||
|
# Basic response validation
|
||||||
|
assert result is not None
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) > 0
|
||||||
|
assert result[0].type == "text"
|
||||||
|
|
||||||
|
# Parse JSON response
|
||||||
|
import json
|
||||||
|
|
||||||
|
response_data = json.loads(result[0].text)
|
||||||
|
|
||||||
|
# Debug log the response
|
||||||
|
logger.debug(f"Response data: {json.dumps(response_data, indent=2)}")
|
||||||
|
|
||||||
|
# Verify response structure - no cargo culting
|
||||||
|
if response_data["status"] == "error":
|
||||||
|
pytest.fail(f"Chat tool returned error: {response_data.get('error', 'Unknown error')}")
|
||||||
|
assert response_data["status"] in ["success", "continuation_available"]
|
||||||
|
assert "4" in response_data["content"]
|
||||||
|
|
||||||
|
# Verify o3-pro was actually used
|
||||||
|
metadata = response_data["metadata"]
|
||||||
|
assert metadata["model_used"] == "o3-pro"
|
||||||
|
assert metadata["provider_used"] == "openai"
|
||||||
@@ -278,11 +278,9 @@ class TestOpenAIProvider:
|
|||||||
mock_openai_class.return_value = mock_client
|
mock_openai_class.return_value = mock_client
|
||||||
|
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.output = MagicMock()
|
# New o3-pro format: direct output_text field
|
||||||
mock_response.output.content = [MagicMock()]
|
mock_response.output_text = "4"
|
||||||
mock_response.output.content[0].type = "output_text"
|
mock_response.model = "o3-pro-2025-06-10"
|
||||||
mock_response.output.content[0].text = "4"
|
|
||||||
mock_response.model = "o3-pro"
|
|
||||||
mock_response.id = "test-id"
|
mock_response.id = "test-id"
|
||||||
mock_response.created_at = 1234567890
|
mock_response.created_at = 1234567890
|
||||||
mock_response.usage = MagicMock()
|
mock_response.usage = MagicMock()
|
||||||
@@ -300,13 +298,13 @@ class TestOpenAIProvider:
|
|||||||
# Verify responses.create was called
|
# Verify responses.create was called
|
||||||
mock_client.responses.create.assert_called_once()
|
mock_client.responses.create.assert_called_once()
|
||||||
call_args = mock_client.responses.create.call_args[1]
|
call_args = mock_client.responses.create.call_args[1]
|
||||||
assert call_args["model"] == "o3-pro"
|
assert call_args["model"] == "o3-pro-2025-06-10"
|
||||||
assert call_args["input"][0]["role"] == "user"
|
assert call_args["input"][0]["role"] == "user"
|
||||||
assert "What is 2 + 2?" in call_args["input"][0]["content"][0]["text"]
|
assert "What is 2 + 2?" in call_args["input"][0]["content"][0]["text"]
|
||||||
|
|
||||||
# Verify the response
|
# Verify the response
|
||||||
assert result.content == "4"
|
assert result.content == "4"
|
||||||
assert result.model_name == "o3-pro"
|
assert result.model_name == "o3-pro-2025-06-10"
|
||||||
assert result.metadata["endpoint"] == "responses"
|
assert result.metadata["endpoint"] == "responses"
|
||||||
|
|
||||||
@patch("providers.openai_compatible.OpenAI")
|
@patch("providers.openai_compatible.OpenAI")
|
||||||
|
|||||||
143
tests/test_pii_sanitizer.py
Normal file
143
tests/test_pii_sanitizer.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Test cases for PII sanitizer."""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from .pii_sanitizer import PIIPattern, PIISanitizer
|
||||||
|
|
||||||
|
|
||||||
|
class TestPIISanitizer(unittest.TestCase):
|
||||||
|
"""Test PII sanitization functionality."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test sanitizer."""
|
||||||
|
self.sanitizer = PIISanitizer()
|
||||||
|
|
||||||
|
def test_api_key_sanitization(self):
|
||||||
|
"""Test various API key formats are sanitized."""
|
||||||
|
test_cases = [
|
||||||
|
# OpenAI keys
|
||||||
|
("sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12", "sk-proj-SANITIZED"),
|
||||||
|
("sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN", "sk-SANITIZED"),
|
||||||
|
# Anthropic keys
|
||||||
|
("sk-ant-abcd1234567890ABCD1234567890abcd1234567890ABCD12", "sk-ant-SANITIZED"),
|
||||||
|
# Google keys
|
||||||
|
("AIzaSyD-1234567890abcdefghijklmnopqrstuv", "AIza-SANITIZED"),
|
||||||
|
# GitHub tokens
|
||||||
|
("ghp_1234567890abcdefghijklmnopqrstuvwxyz", "gh_SANITIZED"),
|
||||||
|
("ghs_1234567890abcdefghijklmnopqrstuvwxyz", "gh_SANITIZED"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for original, expected in test_cases:
|
||||||
|
with self.subTest(original=original):
|
||||||
|
result = self.sanitizer.sanitize_string(original)
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
def test_personal_info_sanitization(self):
|
||||||
|
"""Test personal information is sanitized."""
|
||||||
|
test_cases = [
|
||||||
|
# Email addresses
|
||||||
|
("john.doe@example.com", "user@example.com"),
|
||||||
|
("test123@company.org", "user@example.com"),
|
||||||
|
# Phone numbers (all now use the same pattern)
|
||||||
|
("(555) 123-4567", "(XXX) XXX-XXXX"),
|
||||||
|
("555-123-4567", "(XXX) XXX-XXXX"),
|
||||||
|
("+1-555-123-4567", "(XXX) XXX-XXXX"),
|
||||||
|
# SSN
|
||||||
|
("123-45-6789", "XXX-XX-XXXX"),
|
||||||
|
# Credit card
|
||||||
|
("1234 5678 9012 3456", "XXXX-XXXX-XXXX-XXXX"),
|
||||||
|
("1234-5678-9012-3456", "XXXX-XXXX-XXXX-XXXX"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for original, expected in test_cases:
|
||||||
|
with self.subTest(original=original):
|
||||||
|
result = self.sanitizer.sanitize_string(original)
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
def test_header_sanitization(self):
|
||||||
|
"""Test HTTP header sanitization."""
|
||||||
|
headers = {
|
||||||
|
"Authorization": "Bearer sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12",
|
||||||
|
"API-Key": "sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"User-Agent": "MyApp/1.0",
|
||||||
|
"Cookie": "session=abc123; user=john.doe@example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
sanitized = self.sanitizer.sanitize_headers(headers)
|
||||||
|
|
||||||
|
self.assertEqual(sanitized["Authorization"], "Bearer SANITIZED")
|
||||||
|
self.assertEqual(sanitized["API-Key"], "sk-SANITIZED")
|
||||||
|
self.assertEqual(sanitized["Content-Type"], "application/json")
|
||||||
|
self.assertEqual(sanitized["User-Agent"], "MyApp/1.0")
|
||||||
|
self.assertIn("user@example.com", sanitized["Cookie"])
|
||||||
|
|
||||||
|
def test_nested_structure_sanitization(self):
|
||||||
|
"""Test sanitization of nested data structures."""
|
||||||
|
data = {
|
||||||
|
"user": {
|
||||||
|
"email": "john.doe@example.com",
|
||||||
|
"api_key": "sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12",
|
||||||
|
},
|
||||||
|
"tokens": [
|
||||||
|
"ghp_1234567890abcdefghijklmnopqrstuvwxyz",
|
||||||
|
"Bearer sk-ant-abcd1234567890ABCD1234567890abcd1234567890ABCD12",
|
||||||
|
],
|
||||||
|
"metadata": {"ip": "192.168.1.100", "phone": "(555) 123-4567"},
|
||||||
|
}
|
||||||
|
|
||||||
|
sanitized = self.sanitizer.sanitize_value(data)
|
||||||
|
|
||||||
|
self.assertEqual(sanitized["user"]["email"], "user@example.com")
|
||||||
|
self.assertEqual(sanitized["user"]["api_key"], "sk-proj-SANITIZED")
|
||||||
|
self.assertEqual(sanitized["tokens"][0], "gh_SANITIZED")
|
||||||
|
self.assertEqual(sanitized["tokens"][1], "Bearer sk-ant-SANITIZED")
|
||||||
|
self.assertEqual(sanitized["metadata"]["ip"], "0.0.0.0")
|
||||||
|
self.assertEqual(sanitized["metadata"]["phone"], "(XXX) XXX-XXXX")
|
||||||
|
|
||||||
|
def test_url_sanitization(self):
|
||||||
|
"""Test URL parameter sanitization."""
|
||||||
|
urls = [
|
||||||
|
(
|
||||||
|
"https://api.example.com/v1/users?api_key=sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN",
|
||||||
|
"https://api.example.com/v1/users?api_key=SANITIZED",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"https://example.com/login?token=ghp_1234567890abcdefghijklmnopqrstuvwxyz&user=test",
|
||||||
|
"https://example.com/login?token=SANITIZED&user=test",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
for original, expected in urls:
|
||||||
|
with self.subTest(url=original):
|
||||||
|
result = self.sanitizer.sanitize_url(original)
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
def test_disable_sanitization(self):
|
||||||
|
"""Test that sanitization can be disabled."""
|
||||||
|
self.sanitizer.sanitize_enabled = False
|
||||||
|
|
||||||
|
sensitive_data = "sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12"
|
||||||
|
result = self.sanitizer.sanitize_string(sensitive_data)
|
||||||
|
|
||||||
|
# Should return original when disabled
|
||||||
|
self.assertEqual(result, sensitive_data)
|
||||||
|
|
||||||
|
def test_custom_pattern(self):
|
||||||
|
"""Test adding custom PII patterns."""
|
||||||
|
# Add custom pattern for internal employee IDs
|
||||||
|
custom_pattern = PIIPattern.create(
|
||||||
|
name="employee_id", pattern=r"EMP\d{6}", replacement="EMP-REDACTED", description="Internal employee IDs"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sanitizer.add_pattern(custom_pattern)
|
||||||
|
|
||||||
|
text = "Employee EMP123456 has access to the system"
|
||||||
|
result = self.sanitizer.sanitize_string(text)
|
||||||
|
|
||||||
|
self.assertEqual(result, "Employee EMP-REDACTED has access to the system")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
47
tests/transport_helpers.py
Normal file
47
tests/transport_helpers.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""Helper functions for HTTP transport injection in tests."""
|
||||||
|
|
||||||
|
from tests.http_transport_recorder import TransportFactory
|
||||||
|
|
||||||
|
|
||||||
|
def inject_transport(monkeypatch, cassette_path: str):
|
||||||
|
"""Inject HTTP transport into OpenAICompatibleProvider for testing.
|
||||||
|
|
||||||
|
This helper simplifies the monkey patching pattern used across tests
|
||||||
|
to inject custom HTTP transports for recording/replaying API calls.
|
||||||
|
|
||||||
|
Also ensures OpenAI provider is properly registered for tests that need it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
monkeypatch: pytest monkeypatch fixture
|
||||||
|
cassette_path: Path to cassette file for recording/replay
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created transport instance
|
||||||
|
|
||||||
|
Example:
|
||||||
|
transport = inject_transport(monkeypatch, "path/to/cassette.json")
|
||||||
|
"""
|
||||||
|
# Ensure OpenAI provider is registered - always needed for transport injection
|
||||||
|
from providers.base import ProviderType
|
||||||
|
from providers.openai_provider import OpenAIModelProvider
|
||||||
|
from providers.registry import ModelProviderRegistry
|
||||||
|
|
||||||
|
# Always register OpenAI provider for transport tests (API key might be dummy)
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
|
||||||
|
# Create transport
|
||||||
|
transport = TransportFactory.create_transport(str(cassette_path))
|
||||||
|
|
||||||
|
# Inject transport using the established pattern
|
||||||
|
from providers.openai_compatible import OpenAICompatibleProvider
|
||||||
|
|
||||||
|
original_client_property = OpenAICompatibleProvider.client
|
||||||
|
|
||||||
|
def patched_client_getter(self):
|
||||||
|
if self._client is None:
|
||||||
|
self._test_transport = transport
|
||||||
|
return original_client_property.fget(self)
|
||||||
|
|
||||||
|
monkeypatch.setattr(OpenAICompatibleProvider, "client", property(patched_client_getter))
|
||||||
|
|
||||||
|
return transport
|
||||||
Reference in New Issue
Block a user