diff --git a/tests/test_o3_pro_output_text_fix.py b/tests/test_o3_pro_output_text_fix.py index 687bc61..9d1aa5c 100644 --- a/tests/test_o3_pro_output_text_fix.py +++ b/tests/test_o3_pro_output_text_fix.py @@ -10,7 +10,6 @@ the OpenAI SDK to create real response objects that we can test. RECORDING: To record new responses, delete the cassette file and run with real API keys. """ -import os import unittest from pathlib import Path @@ -20,7 +19,7 @@ from dotenv import load_dotenv from providers import ModelProviderRegistry from providers.base import ProviderType from providers.openai_provider import OpenAIModelProvider -from tests.http_transport_recorder import TransportFactory +from tests.transport_helpers import inject_transport from tools.chat import ChatTool # Load environment variables from .env file @@ -31,29 +30,6 @@ cassette_dir = Path(__file__).parent / "openai_cassettes" cassette_dir.mkdir(exist_ok=True) -@pytest.fixture -def allow_all_models(monkeypatch): - """Allow all models by resetting the restriction service singleton.""" - # Import here to avoid circular imports - - # Reset the singleton so it will re-read env vars inside this fixture - monkeypatch.setattr("utils.model_restrictions._restriction_service", None) - monkeypatch.setenv("ALLOWED_MODELS", "") # empty string = no restrictions - monkeypatch.setenv("OPENAI_API_KEY", "dummy-key-for-replay") # transport layer expects a key - - # Also clear the provider registry cache to ensure clean state - from providers.registry import ModelProviderRegistry - - ModelProviderRegistry.clear_cache() - - yield - - # Clean up: reset singleton again so other tests don't see the unrestricted version - monkeypatch.setattr("utils.model_restrictions._restriction_service", None) - # Clear registry cache again for other tests - ModelProviderRegistry.clear_cache() - - @pytest.mark.asyncio class TestO3ProOutputTextFix: """Test o3-pro response parsing fix using respx for HTTP recording/replay.""" @@ -82,36 +58,19 @@ class TestO3ProOutputTextFix: ModelProviderRegistry._providers = {} @pytest.mark.no_mock_provider # Disable provider mocking for this test - @pytest.mark.usefixtures("allow_all_models") async def test_o3_pro_uses_output_text_field(self, monkeypatch): """Test that o3-pro parsing uses the output_text convenience field via ChatTool.""" + # Set API key inline - helper will handle provider registration + monkeypatch.setenv("OPENAI_API_KEY", "dummy-key-for-replay") + cassette_path = cassette_dir / "o3_pro_basic_math.json" - # Skip if cassette doesn't exist (for test suite runs) + # Require cassette for test - no cargo culting if not cassette_path.exists(): - if os.getenv("OPENAI_API_KEY"): - print(f"Recording new cassette at {cassette_path}") - else: - pytest.skip("Cassette not found and no OPENAI_API_KEY to record new one") + pytest.skip("Cassette file required - record with real OPENAI_API_KEY") - # Create transport (automatically selects record vs replay mode) - transport = TransportFactory.create_transport(str(cassette_path)) - - # Monkey-patch OpenAICompatibleProvider's client property to always use our transport - from providers.openai_compatible import OpenAICompatibleProvider - - original_client_property = OpenAICompatibleProvider.client - - def patched_client_getter(self): - # If no client exists yet, create it with our transport - if self._client is None: - # Set the test transport before creating client - self._test_transport = transport - # Call original property getter - return original_client_property.fget(self) - - # Replace the client property with our patched version - monkeypatch.setattr(OpenAICompatibleProvider, "client", property(patched_client_getter)) + # Simplified transport injection - just one line! + inject_transport(monkeypatch, cassette_path) # Execute ChatTool test with custom transport result = await self._execute_chat_tool_test() @@ -119,12 +78,8 @@ class TestO3ProOutputTextFix: # Verify the response works correctly self._verify_chat_tool_response(result) - # Verify cassette was created/used - assert cassette_path.exists(), f"Cassette should exist at {cassette_path}" - - print( - f"✅ HTTP transport {'recorded' if isinstance(transport, type(transport).__bases__[0]) else 'replayed'} o3-pro interaction" - ) + # Verify cassette exists + assert cassette_path.exists() async def _execute_chat_tool_test(self): """Execute the ChatTool with o3-pro and return the result.""" @@ -135,71 +90,25 @@ class TestO3ProOutputTextFix: def _verify_chat_tool_response(self, result): """Verify the ChatTool response contains expected data.""" - # Verify we got a valid response - assert result is not None, "Should get response from ChatTool" + # Basic response validation + assert result is not None + assert isinstance(result, list) + assert len(result) > 0 + assert result[0].type == "text" - # Parse the result content (ChatTool returns MCP TextContent format) - assert isinstance(result, list), "ChatTool should return list of content" - assert len(result) > 0, "Should have at least one content item" - - # Get the text content (result is a list of TextContent objects) - content_item = result[0] - assert content_item.type == "text", "First item should be text content" - - text_content = content_item.text - assert len(text_content) > 0, "Should have text content" - - # Parse the JSON response to verify metadata + # Parse JSON response import json - response_data = json.loads(text_content) + response_data = json.loads(result[0].text) - # Verify response structure - assert "status" in response_data, "Response should have status field" - assert "content" in response_data, "Response should have content field" - assert "metadata" in response_data, "Response should have metadata field" + # Verify response structure - no cargo culting + assert response_data["status"] in ["success", "continuation_available"] + assert "4" in response_data["content"] - # Check if this is an error response (which may happen if cassette doesn't exist) - if response_data["status"] == "error": - # Skip metadata verification for error responses - print(f"⚠️ Got error response: {response_data['content']}") - print("⚠️ Skipping model metadata verification for error case") - return - - # The key verification: The response should contain "4" as the answer - # This is what proves o3-pro is working correctly with the output_text field - content = response_data["content"] - assert "4" in content, f"Response content should contain the answer '4', got: {content[:200]}..." - - # CRITICAL: Verify that o3-pro was actually used (not just requested) + # Verify o3-pro was actually used metadata = response_data["metadata"] - assert "model_used" in metadata, "Metadata should contain model_used field" - # Note: model_used shows the alias "o3-pro" not the full model ID "o3-pro-2025-06-10" - assert metadata["model_used"] == "o3-pro", f"Should have used o3-pro, but got: {metadata.get('model_used')}" - - # Verify provider information - assert "provider_used" in metadata, "Metadata should contain provider_used field" - assert ( - metadata["provider_used"] == "openai" - ), f"Should have used openai provider, but got: {metadata.get('provider_used')}" - - # Additional verification that the response parsing worked correctly - assert response_data["status"] in [ - "success", - "continuation_available", - ], f"Unexpected status: {response_data['status']}" - - # ADDITIONAL VERIFICATION: Check that the response actually came from o3-pro by verifying: - # 1. The response uses the /v1/responses endpoint (specific to o3 models) - # 2. The response contains "4" which proves output_text parsing worked - # 3. The metadata confirms openai provider was used - # Together these prove o3-pro was used and response parsing is correct - - print(f"✅ o3-pro successfully returned: {content[:100]}...") - print(f"✅ Verified model used: {metadata['model_used']} (alias for o3-pro-2025-06-10)") - print(f"✅ Verified provider: {metadata['provider_used']}") - print("✅ Response parsing uses output_text field correctly") - print("✅ Cassette confirms /v1/responses endpoint was used (o3-specific)") + assert metadata["model_used"] == "o3-pro" + assert metadata["provider_used"] == "openai" if __name__ == "__main__": diff --git a/tests/transport_helpers.py b/tests/transport_helpers.py new file mode 100644 index 0000000..7a68f8e --- /dev/null +++ b/tests/transport_helpers.py @@ -0,0 +1,47 @@ +"""Helper functions for HTTP transport injection in tests.""" + +from tests.http_transport_recorder import TransportFactory + + +def inject_transport(monkeypatch, cassette_path: str): + """Inject HTTP transport into OpenAICompatibleProvider for testing. + + This helper simplifies the monkey patching pattern used across tests + to inject custom HTTP transports for recording/replaying API calls. + + Also ensures OpenAI provider is properly registered for tests that need it. + + Args: + monkeypatch: pytest monkeypatch fixture + cassette_path: Path to cassette file for recording/replay + + Returns: + The created transport instance + + Example: + transport = inject_transport(monkeypatch, "path/to/cassette.json") + """ + # Ensure OpenAI provider is registered if API key is available + import os + if os.getenv("OPENAI_API_KEY"): + from providers.registry import ModelProviderRegistry + from providers.base import ProviderType + from providers.openai_provider import OpenAIModelProvider + ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider) + + # Create transport + transport = TransportFactory.create_transport(str(cassette_path)) + + # Inject transport using the established pattern + from providers.openai_compatible import OpenAICompatibleProvider + + original_client_property = OpenAICompatibleProvider.client + + def patched_client_getter(self): + if self._client is None: + self._test_transport = transport + return original_client_property.fget(self) + + monkeypatch.setattr(OpenAICompatibleProvider, "client", property(patched_client_getter)) + + return transport