From 3db49413ff29cf30b311715076bfad2c242f0fec Mon Sep 17 00:00:00 2001 From: Josh Vera Date: Sat, 12 Jul 2025 20:24:34 -0600 Subject: [PATCH] fix: Resolve o3-pro response parsing and test execution issues MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix lint errors: trailing whitespace and deprecated typing imports - Update test mock for o3-pro response format (output.content[] → output_text) - Implement robust test isolation with monkeypatch fixture - Clear provider registry cache to prevent test interference - Ensure o3-pro tests pass in both individual and full suite execution šŸ¤– Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- providers/openai_compatible.py | 6 +- tests/conftest.py | 2 +- tests/http_transport_recorder.py | 229 +++++++++++++-------------- tests/pii_sanitizer.py | 202 ++++++++++++----------- tests/sanitize_cassettes.py | 65 ++++---- tests/test_o3_pro_output_text_fix.py | 59 +++++-- tests/test_openai_provider.py | 6 +- tests/test_pii_sanitizer.py | 79 +++++---- 8 files changed, 328 insertions(+), 320 deletions(-) diff --git a/providers/openai_compatible.py b/providers/openai_compatible.py index d718264..6e564cc 100644 --- a/providers/openai_compatible.py +++ b/providers/openai_compatible.py @@ -221,7 +221,7 @@ class OpenAICompatibleProvider(ModelProvider): # Create httpx client with minimal config to avoid proxy conflicts # Note: proxies parameter was removed in httpx 0.28.0 # Check for test transport injection - if hasattr(self, '_test_transport'): + if hasattr(self, "_test_transport"): # Use custom transport for testing (HTTP recording/replay) http_client = httpx.Client( transport=self._test_transport, @@ -318,13 +318,13 @@ class OpenAICompatibleProvider(ModelProvider): """ logging.debug(f"Response object type: {type(response)}") logging.debug(f"Response attributes: {dir(response)}") - + if not hasattr(response, "output_text"): raise ValueError(f"o3-pro response missing output_text field. Response type: {type(response).__name__}") content = response.output_text logging.debug(f"Extracted output_text: '{content}' (type: {type(content)})") - + if content is None: raise ValueError("o3-pro returned None for output_text") diff --git a/tests/conftest.py b/tests/conftest.py index d7014a7..0c4775a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -93,7 +93,7 @@ def pytest_collection_modifyitems(session, config, items): if item.get_closest_marker("no_mock_provider"): config._needs_dummy_keys = False break - + # Set dummy keys only if no test needs real keys if config._needs_dummy_keys: _set_dummy_keys_if_missing() diff --git a/tests/http_transport_recorder.py b/tests/http_transport_recorder.py index bde3ab8..d98b813 100644 --- a/tests/http_transport_recorder.py +++ b/tests/http_transport_recorder.py @@ -2,7 +2,7 @@ """ HTTP Transport Recorder for O3-Pro Testing -Custom httpx transport solution that replaces respx for recording/replaying +Custom httpx transport solution that replaces respx for recording/replaying HTTP interactions. Provides full control over the recording process without respx limitations. @@ -13,40 +13,40 @@ Key Features: - JSON cassette format with data sanitization """ -import json -import hashlib -import copy import base64 +import copy +import hashlib +import json from pathlib import Path -from typing import Dict, Any, Optional -import httpx -from io import BytesIO -from .pii_sanitizer import PIISanitizer +from typing import Any, Optional +import httpx + +from .pii_sanitizer import PIISanitizer class RecordingTransport(httpx.HTTPTransport): """Transport that wraps default httpx transport and records all interactions.""" - + def __init__(self, cassette_path: str, capture_content: bool = True, sanitize: bool = True): super().__init__() self.cassette_path = Path(cassette_path) self.recorded_interactions = [] self.capture_content = capture_content self.sanitizer = PIISanitizer() if sanitize else None - + def handle_request(self, request: httpx.Request) -> httpx.Response: """Handle request by recording interaction and delegating to real transport.""" print(f"šŸŽ¬ RecordingTransport: Making request to {request.method} {request.url}") - + # Record request BEFORE making the call request_data = self._serialize_request(request) - + # Make real HTTP call using parent transport response = super().handle_request(request) - + print(f"šŸŽ¬ RecordingTransport: Got response {response.status_code}") - + # Post-response content capture (proper approach) if self.capture_content: try: @@ -55,19 +55,20 @@ class RecordingTransport(httpx.HTTPTransport): content_bytes = response.read() response.close() # Close the original stream print(f"šŸŽ¬ RecordingTransport: Captured {len(content_bytes)} bytes of decompressed content") - + # Serialize response with captured content response_data = self._serialize_response_with_content(response, content_bytes) - + # Create a new response with the same metadata but buffered content # If the original response was gzipped, we need to re-compress response_content = content_bytes - if response.headers.get('content-encoding') == 'gzip': + if response.headers.get("content-encoding") == "gzip": import gzip + print(f"šŸ—œļø Re-compressing {len(content_bytes)} bytes with gzip...") response_content = gzip.compress(content_bytes) print(f"šŸ—œļø Compressed to {len(response_content)} bytes") - + new_response = httpx.Response( status_code=response.status_code, headers=response.headers, # Keep original headers intact @@ -76,15 +77,16 @@ class RecordingTransport(httpx.HTTPTransport): extensions=response.extensions, history=response.history, ) - + # Record the interaction self._record_interaction(request_data, response_data) - + return new_response - + except Exception as e: print(f"āš ļø Content capture failed: {e}, falling back to stub") import traceback + print(f"āš ļø Full exception traceback:\n{traceback.format_exc()}") response_data = self._serialize_response(response) self._record_interaction(request_data, response_data) @@ -94,105 +96,99 @@ class RecordingTransport(httpx.HTTPTransport): response_data = self._serialize_response(response) self._record_interaction(request_data, response_data) return response - - def _record_interaction(self, request_data: Dict[str, Any], response_data: Dict[str, Any]): + + def _record_interaction(self, request_data: dict[str, Any], response_data: dict[str, Any]): """Helper method to record interaction and save cassette.""" - interaction = { - "request": request_data, - "response": response_data - } + interaction = {"request": request_data, "response": response_data} self.recorded_interactions.append(interaction) self._save_cassette() print(f"šŸŽ¬ RecordingTransport: Saved cassette to {self.cassette_path}") - - def _serialize_request(self, request: httpx.Request) -> Dict[str, Any]: + + def _serialize_request(self, request: httpx.Request) -> dict[str, Any]: """Serialize httpx.Request to JSON-compatible format.""" # For requests, we can safely read the content since it's already been prepared # httpx.Request.content is safe to access multiple times content = request.content - + # Convert bytes to string for JSON serialization if isinstance(content, bytes): try: - content_str = content.decode('utf-8') + content_str = content.decode("utf-8") except UnicodeDecodeError: # Handle binary content (shouldn't happen for o3-pro API) content_str = content.hex() else: content_str = str(content) if content else "" - + request_data = { "method": request.method, "url": str(request.url), "path": request.url.path, "headers": dict(request.headers), - "content": self._sanitize_request_content(content_str) + "content": self._sanitize_request_content(content_str), } - + # Apply PII sanitization if enabled if self.sanitizer: request_data = self.sanitizer.sanitize_request(request_data) - + return request_data - - def _serialize_response(self, response: httpx.Response) -> Dict[str, Any]: + + def _serialize_response(self, response: httpx.Response) -> dict[str, Any]: """Serialize httpx.Response to JSON-compatible format (legacy method without content).""" # Legacy method for backward compatibility when content capture is disabled return { "status_code": response.status_code, "headers": dict(response.headers), "content": {"note": "Response content not recorded to avoid httpx.ResponseNotRead exception"}, - "reason_phrase": response.reason_phrase + "reason_phrase": response.reason_phrase, } - - def _serialize_response_with_content(self, response: httpx.Response, content_bytes: bytes) -> Dict[str, Any]: + + def _serialize_response_with_content(self, response: httpx.Response, content_bytes: bytes) -> dict[str, Any]: """Serialize httpx.Response with captured content.""" try: # Debug: check what we got print(f"šŸ” Content type: {type(content_bytes)}, size: {len(content_bytes)}") print(f"šŸ” First 100 chars: {content_bytes[:100]}") - + # Ensure we have bytes for base64 encoding if not isinstance(content_bytes, bytes): print(f"āš ļø Content is not bytes, converting from {type(content_bytes)}") if isinstance(content_bytes, str): - content_bytes = content_bytes.encode('utf-8') + content_bytes = content_bytes.encode("utf-8") else: - content_bytes = str(content_bytes).encode('utf-8') - + content_bytes = str(content_bytes).encode("utf-8") + # Encode content as base64 for JSON storage print(f"šŸ” Base64 encoding {len(content_bytes)} bytes...") - content_b64 = base64.b64encode(content_bytes).decode('utf-8') + content_b64 = base64.b64encode(content_bytes).decode("utf-8") print(f"āœ… Base64 encoded successfully, result length: {len(content_b64)}") - + response_data = { "status_code": response.status_code, "headers": dict(response.headers), - "content": { - "data": content_b64, - "encoding": "base64", - "size": len(content_bytes) - }, - "reason_phrase": response.reason_phrase + "content": {"data": content_b64, "encoding": "base64", "size": len(content_bytes)}, + "reason_phrase": response.reason_phrase, } - + # Apply PII sanitization if enabled if self.sanitizer: response_data = self.sanitizer.sanitize_response(response_data) - + return response_data except Exception as e: print(f"šŸ” Error in _serialize_response_with_content: {e}") import traceback + print(f"šŸ” Full traceback: {traceback.format_exc()}") # Fall back to minimal info return { "status_code": response.status_code, "headers": dict(response.headers), "content": {"error": f"Failed to serialize content: {e}"}, - "reason_phrase": response.reason_phrase + "reason_phrase": response.reason_phrase, } - + def _sanitize_request_content(self, content: str) -> Any: """Sanitize request content to remove sensitive data.""" try: @@ -203,14 +199,14 @@ class RecordingTransport(httpx.HTTPTransport): except json.JSONDecodeError: pass return content - + def _sanitize_response_content(self, data: Any) -> Any: """Sanitize response content to remove sensitive data.""" if not isinstance(data, dict): return data - + sanitized = copy.deepcopy(data) - + # Sensitive fields to sanitize sensitive_fields = { "id": "resp_SANITIZED", @@ -218,7 +214,7 @@ class RecordingTransport(httpx.HTTPTransport): "created_at": 0, "system_fingerprint": "fp_SANITIZED", } - + def sanitize_dict(obj): if isinstance(obj, dict): for key, value in obj.items(): @@ -230,82 +226,76 @@ class RecordingTransport(httpx.HTTPTransport): for item in obj: if isinstance(item, (dict, list)): sanitize_dict(item) - + sanitize_dict(sanitized) return sanitized - + def _save_cassette(self): """Save recorded interactions to cassette file.""" # Ensure directory exists self.cassette_path.parent.mkdir(parents=True, exist_ok=True) - + # Save cassette - cassette_data = { - "interactions": self.recorded_interactions - } - - self.cassette_path.write_text( - json.dumps(cassette_data, indent=2, sort_keys=True) - ) + cassette_data = {"interactions": self.recorded_interactions} + + self.cassette_path.write_text(json.dumps(cassette_data, indent=2, sort_keys=True)) class ReplayTransport(httpx.MockTransport): """Transport that replays saved HTTP interactions from cassettes.""" - + def __init__(self, cassette_path: str): self.cassette_path = Path(cassette_path) self.interactions = self._load_cassette() super().__init__(self._handle_request) - + def _load_cassette(self) -> list: """Load interactions from cassette file.""" if not self.cassette_path.exists(): raise FileNotFoundError(f"Cassette file not found: {self.cassette_path}") - + try: cassette_data = json.loads(self.cassette_path.read_text()) return cassette_data.get("interactions", []) except json.JSONDecodeError as e: raise ValueError(f"Invalid cassette file format: {e}") - + def _handle_request(self, request: httpx.Request) -> httpx.Response: """Handle request by finding matching interaction and returning saved response.""" print(f"šŸ” ReplayTransport: Looking for {request.method} {request.url}") - + # Debug: show what we're trying to match request_signature = self._get_request_signature(request) print(f"šŸ” Request signature: {request_signature}") - + # Debug: show actual request content content = request.content - if hasattr(content, 'read'): + if hasattr(content, "read"): content = content.read() if isinstance(content, bytes): - content_str = content.decode('utf-8', errors='ignore') + content_str = content.decode("utf-8", errors="ignore") else: content_str = str(content) if content else "" print(f"šŸ” Actual request content: {content_str}") - + # Debug: show available signatures for i, interaction in enumerate(self.interactions): saved_signature = self._get_saved_request_signature(interaction["request"]) saved_content = interaction["request"].get("content", {}) print(f"šŸ” Available signature {i}: {saved_signature}") print(f"šŸ” Saved content {i}: {saved_content}") - + # Find matching interaction interaction = self._find_matching_interaction(request) if not interaction: print("🚨 MYSTERY SOLVED: No matching interaction found! This should fail...") - raise ValueError( - f"No matching interaction found for {request.method} {request.url}" - ) - - print(f"āœ… Found matching interaction from cassette!") - + raise ValueError(f"No matching interaction found for {request.method} {request.url}") + + print("āœ… Found matching interaction from cassette!") + # Build response from saved data response_data = interaction["response"] - + # Convert content back to appropriate format content = response_data.get("content", {}) if isinstance(content, dict): @@ -317,55 +307,56 @@ class ReplayTransport(httpx.MockTransport): print(f"šŸŽ¬ ReplayTransport: Decoded {len(content_bytes)} bytes from base64") except Exception as e: print(f"āš ļø Failed to decode base64 content: {e}") - content_bytes = json.dumps(content).encode('utf-8') + content_bytes = json.dumps(content).encode("utf-8") else: # Legacy format or stub content - content_bytes = json.dumps(content).encode('utf-8') + content_bytes = json.dumps(content).encode("utf-8") else: - content_bytes = str(content).encode('utf-8') - + content_bytes = str(content).encode("utf-8") + # Check if response expects gzipped content headers = response_data.get("headers", {}) - if headers.get('content-encoding') == 'gzip': + if headers.get("content-encoding") == "gzip": # Re-compress the content for httpx import gzip + print(f"šŸ—œļø ReplayTransport: Re-compressing {len(content_bytes)} bytes with gzip...") content_bytes = gzip.compress(content_bytes) print(f"šŸ—œļø ReplayTransport: Compressed to {len(content_bytes)} bytes") - + print(f"šŸŽ¬ ReplayTransport: Returning cassette response with content: {content_bytes[:100]}...") - + # Create httpx.Response return httpx.Response( status_code=response_data["status_code"], headers=response_data.get("headers", {}), content=content_bytes, - request=request + request=request, ) - - def _find_matching_interaction(self, request: httpx.Request) -> Optional[Dict[str, Any]]: + + def _find_matching_interaction(self, request: httpx.Request) -> Optional[dict[str, Any]]: """Find interaction that matches the request.""" request_signature = self._get_request_signature(request) - + for interaction in self.interactions: saved_signature = self._get_saved_request_signature(interaction["request"]) if request_signature == saved_signature: return interaction - + return None - + def _get_request_signature(self, request: httpx.Request) -> str: """Generate signature for request matching.""" # Use method, path, and content hash for matching content = request.content - if hasattr(content, 'read'): + if hasattr(content, "read"): content = content.read() - + if isinstance(content, bytes): - content_str = content.decode('utf-8', errors='ignore') + content_str = content.decode("utf-8", errors="ignore") else: content_str = str(content) if content else "" - + # Parse JSON and re-serialize with sorted keys for consistent hashing try: if content_str.strip(): @@ -374,37 +365,37 @@ class ReplayTransport(httpx.MockTransport): except json.JSONDecodeError: # Not JSON, use as-is pass - + # Create hash of content for stable matching content_hash = hashlib.md5(content_str.encode()).hexdigest() - + return f"{request.method}:{request.url.path}:{content_hash}" - - def _get_saved_request_signature(self, saved_request: Dict[str, Any]) -> str: + + def _get_saved_request_signature(self, saved_request: dict[str, Any]) -> str: """Generate signature for saved request.""" method = saved_request["method"] path = saved_request["path"] - + # Hash the saved content content = saved_request.get("content", "") if isinstance(content, dict): content_str = json.dumps(content, sort_keys=True) else: content_str = str(content) - + content_hash = hashlib.md5(content_str.encode()).hexdigest() - + return f"{method}:{path}:{content_hash}" class TransportFactory: """Factory for creating appropriate transport based on cassette availability.""" - + @staticmethod def create_transport(cassette_path: str) -> httpx.HTTPTransport: """Create transport based on cassette existence and API key availability.""" cassette_file = Path(cassette_path) - + # Check if we should record or replay if cassette_file.exists(): # Cassette exists - use replay mode @@ -413,15 +404,15 @@ class TransportFactory: # No cassette - use recording mode # Note: We'll check for API key in the test itself return RecordingTransport(cassette_path) - + @staticmethod def should_record(cassette_path: str, api_key: Optional[str] = None) -> bool: """Determine if we should record based on cassette and API key availability.""" cassette_file = Path(cassette_path) - + # Record if cassette doesn't exist AND we have API key return not cassette_file.exists() and bool(api_key) - + @staticmethod def should_replay(cassette_path: str) -> bool: """Determine if we should replay based on cassette availability.""" @@ -434,8 +425,8 @@ class TransportFactory: # # In test setup: # cassette_path = "tests/cassettes/o3_pro_basic_math.json" # transport = TransportFactory.create_transport(cassette_path) -# +# # # Inject into OpenAI client: # provider._test_transport = transport -# -# # The provider's client property will detect _test_transport and use it \ No newline at end of file +# +# # The provider's client property will detect _test_transport and use it diff --git a/tests/pii_sanitizer.py b/tests/pii_sanitizer.py index ca2c6be..160492f 100644 --- a/tests/pii_sanitizer.py +++ b/tests/pii_sanitizer.py @@ -7,11 +7,12 @@ request/response recordings to prevent accidental exposure of API keys, tokens, personal information, and other sensitive data. """ -import re -from typing import Any, Dict, List, Optional, Pattern -from dataclasses import dataclass -from copy import deepcopy import logging +import re +from copy import deepcopy +from dataclasses import dataclass +from re import Pattern +from typing import Any, Optional logger = logging.getLogger(__name__) @@ -19,178 +20,170 @@ logger = logging.getLogger(__name__) @dataclass class PIIPattern: """Defines a pattern for detecting and sanitizing PII.""" + name: str pattern: Pattern[str] replacement: str description: str - + @classmethod - def create(cls, name: str, pattern: str, replacement: str, description: str) -> 'PIIPattern': + def create(cls, name: str, pattern: str, replacement: str, description: str) -> "PIIPattern": """Create a PIIPattern with compiled regex.""" - return cls( - name=name, - pattern=re.compile(pattern), - replacement=replacement, - description=description - ) + return cls(name=name, pattern=re.compile(pattern), replacement=replacement, description=description) class PIISanitizer: """Sanitizes PII from various data structures while preserving format.""" - - def __init__(self, patterns: Optional[List[PIIPattern]] = None): + + def __init__(self, patterns: Optional[list[PIIPattern]] = None): """Initialize with optional custom patterns.""" - self.patterns: List[PIIPattern] = patterns or [] + self.patterns: list[PIIPattern] = patterns or [] self.sanitize_enabled = True - + # Add default patterns if none provided if not patterns: self._add_default_patterns() - + def _add_default_patterns(self): """Add comprehensive default PII patterns.""" default_patterns = [ # API Keys - Core patterns (Bearer tokens handled in sanitize_headers) PIIPattern.create( name="openai_api_key_proj", - pattern=r'sk-proj-[A-Za-z0-9\-_]{48,}', + pattern=r"sk-proj-[A-Za-z0-9\-_]{48,}", replacement="sk-proj-SANITIZED", - description="OpenAI project API keys" + description="OpenAI project API keys", ), PIIPattern.create( name="openai_api_key", - pattern=r'sk-[A-Za-z0-9]{48,}', + pattern=r"sk-[A-Za-z0-9]{48,}", replacement="sk-SANITIZED", - description="OpenAI API keys" + description="OpenAI API keys", ), PIIPattern.create( name="anthropic_api_key", - pattern=r'sk-ant-[A-Za-z0-9\-_]{48,}', + pattern=r"sk-ant-[A-Za-z0-9\-_]{48,}", replacement="sk-ant-SANITIZED", - description="Anthropic API keys" + description="Anthropic API keys", ), PIIPattern.create( name="google_api_key", - pattern=r'AIza[A-Za-z0-9\-_]{35,}', + pattern=r"AIza[A-Za-z0-9\-_]{35,}", replacement="AIza-SANITIZED", - description="Google API keys" + description="Google API keys", ), PIIPattern.create( name="github_tokens", - pattern=r'gh[psr]_[A-Za-z0-9]{36}', + pattern=r"gh[psr]_[A-Za-z0-9]{36}", replacement="gh_SANITIZED", - description="GitHub tokens (all types)" + description="GitHub tokens (all types)", ), - # JWT tokens PIIPattern.create( name="jwt_token", - pattern=r'eyJ[A-Za-z0-9\-_]+\.eyJ[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_]+', + pattern=r"eyJ[A-Za-z0-9\-_]+\.eyJ[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_]+", replacement="eyJ-SANITIZED", - description="JSON Web Tokens" + description="JSON Web Tokens", ), - # Personal Information PIIPattern.create( name="email_address", - pattern=r'[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}', + pattern=r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}", replacement="user@example.com", - description="Email addresses" + description="Email addresses", ), PIIPattern.create( name="ipv4_address", - pattern=r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b', + pattern=r"\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b", replacement="0.0.0.0", - description="IPv4 addresses" + description="IPv4 addresses", ), PIIPattern.create( name="ssn", - pattern=r'\b\d{3}-\d{2}-\d{4}\b', + pattern=r"\b\d{3}-\d{2}-\d{4}\b", replacement="XXX-XX-XXXX", - description="Social Security Numbers" + description="Social Security Numbers", ), PIIPattern.create( name="credit_card", - pattern=r'\b\d{4}[\s\-]?\d{4}[\s\-]?\d{4}[\s\-]?\d{4}\b', + pattern=r"\b\d{4}[\s\-]?\d{4}[\s\-]?\d{4}[\s\-]?\d{4}\b", replacement="XXXX-XXXX-XXXX-XXXX", - description="Credit card numbers" + description="Credit card numbers", ), PIIPattern.create( name="phone_number", - pattern=r'(?:\+\d{1,3}[\s\-]?)?\(?\d{3}\)?[\s\-]?\d{3}[\s\-]?\d{4}', + pattern=r"(?:\+\d{1,3}[\s\-]?)?\(?\d{3}\)?[\s\-]?\d{3}[\s\-]?\d{4}", replacement="(XXX) XXX-XXXX", - description="Phone numbers (all formats)" + description="Phone numbers (all formats)", ), - # AWS PIIPattern.create( name="aws_access_key", - pattern=r'AKIA[0-9A-Z]{16}', + pattern=r"AKIA[0-9A-Z]{16}", replacement="AKIA-SANITIZED", - description="AWS access keys" + description="AWS access keys", ), - # Other common patterns PIIPattern.create( name="slack_token", - pattern=r'xox[baprs]-[0-9]{10,13}-[0-9]{10,13}-[a-zA-Z0-9]{24,34}', + pattern=r"xox[baprs]-[0-9]{10,13}-[0-9]{10,13}-[a-zA-Z0-9]{24,34}", replacement="xox-SANITIZED", - description="Slack tokens" + description="Slack tokens", ), PIIPattern.create( name="stripe_key", - pattern=r'(?:sk|pk)_(?:test|live)_[0-9a-zA-Z]{24,99}', + pattern=r"(?:sk|pk)_(?:test|live)_[0-9a-zA-Z]{24,99}", replacement="sk_SANITIZED", - description="Stripe API keys" + description="Stripe API keys", ), ] - + self.patterns.extend(default_patterns) - + def add_pattern(self, pattern: PIIPattern): """Add a custom PII pattern.""" self.patterns.append(pattern) logger.info(f"Added PII pattern: {pattern.name}") - + def sanitize_string(self, text: str) -> str: """Apply all patterns to sanitize a string.""" if not self.sanitize_enabled or not isinstance(text, str): return text - + sanitized = text for pattern in self.patterns: if pattern.pattern.search(sanitized): sanitized = pattern.pattern.sub(pattern.replacement, sanitized) logger.debug(f"Applied {pattern.name} sanitization") - + return sanitized - - def sanitize_headers(self, headers: Dict[str, str]) -> Dict[str, str]: + + def sanitize_headers(self, headers: dict[str, str]) -> dict[str, str]: """Special handling for HTTP headers.""" if not self.sanitize_enabled: return headers - + sanitized_headers = {} - + for key, value in headers.items(): # Special case for Authorization headers to preserve auth type - if key.lower() == 'authorization' and ' ' in value: - auth_type = value.split(' ', 1)[0] - if auth_type in ('Bearer', 'Basic'): - sanitized_headers[key] = f'{auth_type} SANITIZED' + if key.lower() == "authorization" and " " in value: + auth_type = value.split(" ", 1)[0] + if auth_type in ("Bearer", "Basic"): + sanitized_headers[key] = f"{auth_type} SANITIZED" else: sanitized_headers[key] = self.sanitize_string(value) else: # Apply standard sanitization to all other headers sanitized_headers[key] = self.sanitize_string(value) - + return sanitized_headers - + def sanitize_value(self, value: Any) -> Any: """Recursively sanitize any value (string, dict, list, etc).""" if not self.sanitize_enabled: return value - + if isinstance(value, str): return self.sanitize_string(value) elif isinstance(value, dict): @@ -202,25 +195,25 @@ class PIISanitizer: else: # For other types (int, float, bool, None), return as-is return value - + def sanitize_url(self, url: str) -> str: """Sanitize sensitive data from URLs (query params, etc).""" if not self.sanitize_enabled: return url - + # First apply general string sanitization url = self.sanitize_string(url) - + # Parse and sanitize query parameters - if '?' in url: - base, query = url.split('?', 1) + if "?" in url: + base, query = url.split("?", 1) params = [] - - for param in query.split('&'): - if '=' in param: - key, value = param.split('=', 1) + + for param in query.split("&"): + if "=" in param: + key, value = param.split("=", 1) # Sanitize common sensitive parameter names - sensitive_params = {'key', 'token', 'api_key', 'secret', 'password'} + sensitive_params = {"key", "token", "api_key", "secret", "password"} if key.lower() in sensitive_params: params.append(f"{key}=SANITIZED") else: @@ -228,54 +221,53 @@ class PIISanitizer: params.append(f"{key}={self.sanitize_string(value)}") else: params.append(param) - + return f"{base}?{'&'.join(params)}" - + return url - - - def sanitize_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + + def sanitize_request(self, request_data: dict[str, Any]) -> dict[str, Any]: """Sanitize a complete request dictionary.""" sanitized = deepcopy(request_data) - + # Sanitize headers - if 'headers' in sanitized: - sanitized['headers'] = self.sanitize_headers(sanitized['headers']) - + if "headers" in sanitized: + sanitized["headers"] = self.sanitize_headers(sanitized["headers"]) + # Sanitize URL - if 'url' in sanitized: - sanitized['url'] = self.sanitize_url(sanitized['url']) - + if "url" in sanitized: + sanitized["url"] = self.sanitize_url(sanitized["url"]) + # Sanitize content - if 'content' in sanitized: - sanitized['content'] = self.sanitize_value(sanitized['content']) - + if "content" in sanitized: + sanitized["content"] = self.sanitize_value(sanitized["content"]) + return sanitized - - def sanitize_response(self, response_data: Dict[str, Any]) -> Dict[str, Any]: + + def sanitize_response(self, response_data: dict[str, Any]) -> dict[str, Any]: """Sanitize a complete response dictionary.""" sanitized = deepcopy(response_data) - + # Sanitize headers - if 'headers' in sanitized: - sanitized['headers'] = self.sanitize_headers(sanitized['headers']) - + if "headers" in sanitized: + sanitized["headers"] = self.sanitize_headers(sanitized["headers"]) + # Sanitize content - if 'content' in sanitized: + if "content" in sanitized: # Handle base64 encoded content specially - if isinstance(sanitized['content'], dict) and sanitized['content'].get('encoding') == 'base64': + if isinstance(sanitized["content"], dict) and sanitized["content"].get("encoding") == "base64": # Don't decode/re-encode the actual response body # but sanitize any metadata - if 'data' in sanitized['content']: + if "data" in sanitized["content"]: # Keep the data as-is but sanitize other fields - for key, value in sanitized['content'].items(): - if key != 'data': - sanitized['content'][key] = self.sanitize_value(value) + for key, value in sanitized["content"].items(): + if key != "data": + sanitized["content"][key] = self.sanitize_value(value) else: - sanitized['content'] = self.sanitize_value(sanitized['content']) - + sanitized["content"] = self.sanitize_value(sanitized["content"]) + return sanitized # Global instance for convenience -default_sanitizer = PIISanitizer() \ No newline at end of file +default_sanitizer = PIISanitizer() diff --git a/tests/sanitize_cassettes.py b/tests/sanitize_cassettes.py index 814b420..123cdbd 100755 --- a/tests/sanitize_cassettes.py +++ b/tests/sanitize_cassettes.py @@ -10,10 +10,10 @@ This script will: """ import json -import sys -from pathlib import Path import shutil +import sys from datetime import datetime +from pathlib import Path # Add tests directory to path to import our modules sys.path.insert(0, str(Path(__file__).parent)) @@ -24,54 +24,55 @@ from pii_sanitizer import PIISanitizer def sanitize_cassette(cassette_path: Path, backup: bool = True) -> bool: """Sanitize a single cassette file.""" print(f"\nšŸ” Processing: {cassette_path}") - + if not cassette_path.exists(): print(f"āŒ File not found: {cassette_path}") return False - + try: # Load cassette - with open(cassette_path, 'r') as f: + with open(cassette_path) as f: cassette_data = json.load(f) - + # Create backup if requested if backup: backup_path = cassette_path.with_suffix(f'.backup-{datetime.now().strftime("%Y%m%d-%H%M%S")}.json') shutil.copy2(cassette_path, backup_path) print(f"šŸ“¦ Backup created: {backup_path}") - + # Initialize sanitizer sanitizer = PIISanitizer() - + # Sanitize interactions - if 'interactions' in cassette_data: + if "interactions" in cassette_data: sanitized_interactions = [] - - for interaction in cassette_data['interactions']: + + for interaction in cassette_data["interactions"]: sanitized_interaction = {} - + # Sanitize request - if 'request' in interaction: - sanitized_interaction['request'] = sanitizer.sanitize_request(interaction['request']) - + if "request" in interaction: + sanitized_interaction["request"] = sanitizer.sanitize_request(interaction["request"]) + # Sanitize response - if 'response' in interaction: - sanitized_interaction['response'] = sanitizer.sanitize_response(interaction['response']) - + if "response" in interaction: + sanitized_interaction["response"] = sanitizer.sanitize_response(interaction["response"]) + sanitized_interactions.append(sanitized_interaction) - - cassette_data['interactions'] = sanitized_interactions - + + cassette_data["interactions"] = sanitized_interactions + # Save sanitized cassette - with open(cassette_path, 'w') as f: + with open(cassette_path, "w") as f: json.dump(cassette_data, f, indent=2, sort_keys=True) - + print(f"āœ… Sanitized: {cassette_path}") return True - + except Exception as e: print(f"āŒ Error processing {cassette_path}: {e}") import traceback + traceback.print_exc() return False @@ -79,31 +80,31 @@ def sanitize_cassette(cassette_path: Path, backup: bool = True) -> bool: def main(): """Sanitize all cassettes in the openai_cassettes directory.""" cassettes_dir = Path(__file__).parent / "openai_cassettes" - + if not cassettes_dir.exists(): print(f"āŒ Directory not found: {cassettes_dir}") sys.exit(1) - + # Find all JSON cassettes cassette_files = list(cassettes_dir.glob("*.json")) - + if not cassette_files: print(f"āŒ No cassette files found in {cassettes_dir}") sys.exit(1) - + print(f"šŸŽ¬ Found {len(cassette_files)} cassette(s) to sanitize") - + # Process each cassette success_count = 0 for cassette_path in cassette_files: if sanitize_cassette(cassette_path): success_count += 1 - + print(f"\n✨ Sanitization complete: {success_count}/{len(cassette_files)} cassettes processed successfully") - + if success_count < len(cassette_files): sys.exit(1) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/test_o3_pro_output_text_fix.py b/tests/test_o3_pro_output_text_fix.py index f1258eb..7c4bed8 100644 --- a/tests/test_o3_pro_output_text_fix.py +++ b/tests/test_o3_pro_output_text_fix.py @@ -18,11 +18,11 @@ from pathlib import Path import pytest from dotenv import load_dotenv -from tools.chat import ChatTool from providers import ModelProviderRegistry from providers.base import ProviderType from providers.openai_provider import OpenAIModelProvider from tests.http_transport_recorder import TransportFactory +from tools.chat import ChatTool # Load environment variables from .env file load_dotenv() @@ -32,54 +32,87 @@ cassette_dir = Path(__file__).parent / "openai_cassettes" cassette_dir.mkdir(exist_ok=True) +@pytest.fixture +def allow_all_models(monkeypatch): + """Allow all models by resetting the restriction service singleton.""" + # Import here to avoid circular imports + from utils.model_restrictions import _restriction_service + + # Store original state + original_service = _restriction_service + original_allowed_models = os.getenv("ALLOWED_MODELS") + original_openai_key = os.getenv("OPENAI_API_KEY") + + # Reset the singleton so it will re-read env vars inside this fixture + monkeypatch.setattr("utils.model_restrictions._restriction_service", None) + monkeypatch.setenv("ALLOWED_MODELS", "") # empty string = no restrictions + monkeypatch.setenv("OPENAI_API_KEY", "dummy-key-for-replay") # transport layer expects a key + + # Also clear the provider registry cache to ensure clean state + from providers.registry import ModelProviderRegistry + ModelProviderRegistry.clear_cache() + + yield + + # Clean up: reset singleton again so other tests don't see the unrestricted version + monkeypatch.setattr("utils.model_restrictions._restriction_service", None) + # Clear registry cache again for other tests + ModelProviderRegistry.clear_cache() + + @pytest.mark.no_mock_provider # Disable provider mocking for this test class TestO3ProOutputTextFix(unittest.IsolatedAsyncioTestCase): """Test o3-pro response parsing fix using respx for HTTP recording/replay.""" def setUp(self): """Set up the test by ensuring OpenAI provider is registered.""" + # Clear any cached providers to ensure clean state + ModelProviderRegistry.clear_cache() # Manually register the OpenAI provider to ensure it's available ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider) + @pytest.mark.usefixtures("allow_all_models") async def test_o3_pro_uses_output_text_field(self): """Test that o3-pro parsing uses the output_text convenience field via ChatTool.""" cassette_path = cassette_dir / "o3_pro_basic_math.json" - + # Skip if no API key available and cassette doesn't exist if not cassette_path.exists() and not os.getenv("OPENAI_API_KEY"): pytest.skip("Set real OPENAI_API_KEY to record cassettes") # Create transport (automatically selects record vs replay mode) transport = TransportFactory.create_transport(str(cassette_path)) - + # Get provider and inject custom transport provider = ModelProviderRegistry.get_provider_for_model("o3-pro") if not provider: self.fail("OpenAI provider not available for o3-pro model") - + # Inject transport for this test - original_transport = getattr(provider, '_test_transport', None) + original_transport = getattr(provider, "_test_transport", None) provider._test_transport = transport - + try: # Execute ChatTool test with custom transport result = await self._execute_chat_tool_test() - + # Verify the response works correctly self._verify_chat_tool_response(result) - + # Verify cassette was created/used if not cassette_path.exists(): self.fail(f"Cassette should exist at {cassette_path}") - - print(f"āœ… HTTP transport {'recorded' if isinstance(transport, type(transport).__bases__[0]) else 'replayed'} o3-pro interaction") - + + print( + f"āœ… HTTP transport {'recorded' if isinstance(transport, type(transport).__bases__[0]) else 'replayed'} o3-pro interaction" + ) + finally: # Restore original transport (if any) if original_transport: provider._test_transport = original_transport - elif hasattr(provider, '_test_transport'): - delattr(provider, '_test_transport') + elif hasattr(provider, "_test_transport"): + delattr(provider, "_test_transport") async def _execute_chat_tool_test(self): """Execute the ChatTool with o3-pro and return the result.""" diff --git a/tests/test_openai_provider.py b/tests/test_openai_provider.py index 3429be9..d077da5 100644 --- a/tests/test_openai_provider.py +++ b/tests/test_openai_provider.py @@ -230,10 +230,8 @@ class TestOpenAIProvider: mock_openai_class.return_value = mock_client mock_response = MagicMock() - mock_response.output = MagicMock() - mock_response.output.content = [MagicMock()] - mock_response.output.content[0].type = "output_text" - mock_response.output.content[0].text = "4" + # New o3-pro format: direct output_text field + mock_response.output_text = "4" mock_response.model = "o3-pro-2025-06-10" mock_response.id = "test-id" mock_response.created_at = 1234567890 diff --git a/tests/test_pii_sanitizer.py b/tests/test_pii_sanitizer.py index a72e059..46cfc9f 100644 --- a/tests/test_pii_sanitizer.py +++ b/tests/test_pii_sanitizer.py @@ -2,64 +2,59 @@ """Test cases for PII sanitizer.""" import unittest -from tests.pii_sanitizer import PIISanitizer, PIIPattern + +from tests.pii_sanitizer import PIIPattern, PIISanitizer class TestPIISanitizer(unittest.TestCase): """Test PII sanitization functionality.""" - + def setUp(self): """Set up test sanitizer.""" self.sanitizer = PIISanitizer() - + def test_api_key_sanitization(self): """Test various API key formats are sanitized.""" test_cases = [ # OpenAI keys ("sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12", "sk-proj-SANITIZED"), ("sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN", "sk-SANITIZED"), - # Anthropic keys ("sk-ant-abcd1234567890ABCD1234567890abcd1234567890ABCD12", "sk-ant-SANITIZED"), - # Google keys ("AIzaSyD-1234567890abcdefghijklmnopqrstuv", "AIza-SANITIZED"), - # GitHub tokens ("ghp_1234567890abcdefghijklmnopqrstuvwxyz", "gh_SANITIZED"), ("ghs_1234567890abcdefghijklmnopqrstuvwxyz", "gh_SANITIZED"), ] - + for original, expected in test_cases: with self.subTest(original=original): result = self.sanitizer.sanitize_string(original) self.assertEqual(result, expected) - + def test_personal_info_sanitization(self): """Test personal information is sanitized.""" test_cases = [ # Email addresses ("john.doe@example.com", "user@example.com"), ("test123@company.org", "user@example.com"), - # Phone numbers (all now use the same pattern) ("(555) 123-4567", "(XXX) XXX-XXXX"), ("555-123-4567", "(XXX) XXX-XXXX"), ("+1-555-123-4567", "(XXX) XXX-XXXX"), - # SSN ("123-45-6789", "XXX-XX-XXXX"), - # Credit card ("1234 5678 9012 3456", "XXXX-XXXX-XXXX-XXXX"), ("1234-5678-9012-3456", "XXXX-XXXX-XXXX-XXXX"), ] - + for original, expected in test_cases: with self.subTest(original=original): result = self.sanitizer.sanitize_string(original) self.assertEqual(result, expected) - + def test_header_sanitization(self): """Test HTTP header sanitization.""" headers = { @@ -67,84 +62,82 @@ class TestPIISanitizer(unittest.TestCase): "API-Key": "sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN", "Content-Type": "application/json", "User-Agent": "MyApp/1.0", - "Cookie": "session=abc123; user=john.doe@example.com" + "Cookie": "session=abc123; user=john.doe@example.com", } - + sanitized = self.sanitizer.sanitize_headers(headers) - + self.assertEqual(sanitized["Authorization"], "Bearer SANITIZED") self.assertEqual(sanitized["API-Key"], "sk-SANITIZED") self.assertEqual(sanitized["Content-Type"], "application/json") self.assertEqual(sanitized["User-Agent"], "MyApp/1.0") self.assertIn("user@example.com", sanitized["Cookie"]) - + def test_nested_structure_sanitization(self): """Test sanitization of nested data structures.""" data = { "user": { "email": "john.doe@example.com", - "api_key": "sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12" + "api_key": "sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12", }, "tokens": [ "ghp_1234567890abcdefghijklmnopqrstuvwxyz", - "Bearer sk-ant-abcd1234567890ABCD1234567890abcd1234567890ABCD12" + "Bearer sk-ant-abcd1234567890ABCD1234567890abcd1234567890ABCD12", ], - "metadata": { - "ip": "192.168.1.100", - "phone": "(555) 123-4567" - } + "metadata": {"ip": "192.168.1.100", "phone": "(555) 123-4567"}, } - + sanitized = self.sanitizer.sanitize_value(data) - + self.assertEqual(sanitized["user"]["email"], "user@example.com") self.assertEqual(sanitized["user"]["api_key"], "sk-proj-SANITIZED") self.assertEqual(sanitized["tokens"][0], "gh_SANITIZED") self.assertEqual(sanitized["tokens"][1], "Bearer sk-ant-SANITIZED") self.assertEqual(sanitized["metadata"]["ip"], "0.0.0.0") self.assertEqual(sanitized["metadata"]["phone"], "(XXX) XXX-XXXX") - + def test_url_sanitization(self): """Test URL parameter sanitization.""" urls = [ - ("https://api.example.com/v1/users?api_key=sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN", - "https://api.example.com/v1/users?api_key=SANITIZED"), - ("https://example.com/login?token=ghp_1234567890abcdefghijklmnopqrstuvwxyz&user=test", - "https://example.com/login?token=SANITIZED&user=test"), + ( + "https://api.example.com/v1/users?api_key=sk-1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN", + "https://api.example.com/v1/users?api_key=SANITIZED", + ), + ( + "https://example.com/login?token=ghp_1234567890abcdefghijklmnopqrstuvwxyz&user=test", + "https://example.com/login?token=SANITIZED&user=test", + ), ] - + for original, expected in urls: with self.subTest(url=original): result = self.sanitizer.sanitize_url(original) self.assertEqual(result, expected) - + def test_disable_sanitization(self): """Test that sanitization can be disabled.""" self.sanitizer.sanitize_enabled = False - + sensitive_data = "sk-proj-abcd1234567890ABCD1234567890abcd1234567890ABCD12" result = self.sanitizer.sanitize_string(sensitive_data) - + # Should return original when disabled self.assertEqual(result, sensitive_data) - + def test_custom_pattern(self): """Test adding custom PII patterns.""" # Add custom pattern for internal employee IDs custom_pattern = PIIPattern.create( - name="employee_id", - pattern=r'EMP\d{6}', - replacement="EMP-REDACTED", - description="Internal employee IDs" + name="employee_id", pattern=r"EMP\d{6}", replacement="EMP-REDACTED", description="Internal employee IDs" ) - + self.sanitizer.add_pattern(custom_pattern) - + text = "Employee EMP123456 has access to the system" result = self.sanitizer.sanitize_string(text) - + self.assertEqual(result, "Employee EMP-REDACTED has access to the system") if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main()