Fixes other providers not fixed by https://github.com/BeehiveInnovations/zen-mcp-server/pull/66 New regression tests
106 lines
3.8 KiB
Python
106 lines
3.8 KiB
Python
"""Tests for Gemini provider token usage extraction."""
|
|
|
|
import unittest
|
|
from unittest.mock import Mock
|
|
|
|
from providers.gemini import GeminiModelProvider
|
|
|
|
|
|
class TestGeminiTokenUsage(unittest.TestCase):
|
|
"""Test Gemini provider token usage handling."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
self.provider = GeminiModelProvider("test-key")
|
|
|
|
def test_extract_usage_with_valid_tokens(self):
|
|
"""Test token extraction with valid token counts."""
|
|
response = Mock()
|
|
response.usage_metadata = Mock()
|
|
response.usage_metadata.prompt_token_count = 100
|
|
response.usage_metadata.candidates_token_count = 50
|
|
|
|
usage = self.provider._extract_usage(response)
|
|
|
|
self.assertEqual(usage["input_tokens"], 100)
|
|
self.assertEqual(usage["output_tokens"], 50)
|
|
self.assertEqual(usage["total_tokens"], 150)
|
|
|
|
def test_extract_usage_with_none_input_tokens(self):
|
|
"""Test token extraction when input_tokens is None (regression test for bug)."""
|
|
response = Mock()
|
|
response.usage_metadata = Mock()
|
|
response.usage_metadata.prompt_token_count = None # This was causing crashes
|
|
response.usage_metadata.candidates_token_count = 50
|
|
|
|
usage = self.provider._extract_usage(response)
|
|
|
|
# Should not include input_tokens when None
|
|
self.assertNotIn("input_tokens", usage)
|
|
self.assertEqual(usage["output_tokens"], 50)
|
|
# Should not calculate total_tokens when input is None
|
|
self.assertNotIn("total_tokens", usage)
|
|
|
|
def test_extract_usage_with_none_output_tokens(self):
|
|
"""Test token extraction when output_tokens is None (regression test for bug)."""
|
|
response = Mock()
|
|
response.usage_metadata = Mock()
|
|
response.usage_metadata.prompt_token_count = 100
|
|
response.usage_metadata.candidates_token_count = None # This was causing crashes
|
|
|
|
usage = self.provider._extract_usage(response)
|
|
|
|
self.assertEqual(usage["input_tokens"], 100)
|
|
# Should not include output_tokens when None
|
|
self.assertNotIn("output_tokens", usage)
|
|
# Should not calculate total_tokens when output is None
|
|
self.assertNotIn("total_tokens", usage)
|
|
|
|
def test_extract_usage_with_both_none_tokens(self):
|
|
"""Test token extraction when both token counts are None."""
|
|
response = Mock()
|
|
response.usage_metadata = Mock()
|
|
response.usage_metadata.prompt_token_count = None
|
|
response.usage_metadata.candidates_token_count = None
|
|
|
|
usage = self.provider._extract_usage(response)
|
|
|
|
# Should return empty dict when all tokens are None
|
|
self.assertEqual(usage, {})
|
|
|
|
def test_extract_usage_without_usage_metadata(self):
|
|
"""Test token extraction when response has no usage_metadata."""
|
|
response = Mock(spec=[])
|
|
|
|
usage = self.provider._extract_usage(response)
|
|
|
|
# Should return empty dict
|
|
self.assertEqual(usage, {})
|
|
|
|
def test_extract_usage_with_zero_tokens(self):
|
|
"""Test token extraction with zero token counts."""
|
|
response = Mock()
|
|
response.usage_metadata = Mock()
|
|
response.usage_metadata.prompt_token_count = 0
|
|
response.usage_metadata.candidates_token_count = 0
|
|
|
|
usage = self.provider._extract_usage(response)
|
|
|
|
self.assertEqual(usage["input_tokens"], 0)
|
|
self.assertEqual(usage["output_tokens"], 0)
|
|
self.assertEqual(usage["total_tokens"], 0)
|
|
|
|
def test_extract_usage_missing_attributes(self):
|
|
"""Test token extraction when metadata lacks token count attributes."""
|
|
response = Mock()
|
|
response.usage_metadata = Mock(spec=[])
|
|
|
|
usage = self.provider._extract_usage(response)
|
|
|
|
# Should return empty dict when attributes are missing
|
|
self.assertEqual(usage, {})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|