Fixes bug pointed out by @dsaluja (https://github.com/dsaluja)
Fixes other providers not fixed by https://github.com/BeehiveInnovations/zen-mcp-server/pull/66 New regression tests
This commit is contained in:
105
tests/test_gemini_token_usage.py
Normal file
105
tests/test_gemini_token_usage.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Tests for Gemini provider token usage extraction."""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import Mock
|
||||
|
||||
from providers.gemini import GeminiModelProvider
|
||||
|
||||
|
||||
class TestGeminiTokenUsage(unittest.TestCase):
|
||||
"""Test Gemini provider token usage handling."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.provider = GeminiModelProvider("test-key")
|
||||
|
||||
def test_extract_usage_with_valid_tokens(self):
|
||||
"""Test token extraction with valid token counts."""
|
||||
response = Mock()
|
||||
response.usage_metadata = Mock()
|
||||
response.usage_metadata.prompt_token_count = 100
|
||||
response.usage_metadata.candidates_token_count = 50
|
||||
|
||||
usage = self.provider._extract_usage(response)
|
||||
|
||||
self.assertEqual(usage["input_tokens"], 100)
|
||||
self.assertEqual(usage["output_tokens"], 50)
|
||||
self.assertEqual(usage["total_tokens"], 150)
|
||||
|
||||
def test_extract_usage_with_none_input_tokens(self):
|
||||
"""Test token extraction when input_tokens is None (regression test for bug)."""
|
||||
response = Mock()
|
||||
response.usage_metadata = Mock()
|
||||
response.usage_metadata.prompt_token_count = None # This was causing crashes
|
||||
response.usage_metadata.candidates_token_count = 50
|
||||
|
||||
usage = self.provider._extract_usage(response)
|
||||
|
||||
# Should not include input_tokens when None
|
||||
self.assertNotIn("input_tokens", usage)
|
||||
self.assertEqual(usage["output_tokens"], 50)
|
||||
# Should not calculate total_tokens when input is None
|
||||
self.assertNotIn("total_tokens", usage)
|
||||
|
||||
def test_extract_usage_with_none_output_tokens(self):
|
||||
"""Test token extraction when output_tokens is None (regression test for bug)."""
|
||||
response = Mock()
|
||||
response.usage_metadata = Mock()
|
||||
response.usage_metadata.prompt_token_count = 100
|
||||
response.usage_metadata.candidates_token_count = None # This was causing crashes
|
||||
|
||||
usage = self.provider._extract_usage(response)
|
||||
|
||||
self.assertEqual(usage["input_tokens"], 100)
|
||||
# Should not include output_tokens when None
|
||||
self.assertNotIn("output_tokens", usage)
|
||||
# Should not calculate total_tokens when output is None
|
||||
self.assertNotIn("total_tokens", usage)
|
||||
|
||||
def test_extract_usage_with_both_none_tokens(self):
|
||||
"""Test token extraction when both token counts are None."""
|
||||
response = Mock()
|
||||
response.usage_metadata = Mock()
|
||||
response.usage_metadata.prompt_token_count = None
|
||||
response.usage_metadata.candidates_token_count = None
|
||||
|
||||
usage = self.provider._extract_usage(response)
|
||||
|
||||
# Should return empty dict when all tokens are None
|
||||
self.assertEqual(usage, {})
|
||||
|
||||
def test_extract_usage_without_usage_metadata(self):
|
||||
"""Test token extraction when response has no usage_metadata."""
|
||||
response = Mock(spec=[])
|
||||
|
||||
usage = self.provider._extract_usage(response)
|
||||
|
||||
# Should return empty dict
|
||||
self.assertEqual(usage, {})
|
||||
|
||||
def test_extract_usage_with_zero_tokens(self):
|
||||
"""Test token extraction with zero token counts."""
|
||||
response = Mock()
|
||||
response.usage_metadata = Mock()
|
||||
response.usage_metadata.prompt_token_count = 0
|
||||
response.usage_metadata.candidates_token_count = 0
|
||||
|
||||
usage = self.provider._extract_usage(response)
|
||||
|
||||
self.assertEqual(usage["input_tokens"], 0)
|
||||
self.assertEqual(usage["output_tokens"], 0)
|
||||
self.assertEqual(usage["total_tokens"], 0)
|
||||
|
||||
def test_extract_usage_missing_attributes(self):
|
||||
"""Test token extraction when metadata lacks token count attributes."""
|
||||
response = Mock()
|
||||
response.usage_metadata = Mock(spec=[])
|
||||
|
||||
usage = self.provider._extract_usage(response)
|
||||
|
||||
# Should return empty dict when attributes are missing
|
||||
self.assertEqual(usage, {})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user