Fixes bug pointed out by @dsaluja (https://github.com/dsaluja)
Fixes other providers not fixed by https://github.com/BeehiveInnovations/zen-mcp-server/pull/66 New regression tests
This commit is contained in:
@@ -14,9 +14,9 @@ import os
|
|||||||
# These values are used in server responses and for tracking releases
|
# These values are used in server responses and for tracking releases
|
||||||
# IMPORTANT: This is the single source of truth for version and author info
|
# IMPORTANT: This is the single source of truth for version and author info
|
||||||
# Semantic versioning: MAJOR.MINOR.PATCH
|
# Semantic versioning: MAJOR.MINOR.PATCH
|
||||||
__version__ = "4.9.1"
|
__version__ = "4.9.2"
|
||||||
# Last update date in ISO format
|
# Last update date in ISO format
|
||||||
__updated__ = "2025-06-16"
|
__updated__ = "2025-06-17"
|
||||||
# Primary maintainer
|
# Primary maintainer
|
||||||
__author__ = "Fahad Gilani"
|
__author__ = "Fahad Gilani"
|
||||||
|
|
||||||
|
|||||||
@@ -303,12 +303,26 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
# Note: The actual structure depends on the SDK version and response format
|
# Note: The actual structure depends on the SDK version and response format
|
||||||
if hasattr(response, "usage_metadata"):
|
if hasattr(response, "usage_metadata"):
|
||||||
metadata = response.usage_metadata
|
metadata = response.usage_metadata
|
||||||
|
|
||||||
|
# Extract token counts with explicit None checks
|
||||||
|
input_tokens = None
|
||||||
|
output_tokens = None
|
||||||
|
|
||||||
if hasattr(metadata, "prompt_token_count"):
|
if hasattr(metadata, "prompt_token_count"):
|
||||||
usage["input_tokens"] = metadata.prompt_token_count
|
value = metadata.prompt_token_count
|
||||||
|
if value is not None:
|
||||||
|
input_tokens = value
|
||||||
|
usage["input_tokens"] = value
|
||||||
|
|
||||||
if hasattr(metadata, "candidates_token_count"):
|
if hasattr(metadata, "candidates_token_count"):
|
||||||
usage["output_tokens"] = metadata.candidates_token_count
|
value = metadata.candidates_token_count
|
||||||
if "input_tokens" in usage and "output_tokens" in usage:
|
if value is not None:
|
||||||
usage["total_tokens"] = usage["input_tokens"] + usage["output_tokens"]
|
output_tokens = value
|
||||||
|
usage["output_tokens"] = value
|
||||||
|
|
||||||
|
# Calculate total only if both values are available and valid
|
||||||
|
if input_tokens is not None and output_tokens is not None:
|
||||||
|
usage["total_tokens"] = input_tokens + output_tokens
|
||||||
|
|
||||||
return usage
|
return usage
|
||||||
|
|
||||||
|
|||||||
@@ -300,10 +300,13 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
if hasattr(response, "usage"):
|
if hasattr(response, "usage"):
|
||||||
usage = self._extract_usage(response)
|
usage = self._extract_usage(response)
|
||||||
elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"):
|
elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"):
|
||||||
|
# Safely extract token counts with None handling
|
||||||
|
input_tokens = getattr(response, "input_tokens", 0) or 0
|
||||||
|
output_tokens = getattr(response, "output_tokens", 0) or 0
|
||||||
usage = {
|
usage = {
|
||||||
"input_tokens": getattr(response, "input_tokens", 0),
|
"input_tokens": input_tokens,
|
||||||
"output_tokens": getattr(response, "output_tokens", 0),
|
"output_tokens": output_tokens,
|
||||||
"total_tokens": getattr(response, "input_tokens", 0) + getattr(response, "output_tokens", 0),
|
"total_tokens": input_tokens + output_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
return ModelResponse(
|
return ModelResponse(
|
||||||
@@ -607,9 +610,10 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
usage = {}
|
usage = {}
|
||||||
|
|
||||||
if hasattr(response, "usage") and response.usage:
|
if hasattr(response, "usage") and response.usage:
|
||||||
usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0)
|
# Safely extract token counts with None handling
|
||||||
usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0)
|
usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0) or 0
|
||||||
usage["total_tokens"] = getattr(response.usage, "total_tokens", 0)
|
usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0) or 0
|
||||||
|
usage["total_tokens"] = getattr(response.usage, "total_tokens", 0) or 0
|
||||||
|
|
||||||
return usage
|
return usage
|
||||||
|
|
||||||
|
|||||||
@@ -1,138 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test script for the enhanced consensus tool with ModelConfig objects
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from tools.consensus import ConsensusTool
|
|
||||||
|
|
||||||
|
|
||||||
async def test_enhanced_consensus():
|
|
||||||
"""Test the enhanced consensus tool with custom stance prompts"""
|
|
||||||
|
|
||||||
print("🧪 Testing Enhanced Consensus Tool")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
# Test all stance synonyms work
|
|
||||||
print("📝 Testing stance synonym normalization...")
|
|
||||||
tool = ConsensusTool()
|
|
||||||
|
|
||||||
test_synonyms = [
|
|
||||||
("support", "for"),
|
|
||||||
("favor", "for"),
|
|
||||||
("oppose", "against"),
|
|
||||||
("critical", "against"),
|
|
||||||
("neutral", "neutral"),
|
|
||||||
("for", "for"),
|
|
||||||
("against", "against"),
|
|
||||||
# Test unknown stances default to neutral
|
|
||||||
("maybe", "neutral"),
|
|
||||||
("supportive", "neutral"),
|
|
||||||
("random", "neutral"),
|
|
||||||
]
|
|
||||||
|
|
||||||
for input_stance, expected in test_synonyms:
|
|
||||||
normalized = tool._normalize_stance(input_stance)
|
|
||||||
status = "✅" if normalized == expected else "❌"
|
|
||||||
print(f"{status} '{input_stance}' → '{normalized}' (expected: '{expected}')")
|
|
||||||
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Create consensus tool instance
|
|
||||||
tool = ConsensusTool()
|
|
||||||
|
|
||||||
# Test arguments with new ModelConfig format
|
|
||||||
test_arguments = {
|
|
||||||
"prompt": "Should we add a pizza ordering button to our enterprise software?",
|
|
||||||
"models": [
|
|
||||||
{
|
|
||||||
"model": "flash",
|
|
||||||
"stance": "support", # Test synonym
|
|
||||||
"stance_prompt": "You are a user experience advocate. Focus on how this feature could improve user engagement and satisfaction. Consider the human elements - how might this bring joy to users' workday? Think about unexpected benefits and creative use cases.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"model": "flash",
|
|
||||||
"stance": "oppose", # Test synonym
|
|
||||||
"stance_prompt": "You are a software architecture specialist. Focus on technical concerns: code maintainability, security implications, scope creep, and system complexity. Consider long-term costs and potential maintenance burden.",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"focus_areas": ["user experience", "technical complexity", "business value"],
|
|
||||||
"temperature": 0.3,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
print("📝 Test Arguments:")
|
|
||||||
print(json.dumps(test_arguments, indent=2))
|
|
||||||
print()
|
|
||||||
|
|
||||||
print("🚀 Executing consensus tool...")
|
|
||||||
|
|
||||||
# Execute the tool
|
|
||||||
result = await tool.execute(test_arguments)
|
|
||||||
|
|
||||||
print("✅ Consensus tool execution completed!")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Parse and display results
|
|
||||||
if result and len(result) > 0:
|
|
||||||
response_text = result[0].text
|
|
||||||
try:
|
|
||||||
response_data = json.loads(response_text)
|
|
||||||
print("📊 Consensus Results:")
|
|
||||||
print(f"Status: {response_data.get('status', 'unknown')}")
|
|
||||||
|
|
||||||
if response_data.get("status") == "consensus_success":
|
|
||||||
models_used = response_data.get("models_used", [])
|
|
||||||
print(f"Models used: {', '.join(models_used)}")
|
|
||||||
|
|
||||||
responses = response_data.get("responses", [])
|
|
||||||
print(f"\n🎭 Individual Model Responses ({len(responses)} total):")
|
|
||||||
|
|
||||||
for i, resp in enumerate(responses, 1):
|
|
||||||
model = resp.get("model", "unknown")
|
|
||||||
stance = resp.get("stance", "neutral")
|
|
||||||
status = resp.get("status", "unknown")
|
|
||||||
|
|
||||||
print(f"\n{i}. {model.upper()} ({stance} stance) - {status}")
|
|
||||||
|
|
||||||
if status == "success":
|
|
||||||
verdict = resp.get("verdict", "No verdict")
|
|
||||||
custom_prompt = resp.get("metadata", {}).get("custom_stance_prompt", False)
|
|
||||||
print(f" Custom prompt used: {'Yes' if custom_prompt else 'No'}")
|
|
||||||
print(f" Verdict preview: {verdict[:200]}...")
|
|
||||||
|
|
||||||
# Show stance normalization worked
|
|
||||||
if stance in ["support", "oppose"]:
|
|
||||||
expected = "for" if stance == "support" else "against"
|
|
||||||
print(f" ✅ Stance '{stance}' normalized correctly")
|
|
||||||
else:
|
|
||||||
error = resp.get("error", "Unknown error")
|
|
||||||
print(f" Error: {error}")
|
|
||||||
|
|
||||||
else:
|
|
||||||
print(f"❌ Consensus failed: {response_data.get('error', 'Unknown error')}")
|
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
print("📄 Raw response (not JSON):")
|
|
||||||
print(response_text[:500] + "..." if len(response_text) > 500 else response_text)
|
|
||||||
else:
|
|
||||||
print("❌ No response received from consensus tool")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Test failed with exception: {str(e)}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
print("\n🎉 Enhanced consensus tool test completed!")
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Run the test
|
|
||||||
success = asyncio.run(test_enhanced_consensus())
|
|
||||||
sys.exit(0 if success else 1)
|
|
||||||
@@ -1,72 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test script to verify line number accuracy in the MCP server
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
|
|
||||||
from tools.analyze import AnalyzeTool
|
|
||||||
from tools.chat import ChatTool
|
|
||||||
|
|
||||||
|
|
||||||
async def test_line_number_reporting():
|
|
||||||
"""Test if tools report accurate line numbers when analyzing code"""
|
|
||||||
|
|
||||||
print("=== Testing Line Number Accuracy ===\n")
|
|
||||||
|
|
||||||
# Test 1: Analyze tool with line numbers
|
|
||||||
analyze_tool = AnalyzeTool()
|
|
||||||
|
|
||||||
# Create a request that asks about specific line numbers
|
|
||||||
analyze_request = {
|
|
||||||
"files": ["/Users/fahad/Developer/gemini-mcp-server/test_line_numbers.py"],
|
|
||||||
"prompt": "Find all the lines where 'ignore_patterns' is assigned a list value. Report the exact line numbers.",
|
|
||||||
"model": "flash", # Use a real model
|
|
||||||
}
|
|
||||||
|
|
||||||
print("1. Testing Analyze tool:")
|
|
||||||
print(f" Prompt: {analyze_request['prompt']}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await analyze_tool.execute(analyze_request)
|
|
||||||
result = json.loads(response[0].text)
|
|
||||||
|
|
||||||
if result["status"] == "success":
|
|
||||||
print(f" Response excerpt: {result['content'][:200]}...")
|
|
||||||
else:
|
|
||||||
print(f" Error: {result}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f" Exception: {e}")
|
|
||||||
|
|
||||||
print("\n" + "=" * 50 + "\n")
|
|
||||||
|
|
||||||
# Test 2: Chat tool to simulate the user's scenario
|
|
||||||
chat_tool = ChatTool()
|
|
||||||
|
|
||||||
chat_request = {
|
|
||||||
"files": ["/Users/fahad/Developer/loganalyzer/main.py"],
|
|
||||||
"prompt": "Tell me the exact line number where 'ignore_patterns' is assigned a list in the file. Be precise about the line number.",
|
|
||||||
"model": "flash",
|
|
||||||
}
|
|
||||||
|
|
||||||
print("2. Testing Chat tool with user's actual file:")
|
|
||||||
print(f" File: {chat_request['files'][0]}")
|
|
||||||
print(f" Prompt: {chat_request['prompt']}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await chat_tool.execute(chat_request)
|
|
||||||
result = json.loads(response[0].text)
|
|
||||||
|
|
||||||
if result["status"] == "success":
|
|
||||||
print(f" Response excerpt: {result['content'][:300]}...")
|
|
||||||
else:
|
|
||||||
print(f" Error: {result}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f" Exception: {e}")
|
|
||||||
|
|
||||||
print("\n=== Test Complete ===")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(test_line_number_reporting())
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""Test file to verify line number accuracy"""
|
|
||||||
|
|
||||||
|
|
||||||
# Line 4: Empty line above
|
|
||||||
def example_function():
|
|
||||||
"""Line 6: Docstring"""
|
|
||||||
# Line 7: Comment
|
|
||||||
pass # Line 8
|
|
||||||
|
|
||||||
|
|
||||||
# Line 10: Another comment
|
|
||||||
class TestClass:
|
|
||||||
"""Line 12: Class docstring"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Line 15: Init docstring"""
|
|
||||||
# Line 16: This is where we'll test
|
|
||||||
self.test_variable = "Line 17"
|
|
||||||
|
|
||||||
def method_one(self):
|
|
||||||
"""Line 20: Method docstring"""
|
|
||||||
# Line 21: Important assignment below
|
|
||||||
ignore_patterns = ["pattern1", "pattern2", "pattern3"] # Line 22: This is our test line
|
|
||||||
return ignore_patterns
|
|
||||||
|
|
||||||
|
|
||||||
# Line 25: More code below
|
|
||||||
def another_function():
|
|
||||||
"""Line 27: Another docstring"""
|
|
||||||
# Line 28: Another assignment
|
|
||||||
ignore_patterns = ["different", "patterns"] # Line 29: Second occurrence
|
|
||||||
return ignore_patterns
|
|
||||||
|
|
||||||
|
|
||||||
# Line 32: End of file marker
|
|
||||||
105
tests/test_gemini_token_usage.py
Normal file
105
tests/test_gemini_token_usage.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
"""Tests for Gemini provider token usage extraction."""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
|
||||||
|
|
||||||
|
class TestGeminiTokenUsage(unittest.TestCase):
|
||||||
|
"""Test Gemini provider token usage handling."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test fixtures."""
|
||||||
|
self.provider = GeminiModelProvider("test-key")
|
||||||
|
|
||||||
|
def test_extract_usage_with_valid_tokens(self):
|
||||||
|
"""Test token extraction with valid token counts."""
|
||||||
|
response = Mock()
|
||||||
|
response.usage_metadata = Mock()
|
||||||
|
response.usage_metadata.prompt_token_count = 100
|
||||||
|
response.usage_metadata.candidates_token_count = 50
|
||||||
|
|
||||||
|
usage = self.provider._extract_usage(response)
|
||||||
|
|
||||||
|
self.assertEqual(usage["input_tokens"], 100)
|
||||||
|
self.assertEqual(usage["output_tokens"], 50)
|
||||||
|
self.assertEqual(usage["total_tokens"], 150)
|
||||||
|
|
||||||
|
def test_extract_usage_with_none_input_tokens(self):
|
||||||
|
"""Test token extraction when input_tokens is None (regression test for bug)."""
|
||||||
|
response = Mock()
|
||||||
|
response.usage_metadata = Mock()
|
||||||
|
response.usage_metadata.prompt_token_count = None # This was causing crashes
|
||||||
|
response.usage_metadata.candidates_token_count = 50
|
||||||
|
|
||||||
|
usage = self.provider._extract_usage(response)
|
||||||
|
|
||||||
|
# Should not include input_tokens when None
|
||||||
|
self.assertNotIn("input_tokens", usage)
|
||||||
|
self.assertEqual(usage["output_tokens"], 50)
|
||||||
|
# Should not calculate total_tokens when input is None
|
||||||
|
self.assertNotIn("total_tokens", usage)
|
||||||
|
|
||||||
|
def test_extract_usage_with_none_output_tokens(self):
|
||||||
|
"""Test token extraction when output_tokens is None (regression test for bug)."""
|
||||||
|
response = Mock()
|
||||||
|
response.usage_metadata = Mock()
|
||||||
|
response.usage_metadata.prompt_token_count = 100
|
||||||
|
response.usage_metadata.candidates_token_count = None # This was causing crashes
|
||||||
|
|
||||||
|
usage = self.provider._extract_usage(response)
|
||||||
|
|
||||||
|
self.assertEqual(usage["input_tokens"], 100)
|
||||||
|
# Should not include output_tokens when None
|
||||||
|
self.assertNotIn("output_tokens", usage)
|
||||||
|
# Should not calculate total_tokens when output is None
|
||||||
|
self.assertNotIn("total_tokens", usage)
|
||||||
|
|
||||||
|
def test_extract_usage_with_both_none_tokens(self):
|
||||||
|
"""Test token extraction when both token counts are None."""
|
||||||
|
response = Mock()
|
||||||
|
response.usage_metadata = Mock()
|
||||||
|
response.usage_metadata.prompt_token_count = None
|
||||||
|
response.usage_metadata.candidates_token_count = None
|
||||||
|
|
||||||
|
usage = self.provider._extract_usage(response)
|
||||||
|
|
||||||
|
# Should return empty dict when all tokens are None
|
||||||
|
self.assertEqual(usage, {})
|
||||||
|
|
||||||
|
def test_extract_usage_without_usage_metadata(self):
|
||||||
|
"""Test token extraction when response has no usage_metadata."""
|
||||||
|
response = Mock(spec=[])
|
||||||
|
|
||||||
|
usage = self.provider._extract_usage(response)
|
||||||
|
|
||||||
|
# Should return empty dict
|
||||||
|
self.assertEqual(usage, {})
|
||||||
|
|
||||||
|
def test_extract_usage_with_zero_tokens(self):
|
||||||
|
"""Test token extraction with zero token counts."""
|
||||||
|
response = Mock()
|
||||||
|
response.usage_metadata = Mock()
|
||||||
|
response.usage_metadata.prompt_token_count = 0
|
||||||
|
response.usage_metadata.candidates_token_count = 0
|
||||||
|
|
||||||
|
usage = self.provider._extract_usage(response)
|
||||||
|
|
||||||
|
self.assertEqual(usage["input_tokens"], 0)
|
||||||
|
self.assertEqual(usage["output_tokens"], 0)
|
||||||
|
self.assertEqual(usage["total_tokens"], 0)
|
||||||
|
|
||||||
|
def test_extract_usage_missing_attributes(self):
|
||||||
|
"""Test token extraction when metadata lacks token count attributes."""
|
||||||
|
response = Mock()
|
||||||
|
response.usage_metadata = Mock(spec=[])
|
||||||
|
|
||||||
|
usage = self.provider._extract_usage(response)
|
||||||
|
|
||||||
|
# Should return empty dict when attributes are missing
|
||||||
|
self.assertEqual(usage, {})
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
135
tests/test_openai_compatible_token_usage.py
Normal file
135
tests/test_openai_compatible_token_usage.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""Tests for OpenAI-compatible provider token usage extraction."""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from providers.openai_compatible import OpenAICompatibleProvider
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAICompatibleTokenUsage(unittest.TestCase):
|
||||||
|
"""Test OpenAI-compatible provider token usage handling."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test fixtures."""
|
||||||
|
|
||||||
|
# Create a concrete implementation for testing
|
||||||
|
class TestProvider(OpenAICompatibleProvider):
|
||||||
|
FRIENDLY_NAME = "Test"
|
||||||
|
SUPPORTED_MODELS = {"test-model": {"context_window": 4096}}
|
||||||
|
|
||||||
|
def get_capabilities(self, model_name):
|
||||||
|
return Mock()
|
||||||
|
|
||||||
|
def get_provider_type(self):
|
||||||
|
return Mock()
|
||||||
|
|
||||||
|
def validate_model_name(self, model_name):
|
||||||
|
return True
|
||||||
|
|
||||||
|
self.provider = TestProvider("test-key")
|
||||||
|
|
||||||
|
def test_extract_usage_with_valid_tokens(self):
|
||||||
|
"""Test token extraction with valid token counts."""
|
||||||
|
response = Mock()
|
||||||
|
response.usage = Mock()
|
||||||
|
response.usage.prompt_tokens = 100
|
||||||
|
response.usage.completion_tokens = 50
|
||||||
|
response.usage.total_tokens = 150
|
||||||
|
|
||||||
|
usage = self.provider._extract_usage(response)
|
||||||
|
|
||||||
|
self.assertEqual(usage["input_tokens"], 100)
|
||||||
|
self.assertEqual(usage["output_tokens"], 50)
|
||||||
|
self.assertEqual(usage["total_tokens"], 150)
|
||||||
|
|
||||||
|
def test_extract_usage_with_none_prompt_tokens(self):
|
||||||
|
"""Test token extraction when prompt_tokens is None (regression test for bug)."""
|
||||||
|
response = Mock()
|
||||||
|
response.usage = Mock()
|
||||||
|
response.usage.prompt_tokens = None # This was causing crashes
|
||||||
|
response.usage.completion_tokens = 50
|
||||||
|
response.usage.total_tokens = None
|
||||||
|
|
||||||
|
usage = self.provider._extract_usage(response)
|
||||||
|
|
||||||
|
# Should default to 0 when None
|
||||||
|
self.assertEqual(usage["input_tokens"], 0)
|
||||||
|
self.assertEqual(usage["output_tokens"], 50)
|
||||||
|
self.assertEqual(usage["total_tokens"], 0)
|
||||||
|
|
||||||
|
def test_extract_usage_with_none_completion_tokens(self):
|
||||||
|
"""Test token extraction when completion_tokens is None (regression test for bug)."""
|
||||||
|
response = Mock()
|
||||||
|
response.usage = Mock()
|
||||||
|
response.usage.prompt_tokens = 100
|
||||||
|
response.usage.completion_tokens = None # This was causing crashes
|
||||||
|
response.usage.total_tokens = None
|
||||||
|
|
||||||
|
usage = self.provider._extract_usage(response)
|
||||||
|
|
||||||
|
self.assertEqual(usage["input_tokens"], 100)
|
||||||
|
# Should default to 0 when None
|
||||||
|
self.assertEqual(usage["output_tokens"], 0)
|
||||||
|
self.assertEqual(usage["total_tokens"], 0)
|
||||||
|
|
||||||
|
def test_extract_usage_with_all_none_tokens(self):
|
||||||
|
"""Test token extraction when all token counts are None."""
|
||||||
|
response = Mock()
|
||||||
|
response.usage = Mock()
|
||||||
|
response.usage.prompt_tokens = None
|
||||||
|
response.usage.completion_tokens = None
|
||||||
|
response.usage.total_tokens = None
|
||||||
|
|
||||||
|
usage = self.provider._extract_usage(response)
|
||||||
|
|
||||||
|
# Should default to 0 for all when None
|
||||||
|
self.assertEqual(usage["input_tokens"], 0)
|
||||||
|
self.assertEqual(usage["output_tokens"], 0)
|
||||||
|
self.assertEqual(usage["total_tokens"], 0)
|
||||||
|
|
||||||
|
def test_extract_usage_without_usage(self):
|
||||||
|
"""Test token extraction when response has no usage."""
|
||||||
|
response = Mock(spec=[]) # No usage attribute
|
||||||
|
|
||||||
|
usage = self.provider._extract_usage(response)
|
||||||
|
|
||||||
|
# Should return empty dict
|
||||||
|
self.assertEqual(usage, {})
|
||||||
|
|
||||||
|
def test_extract_usage_with_zero_tokens(self):
|
||||||
|
"""Test token extraction with zero token counts."""
|
||||||
|
response = Mock()
|
||||||
|
response.usage = Mock()
|
||||||
|
response.usage.prompt_tokens = 0
|
||||||
|
response.usage.completion_tokens = 0
|
||||||
|
response.usage.total_tokens = 0
|
||||||
|
|
||||||
|
usage = self.provider._extract_usage(response)
|
||||||
|
|
||||||
|
self.assertEqual(usage["input_tokens"], 0)
|
||||||
|
self.assertEqual(usage["output_tokens"], 0)
|
||||||
|
self.assertEqual(usage["total_tokens"], 0)
|
||||||
|
|
||||||
|
def test_alternative_token_format_with_none(self):
|
||||||
|
"""Test alternative token format (input_tokens/output_tokens) with None values."""
|
||||||
|
# This tests the other code path in generate_content_openai_responses
|
||||||
|
# Simulate a response with input_tokens/output_tokens attributes that could be None
|
||||||
|
response = Mock()
|
||||||
|
response.input_tokens = None # This was causing crashes
|
||||||
|
response.output_tokens = 50
|
||||||
|
|
||||||
|
# Test the pattern: getattr(response, "input_tokens", 0) or 0
|
||||||
|
input_tokens = getattr(response, "input_tokens", 0) or 0
|
||||||
|
output_tokens = getattr(response, "output_tokens", 0) or 0
|
||||||
|
|
||||||
|
# Should not crash and should handle None gracefully
|
||||||
|
self.assertEqual(input_tokens, 0)
|
||||||
|
self.assertEqual(output_tokens, 50)
|
||||||
|
|
||||||
|
# Test that addition works
|
||||||
|
total = input_tokens + output_tokens
|
||||||
|
self.assertEqual(total, 50)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user