refactor: cleanup token counting

This commit is contained in:
Fahad
2025-10-02 11:35:29 +04:00
parent 14a35afa1d
commit 7fe9fc49f8
4 changed files with 43 additions and 61 deletions

View File

@@ -73,10 +73,24 @@ class ModelProvider(ABC):
"""
pass
@abstractmethod
def count_tokens(self, text: str, model_name: str) -> int:
"""Count tokens for the given text using the specified model's tokenizer."""
pass
"""Estimate token usage for a piece of text.
Providers can rely on this shared implementation or override it when
they expose a more accurate tokenizer. This default uses a simple
character-based heuristic so it works even without provider-specific
tooling.
"""
resolved_model = self._resolve_model_name(model_name)
if not text:
return 0
# Rough estimation: ~4 characters per token for English text
estimated = max(1, len(text) // 4)
logger.debug("Estimating %s tokens for model %s via character heuristic", estimated, resolved_model)
return estimated
@abstractmethod
def get_provider_type(self) -> ProviderType:

View File

@@ -361,15 +361,6 @@ class GeminiModelProvider(ModelProvider):
error_msg = f"Gemini API error for model {resolved_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
raise RuntimeError(error_msg) from last_exception
def count_tokens(self, text: str, model_name: str) -> int:
"""Count tokens for the given text using Gemini's tokenizer."""
self._resolve_model_name(model_name)
# For now, use a simple estimation
# TODO: Use actual Gemini tokenizer when available in SDK
# Rough estimation: ~4 characters per token for English text
return len(text) // 4
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.GOOGLE

View File

@@ -622,50 +622,6 @@ class OpenAICompatibleProvider(ModelProvider):
logging.error(error_msg)
raise RuntimeError(error_msg) from last_exception
def count_tokens(self, text: str, model_name: str) -> int:
"""Count tokens for the given text.
Uses a layered approach:
1. Try provider-specific token counting endpoint
2. Try tiktoken for known model families
3. Fall back to character-based estimation
Args:
text: Text to count tokens for
model_name: Model name for tokenizer selection
Returns:
Estimated token count
"""
# 1. Check if provider has a remote token counting endpoint
if hasattr(self, "count_tokens_remote"):
try:
return self.count_tokens_remote(text, model_name)
except Exception as e:
logging.debug(f"Remote token counting failed: {e}")
# 2. Try tiktoken for known models
try:
import tiktoken
# Try to get encoding for the specific model
try:
encoding = tiktoken.encoding_for_model(model_name)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(text))
except (ImportError, Exception) as e:
logging.debug(f"Tiktoken not available or failed: {e}")
# 3. Fall back to character-based estimation
logging.warning(
f"No specific tokenizer available for '{model_name}'. "
"Using character-based estimation (~4 chars per token)."
)
return len(text) // 4
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
"""Validate model parameters.
@@ -712,6 +668,26 @@ class OpenAICompatibleProvider(ModelProvider):
return usage
def count_tokens(self, text: str, model_name: str) -> int:
"""Count tokens using OpenAI-compatible tokenizer tables when available."""
resolved_model = self._resolve_model_name(model_name)
try:
import tiktoken
try:
encoding = tiktoken.encoding_for_model(resolved_model)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(text))
except (ImportError, Exception) as exc:
logging.debug("tiktoken unavailable for %s: %s", resolved_model, exc)
return super().count_tokens(text, model_name)
@abstractmethod
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific model.