refactor: cleanup token counting
This commit is contained in:
@@ -15,7 +15,7 @@ Each provider:
|
|||||||
**Option A: Full Provider (`ModelProvider`)**
|
**Option A: Full Provider (`ModelProvider`)**
|
||||||
- For APIs with unique features or custom authentication
|
- For APIs with unique features or custom authentication
|
||||||
- Complete control over API calls and response handling
|
- Complete control over API calls and response handling
|
||||||
- Required methods: `generate_content()`, `count_tokens()`, `get_capabilities()`, `validate_model_name()`, `get_provider_type()`
|
- Required methods: `generate_content()`, `get_capabilities()`, `validate_model_name()`, `get_provider_type()` (override `count_tokens()` only when you have a provider-accurate tokenizer)
|
||||||
|
|
||||||
**Option B: OpenAI-Compatible (`OpenAICompatibleProvider`)**
|
**Option B: OpenAI-Compatible (`OpenAICompatibleProvider`)**
|
||||||
- For APIs that follow OpenAI's chat completion format
|
- For APIs that follow OpenAI's chat completion format
|
||||||
@@ -120,10 +120,6 @@ class ExampleModelProvider(ModelProvider):
|
|||||||
friendly_name="Example",
|
friendly_name="Example",
|
||||||
provider=ProviderType.EXAMPLE,
|
provider=ProviderType.EXAMPLE,
|
||||||
)
|
)
|
||||||
|
|
||||||
def count_tokens(self, text: str, model_name: str) -> int:
|
|
||||||
return len(text) // 4 # Simple estimation
|
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
return ProviderType.EXAMPLE
|
return ProviderType.EXAMPLE
|
||||||
|
|
||||||
@@ -132,6 +128,11 @@ class ExampleModelProvider(ModelProvider):
|
|||||||
return resolved_name in self.MODEL_CAPABILITIES
|
return resolved_name in self.MODEL_CAPABILITIES
|
||||||
```
|
```
|
||||||
|
|
||||||
|
`ModelProvider.count_tokens()` uses a simple 4-characters-per-token estimate so
|
||||||
|
providers work out of the box. Override the method only when you can call into
|
||||||
|
the provider's real tokenizer (for example, the OpenAI-compatible base class
|
||||||
|
already integrates `tiktoken`).
|
||||||
|
|
||||||
#### Option B: OpenAI-Compatible Provider (Simplified)
|
#### Option B: OpenAI-Compatible Provider (Simplified)
|
||||||
|
|
||||||
For OpenAI-compatible APIs:
|
For OpenAI-compatible APIs:
|
||||||
|
|||||||
@@ -73,10 +73,24 @@ class ModelProvider(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def count_tokens(self, text: str, model_name: str) -> int:
|
def count_tokens(self, text: str, model_name: str) -> int:
|
||||||
"""Count tokens for the given text using the specified model's tokenizer."""
|
"""Estimate token usage for a piece of text.
|
||||||
pass
|
|
||||||
|
Providers can rely on this shared implementation or override it when
|
||||||
|
they expose a more accurate tokenizer. This default uses a simple
|
||||||
|
character-based heuristic so it works even without provider-specific
|
||||||
|
tooling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
resolved_model = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Rough estimation: ~4 characters per token for English text
|
||||||
|
estimated = max(1, len(text) // 4)
|
||||||
|
logger.debug("Estimating %s tokens for model %s via character heuristic", estimated, resolved_model)
|
||||||
|
return estimated
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
|
|||||||
@@ -361,15 +361,6 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
error_msg = f"Gemini API error for model {resolved_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
|
error_msg = f"Gemini API error for model {resolved_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
|
||||||
raise RuntimeError(error_msg) from last_exception
|
raise RuntimeError(error_msg) from last_exception
|
||||||
|
|
||||||
def count_tokens(self, text: str, model_name: str) -> int:
|
|
||||||
"""Count tokens for the given text using Gemini's tokenizer."""
|
|
||||||
self._resolve_model_name(model_name)
|
|
||||||
|
|
||||||
# For now, use a simple estimation
|
|
||||||
# TODO: Use actual Gemini tokenizer when available in SDK
|
|
||||||
# Rough estimation: ~4 characters per token for English text
|
|
||||||
return len(text) // 4
|
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
return ProviderType.GOOGLE
|
return ProviderType.GOOGLE
|
||||||
|
|||||||
@@ -622,50 +622,6 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
logging.error(error_msg)
|
logging.error(error_msg)
|
||||||
raise RuntimeError(error_msg) from last_exception
|
raise RuntimeError(error_msg) from last_exception
|
||||||
|
|
||||||
def count_tokens(self, text: str, model_name: str) -> int:
|
|
||||||
"""Count tokens for the given text.
|
|
||||||
|
|
||||||
Uses a layered approach:
|
|
||||||
1. Try provider-specific token counting endpoint
|
|
||||||
2. Try tiktoken for known model families
|
|
||||||
3. Fall back to character-based estimation
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Text to count tokens for
|
|
||||||
model_name: Model name for tokenizer selection
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Estimated token count
|
|
||||||
"""
|
|
||||||
# 1. Check if provider has a remote token counting endpoint
|
|
||||||
if hasattr(self, "count_tokens_remote"):
|
|
||||||
try:
|
|
||||||
return self.count_tokens_remote(text, model_name)
|
|
||||||
except Exception as e:
|
|
||||||
logging.debug(f"Remote token counting failed: {e}")
|
|
||||||
|
|
||||||
# 2. Try tiktoken for known models
|
|
||||||
try:
|
|
||||||
import tiktoken
|
|
||||||
|
|
||||||
# Try to get encoding for the specific model
|
|
||||||
try:
|
|
||||||
encoding = tiktoken.encoding_for_model(model_name)
|
|
||||||
except KeyError:
|
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
|
||||||
|
|
||||||
return len(encoding.encode(text))
|
|
||||||
|
|
||||||
except (ImportError, Exception) as e:
|
|
||||||
logging.debug(f"Tiktoken not available or failed: {e}")
|
|
||||||
|
|
||||||
# 3. Fall back to character-based estimation
|
|
||||||
logging.warning(
|
|
||||||
f"No specific tokenizer available for '{model_name}'. "
|
|
||||||
"Using character-based estimation (~4 chars per token)."
|
|
||||||
)
|
|
||||||
return len(text) // 4
|
|
||||||
|
|
||||||
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
|
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
|
||||||
"""Validate model parameters.
|
"""Validate model parameters.
|
||||||
|
|
||||||
@@ -712,6 +668,26 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
|
|
||||||
return usage
|
return usage
|
||||||
|
|
||||||
|
def count_tokens(self, text: str, model_name: str) -> int:
|
||||||
|
"""Count tokens using OpenAI-compatible tokenizer tables when available."""
|
||||||
|
|
||||||
|
resolved_model = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
try:
|
||||||
|
encoding = tiktoken.encoding_for_model(resolved_model)
|
||||||
|
except KeyError:
|
||||||
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
return len(encoding.encode(text))
|
||||||
|
|
||||||
|
except (ImportError, Exception) as exc:
|
||||||
|
logging.debug("tiktoken unavailable for %s: %s", resolved_model, exc)
|
||||||
|
|
||||||
|
return super().count_tokens(text, model_name)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||||
"""Get capabilities for a specific model.
|
"""Get capabilities for a specific model.
|
||||||
|
|||||||
Reference in New Issue
Block a user