From 7fe9fc49f8e3cd92be4c45a6645d5d4ab3014091 Mon Sep 17 00:00:00 2001 From: Fahad Date: Thu, 2 Oct 2025 11:35:29 +0400 Subject: [PATCH] refactor: cleanup token counting --- docs/adding_providers.md | 11 +++--- providers/base.py | 20 +++++++++-- providers/gemini.py | 9 ----- providers/openai_compatible.py | 64 +++++++++++----------------------- 4 files changed, 43 insertions(+), 61 deletions(-) diff --git a/docs/adding_providers.md b/docs/adding_providers.md index 8c49647..0a62c8c 100644 --- a/docs/adding_providers.md +++ b/docs/adding_providers.md @@ -15,7 +15,7 @@ Each provider: **Option A: Full Provider (`ModelProvider`)** - For APIs with unique features or custom authentication - Complete control over API calls and response handling -- Required methods: `generate_content()`, `count_tokens()`, `get_capabilities()`, `validate_model_name()`, `get_provider_type()` +- Required methods: `generate_content()`, `get_capabilities()`, `validate_model_name()`, `get_provider_type()` (override `count_tokens()` only when you have a provider-accurate tokenizer) **Option B: OpenAI-Compatible (`OpenAICompatibleProvider`)** - For APIs that follow OpenAI's chat completion format @@ -120,10 +120,6 @@ class ExampleModelProvider(ModelProvider): friendly_name="Example", provider=ProviderType.EXAMPLE, ) - - def count_tokens(self, text: str, model_name: str) -> int: - return len(text) // 4 # Simple estimation - def get_provider_type(self) -> ProviderType: return ProviderType.EXAMPLE @@ -132,6 +128,11 @@ class ExampleModelProvider(ModelProvider): return resolved_name in self.MODEL_CAPABILITIES ``` +`ModelProvider.count_tokens()` uses a simple 4-characters-per-token estimate so +providers work out of the box. Override the method only when you can call into +the provider's real tokenizer (for example, the OpenAI-compatible base class +already integrates `tiktoken`). + #### Option B: OpenAI-Compatible Provider (Simplified) For OpenAI-compatible APIs: diff --git a/providers/base.py b/providers/base.py index e8e54f9..d959832 100644 --- a/providers/base.py +++ b/providers/base.py @@ -73,10 +73,24 @@ class ModelProvider(ABC): """ pass - @abstractmethod def count_tokens(self, text: str, model_name: str) -> int: - """Count tokens for the given text using the specified model's tokenizer.""" - pass + """Estimate token usage for a piece of text. + + Providers can rely on this shared implementation or override it when + they expose a more accurate tokenizer. This default uses a simple + character-based heuristic so it works even without provider-specific + tooling. + """ + + resolved_model = self._resolve_model_name(model_name) + + if not text: + return 0 + + # Rough estimation: ~4 characters per token for English text + estimated = max(1, len(text) // 4) + logger.debug("Estimating %s tokens for model %s via character heuristic", estimated, resolved_model) + return estimated @abstractmethod def get_provider_type(self) -> ProviderType: diff --git a/providers/gemini.py b/providers/gemini.py index 2bdc4da..952aab4 100644 --- a/providers/gemini.py +++ b/providers/gemini.py @@ -361,15 +361,6 @@ class GeminiModelProvider(ModelProvider): error_msg = f"Gemini API error for model {resolved_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}" raise RuntimeError(error_msg) from last_exception - def count_tokens(self, text: str, model_name: str) -> int: - """Count tokens for the given text using Gemini's tokenizer.""" - self._resolve_model_name(model_name) - - # For now, use a simple estimation - # TODO: Use actual Gemini tokenizer when available in SDK - # Rough estimation: ~4 characters per token for English text - return len(text) // 4 - def get_provider_type(self) -> ProviderType: """Get the provider type.""" return ProviderType.GOOGLE diff --git a/providers/openai_compatible.py b/providers/openai_compatible.py index fd04e7d..1b0bdfd 100644 --- a/providers/openai_compatible.py +++ b/providers/openai_compatible.py @@ -622,50 +622,6 @@ class OpenAICompatibleProvider(ModelProvider): logging.error(error_msg) raise RuntimeError(error_msg) from last_exception - def count_tokens(self, text: str, model_name: str) -> int: - """Count tokens for the given text. - - Uses a layered approach: - 1. Try provider-specific token counting endpoint - 2. Try tiktoken for known model families - 3. Fall back to character-based estimation - - Args: - text: Text to count tokens for - model_name: Model name for tokenizer selection - - Returns: - Estimated token count - """ - # 1. Check if provider has a remote token counting endpoint - if hasattr(self, "count_tokens_remote"): - try: - return self.count_tokens_remote(text, model_name) - except Exception as e: - logging.debug(f"Remote token counting failed: {e}") - - # 2. Try tiktoken for known models - try: - import tiktoken - - # Try to get encoding for the specific model - try: - encoding = tiktoken.encoding_for_model(model_name) - except KeyError: - encoding = tiktoken.get_encoding("cl100k_base") - - return len(encoding.encode(text)) - - except (ImportError, Exception) as e: - logging.debug(f"Tiktoken not available or failed: {e}") - - # 3. Fall back to character-based estimation - logging.warning( - f"No specific tokenizer available for '{model_name}'. " - "Using character-based estimation (~4 chars per token)." - ) - return len(text) // 4 - def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None: """Validate model parameters. @@ -712,6 +668,26 @@ class OpenAICompatibleProvider(ModelProvider): return usage + def count_tokens(self, text: str, model_name: str) -> int: + """Count tokens using OpenAI-compatible tokenizer tables when available.""" + + resolved_model = self._resolve_model_name(model_name) + + try: + import tiktoken + + try: + encoding = tiktoken.encoding_for_model(resolved_model) + except KeyError: + encoding = tiktoken.get_encoding("cl100k_base") + + return len(encoding.encode(text)) + + except (ImportError, Exception) as exc: + logging.debug("tiktoken unavailable for %s: %s", resolved_model, exc) + + return super().count_tokens(text, model_name) + @abstractmethod def get_capabilities(self, model_name: str) -> ModelCapabilities: """Get capabilities for a specific model.