Fixes bug pointed out by @dsaluja (https://github.com/dsaluja)

Fixes other providers not fixed by https://github.com/BeehiveInnovations/zen-mcp-server/pull/66
New regression tests
This commit is contained in:
Fahad
2025-06-17 11:29:45 +04:00
parent be7d80d7aa
commit 77da7b17e6
8 changed files with 270 additions and 258 deletions

View File

@@ -303,12 +303,26 @@ class GeminiModelProvider(ModelProvider):
# Note: The actual structure depends on the SDK version and response format
if hasattr(response, "usage_metadata"):
metadata = response.usage_metadata
# Extract token counts with explicit None checks
input_tokens = None
output_tokens = None
if hasattr(metadata, "prompt_token_count"):
usage["input_tokens"] = metadata.prompt_token_count
value = metadata.prompt_token_count
if value is not None:
input_tokens = value
usage["input_tokens"] = value
if hasattr(metadata, "candidates_token_count"):
usage["output_tokens"] = metadata.candidates_token_count
if "input_tokens" in usage and "output_tokens" in usage:
usage["total_tokens"] = usage["input_tokens"] + usage["output_tokens"]
value = metadata.candidates_token_count
if value is not None:
output_tokens = value
usage["output_tokens"] = value
# Calculate total only if both values are available and valid
if input_tokens is not None and output_tokens is not None:
usage["total_tokens"] = input_tokens + output_tokens
return usage

View File

@@ -300,10 +300,13 @@ class OpenAICompatibleProvider(ModelProvider):
if hasattr(response, "usage"):
usage = self._extract_usage(response)
elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"):
# Safely extract token counts with None handling
input_tokens = getattr(response, "input_tokens", 0) or 0
output_tokens = getattr(response, "output_tokens", 0) or 0
usage = {
"input_tokens": getattr(response, "input_tokens", 0),
"output_tokens": getattr(response, "output_tokens", 0),
"total_tokens": getattr(response, "input_tokens", 0) + getattr(response, "output_tokens", 0),
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
}
return ModelResponse(
@@ -607,9 +610,10 @@ class OpenAICompatibleProvider(ModelProvider):
usage = {}
if hasattr(response, "usage") and response.usage:
usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0)
usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0)
usage["total_tokens"] = getattr(response.usage, "total_tokens", 0)
# Safely extract token counts with None handling
usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0) or 0
usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0) or 0
usage["total_tokens"] = getattr(response.usage, "total_tokens", 0) or 0
return usage