WIP major refactor and features
This commit is contained in:
130
utils/model_context.py
Normal file
130
utils/model_context.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
Model context management for dynamic token allocation.
|
||||
|
||||
This module provides a clean abstraction for model-specific token management,
|
||||
ensuring that token limits are properly calculated based on the current model
|
||||
being used, not global constants.
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
|
||||
from providers import ModelProviderRegistry, ModelCapabilities
|
||||
from config import DEFAULT_MODEL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenAllocation:
|
||||
"""Token allocation strategy for a model."""
|
||||
total_tokens: int
|
||||
content_tokens: int
|
||||
response_tokens: int
|
||||
file_tokens: int
|
||||
history_tokens: int
|
||||
|
||||
@property
|
||||
def available_for_prompt(self) -> int:
|
||||
"""Tokens available for the actual prompt after allocations."""
|
||||
return self.content_tokens - self.file_tokens - self.history_tokens
|
||||
|
||||
|
||||
class ModelContext:
|
||||
"""
|
||||
Encapsulates model-specific information and token calculations.
|
||||
|
||||
This class provides a single source of truth for all model-related
|
||||
token calculations, ensuring consistency across the system.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
self.model_name = model_name
|
||||
self._provider = None
|
||||
self._capabilities = None
|
||||
self._token_allocation = None
|
||||
|
||||
@property
|
||||
def provider(self):
|
||||
"""Get the model provider lazily."""
|
||||
if self._provider is None:
|
||||
self._provider = ModelProviderRegistry.get_provider_for_model(self.model_name)
|
||||
if not self._provider:
|
||||
raise ValueError(f"No provider found for model: {self.model_name}")
|
||||
return self._provider
|
||||
|
||||
@property
|
||||
def capabilities(self) -> ModelCapabilities:
|
||||
"""Get model capabilities lazily."""
|
||||
if self._capabilities is None:
|
||||
self._capabilities = self.provider.get_capabilities(self.model_name)
|
||||
return self._capabilities
|
||||
|
||||
def calculate_token_allocation(self, reserved_for_response: Optional[int] = None) -> TokenAllocation:
|
||||
"""
|
||||
Calculate token allocation based on model capacity.
|
||||
|
||||
Args:
|
||||
reserved_for_response: Override response token reservation
|
||||
|
||||
Returns:
|
||||
TokenAllocation with calculated budgets
|
||||
"""
|
||||
total_tokens = self.capabilities.max_tokens
|
||||
|
||||
# Dynamic allocation based on model capacity
|
||||
if total_tokens < 300_000:
|
||||
# Smaller context models (O3, GPT-4O): Conservative allocation
|
||||
content_ratio = 0.6 # 60% for content
|
||||
response_ratio = 0.4 # 40% for response
|
||||
file_ratio = 0.3 # 30% of content for files
|
||||
history_ratio = 0.5 # 50% of content for history
|
||||
else:
|
||||
# Larger context models (Gemini): More generous allocation
|
||||
content_ratio = 0.8 # 80% for content
|
||||
response_ratio = 0.2 # 20% for response
|
||||
file_ratio = 0.4 # 40% of content for files
|
||||
history_ratio = 0.4 # 40% of content for history
|
||||
|
||||
# Calculate allocations
|
||||
content_tokens = int(total_tokens * content_ratio)
|
||||
response_tokens = reserved_for_response or int(total_tokens * response_ratio)
|
||||
|
||||
# Sub-allocations within content budget
|
||||
file_tokens = int(content_tokens * file_ratio)
|
||||
history_tokens = int(content_tokens * history_ratio)
|
||||
|
||||
allocation = TokenAllocation(
|
||||
total_tokens=total_tokens,
|
||||
content_tokens=content_tokens,
|
||||
response_tokens=response_tokens,
|
||||
file_tokens=file_tokens,
|
||||
history_tokens=history_tokens
|
||||
)
|
||||
|
||||
logger.debug(f"Token allocation for {self.model_name}:")
|
||||
logger.debug(f" Total: {allocation.total_tokens:,}")
|
||||
logger.debug(f" Content: {allocation.content_tokens:,} ({content_ratio:.0%})")
|
||||
logger.debug(f" Response: {allocation.response_tokens:,} ({response_ratio:.0%})")
|
||||
logger.debug(f" Files: {allocation.file_tokens:,} ({file_ratio:.0%} of content)")
|
||||
logger.debug(f" History: {allocation.history_tokens:,} ({history_ratio:.0%} of content)")
|
||||
|
||||
return allocation
|
||||
|
||||
def estimate_tokens(self, text: str) -> int:
|
||||
"""
|
||||
Estimate token count for text using model-specific tokenizer.
|
||||
|
||||
For now, uses simple estimation. Can be enhanced with model-specific
|
||||
tokenizers (tiktoken for OpenAI, etc.) in the future.
|
||||
"""
|
||||
# TODO: Integrate model-specific tokenizers
|
||||
# For now, use conservative estimation
|
||||
return len(text) // 3 # Conservative estimate
|
||||
|
||||
@classmethod
|
||||
def from_arguments(cls, arguments: Dict[str, Any]) -> "ModelContext":
|
||||
"""Create ModelContext from tool arguments."""
|
||||
model_name = arguments.get("model") or DEFAULT_MODEL
|
||||
return cls(model_name)
|
||||
Reference in New Issue
Block a user