WIP major refactor and features

This commit is contained in:
Fahad
2025-06-12 07:14:59 +04:00
parent e06a6fd1fc
commit 2a067a7f4e
46 changed files with 2960 additions and 1011 deletions

View File

@@ -68,12 +68,15 @@ class ConversationTurn(BaseModel):
the content and metadata needed for cross-tool continuation.
Attributes:
role: "user" (Claude) or "assistant" (Gemini)
role: "user" (Claude) or "assistant" (Gemini/O3/etc)
content: The actual message content/response
timestamp: ISO timestamp when this turn was created
follow_up_question: Optional follow-up question from Gemini to Claude
follow_up_question: Optional follow-up question from assistant to Claude
files: List of file paths referenced in this specific turn
tool_name: Which tool generated this turn (for cross-tool tracking)
model_provider: Provider used (e.g., "google", "openai")
model_name: Specific model used (e.g., "gemini-2.0-flash-exp", "o3-mini")
model_metadata: Additional model-specific metadata (e.g., thinking mode, token usage)
"""
role: str # "user" or "assistant"
@@ -82,6 +85,9 @@ class ConversationTurn(BaseModel):
follow_up_question: Optional[str] = None
files: Optional[list[str]] = None # Files referenced in this turn
tool_name: Optional[str] = None # Tool used for this turn
model_provider: Optional[str] = None # Model provider (google, openai, etc)
model_name: Optional[str] = None # Specific model used
model_metadata: Optional[dict[str, Any]] = None # Additional model info
class ThreadContext(BaseModel):
@@ -94,6 +100,7 @@ class ThreadContext(BaseModel):
Attributes:
thread_id: UUID identifying this conversation thread
parent_thread_id: UUID of parent thread (for conversation chains)
created_at: ISO timestamp when thread was created
last_updated_at: ISO timestamp of last modification
tool_name: Name of the tool that initiated this thread
@@ -102,6 +109,7 @@ class ThreadContext(BaseModel):
"""
thread_id: str
parent_thread_id: Optional[str] = None # Parent thread for conversation chains
created_at: str
last_updated_at: str
tool_name: str # Tool that created this thread (preserved for attribution)
@@ -131,7 +139,7 @@ def get_redis_client():
raise ValueError("redis package required. Install with: pip install redis")
def create_thread(tool_name: str, initial_request: dict[str, Any]) -> str:
def create_thread(tool_name: str, initial_request: dict[str, Any], parent_thread_id: Optional[str] = None) -> str:
"""
Create new conversation thread and return thread ID
@@ -142,6 +150,7 @@ def create_thread(tool_name: str, initial_request: dict[str, Any]) -> str:
Args:
tool_name: Name of the tool creating this thread (e.g., "analyze", "chat")
initial_request: Original request parameters (will be filtered for serialization)
parent_thread_id: Optional parent thread ID for conversation chains
Returns:
str: UUID thread identifier that can be used for continuation
@@ -150,6 +159,7 @@ def create_thread(tool_name: str, initial_request: dict[str, Any]) -> str:
- Thread expires after 1 hour (3600 seconds)
- Non-serializable parameters are filtered out automatically
- Thread can be continued by any tool using the returned UUID
- Parent thread creates a chain for conversation history traversal
"""
thread_id = str(uuid.uuid4())
now = datetime.now(timezone.utc).isoformat()
@@ -163,6 +173,7 @@ def create_thread(tool_name: str, initial_request: dict[str, Any]) -> str:
context = ThreadContext(
thread_id=thread_id,
parent_thread_id=parent_thread_id, # Link to parent for conversation chains
created_at=now,
last_updated_at=now,
tool_name=tool_name, # Track which tool initiated this conversation
@@ -175,6 +186,8 @@ def create_thread(tool_name: str, initial_request: dict[str, Any]) -> str:
key = f"thread:{thread_id}"
client.setex(key, 3600, context.model_dump_json())
logger.debug(f"[THREAD] Created new thread {thread_id} with parent {parent_thread_id}")
return thread_id
@@ -221,34 +234,41 @@ def add_turn(
follow_up_question: Optional[str] = None,
files: Optional[list[str]] = None,
tool_name: Optional[str] = None,
model_provider: Optional[str] = None,
model_name: Optional[str] = None,
model_metadata: Optional[dict[str, Any]] = None,
) -> bool:
"""
Add turn to existing thread
Appends a new conversation turn to an existing thread. This is the core
function for building conversation history and enabling cross-tool
continuation. Each turn preserves the tool that generated it.
continuation. Each turn preserves the tool and model that generated it.
Args:
thread_id: UUID of the conversation thread
role: "user" (Claude) or "assistant" (Gemini)
role: "user" (Claude) or "assistant" (Gemini/O3/etc)
content: The actual message/response content
follow_up_question: Optional follow-up question from Gemini
follow_up_question: Optional follow-up question from assistant
files: Optional list of files referenced in this turn
tool_name: Name of the tool adding this turn (for attribution)
model_provider: Provider used (e.g., "google", "openai")
model_name: Specific model used (e.g., "gemini-2.0-flash-exp", "o3-mini")
model_metadata: Additional model info (e.g., thinking mode, token usage)
Returns:
bool: True if turn was successfully added, False otherwise
Failure cases:
- Thread doesn't exist or expired
- Maximum turn limit reached (5 turns)
- Maximum turn limit reached
- Redis connection failure
Note:
- Refreshes thread TTL to 1 hour on successful update
- Turn limits prevent runaway conversations
- File references are preserved for cross-tool access
- Model information enables cross-provider conversations
"""
logger.debug(f"[FLOW] Adding {role} turn to {thread_id} ({tool_name})")
@@ -270,6 +290,9 @@ def add_turn(
follow_up_question=follow_up_question,
files=files, # Preserved for cross-tool file context
tool_name=tool_name, # Track which tool generated this turn
model_provider=model_provider, # Track model provider
model_name=model_name, # Track specific model
model_metadata=model_metadata, # Additional model info
)
context.turns.append(turn)
@@ -286,6 +309,48 @@ def add_turn(
return False
def get_thread_chain(thread_id: str, max_depth: int = 20) -> list[ThreadContext]:
"""
Traverse the parent chain to get all threads in conversation sequence.
Retrieves the complete conversation chain by following parent_thread_id
links. Returns threads in chronological order (oldest first).
Args:
thread_id: Starting thread ID
max_depth: Maximum chain depth to prevent infinite loops
Returns:
list[ThreadContext]: All threads in chain, oldest first
"""
chain = []
current_id = thread_id
seen_ids = set()
# Build chain from current to oldest
while current_id and len(chain) < max_depth:
# Prevent circular references
if current_id in seen_ids:
logger.warning(f"[THREAD] Circular reference detected in thread chain at {current_id}")
break
seen_ids.add(current_id)
context = get_thread(current_id)
if not context:
logger.debug(f"[THREAD] Thread {current_id} not found in chain traversal")
break
chain.append(context)
current_id = context.parent_thread_id
# Reverse to get chronological order (oldest first)
chain.reverse()
logger.debug(f"[THREAD] Retrieved chain of {len(chain)} threads for {thread_id}")
return chain
def get_conversation_file_list(context: ThreadContext) -> list[str]:
"""
Get all unique files referenced across all turns in a conversation.
@@ -327,7 +392,7 @@ def get_conversation_file_list(context: ThreadContext) -> list[str]:
return unique_files
def build_conversation_history(context: ThreadContext, read_files_func=None) -> tuple[str, int]:
def build_conversation_history(context: ThreadContext, model_context=None, read_files_func=None) -> tuple[str, int]:
"""
Build formatted conversation history for tool prompts with embedded file contents.
@@ -335,9 +400,14 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
full file contents from all referenced files. Files are embedded only ONCE at the
start, even if referenced in multiple turns, to prevent duplication and optimize
token usage.
If the thread has a parent chain, this function traverses the entire chain to
include the complete conversation history.
Args:
context: ThreadContext containing the complete conversation
model_context: ModelContext for token allocation (optional, uses DEFAULT_MODEL if not provided)
read_files_func: Optional function to read files (for testing)
Returns:
tuple[str, int]: (formatted_conversation_history, total_tokens_used)
@@ -355,18 +425,57 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
file contents from previous tools, enabling true cross-tool collaboration
while preventing duplicate file embeddings.
"""
if not context.turns:
# Get the complete thread chain
if context.parent_thread_id:
# This thread has a parent, get the full chain
chain = get_thread_chain(context.thread_id)
# Collect all turns from all threads in chain
all_turns = []
all_files_set = set()
total_turns = 0
for thread in chain:
all_turns.extend(thread.turns)
total_turns += len(thread.turns)
# Collect files from this thread
for turn in thread.turns:
if turn.files:
all_files_set.update(turn.files)
all_files = list(all_files_set)
logger.debug(f"[THREAD] Built history from {len(chain)} threads with {total_turns} total turns")
else:
# Single thread, no parent chain
all_turns = context.turns
total_turns = len(context.turns)
all_files = get_conversation_file_list(context)
if not all_turns:
return "", 0
# Get all unique files referenced in this conversation
all_files = get_conversation_file_list(context)
logger.debug(f"[FILES] Found {len(all_files)} unique files in conversation history")
# Get model-specific token allocation early (needed for both files and turns)
if model_context is None:
from utils.model_context import ModelContext
from config import DEFAULT_MODEL
model_context = ModelContext(DEFAULT_MODEL)
token_allocation = model_context.calculate_token_allocation()
max_file_tokens = token_allocation.file_tokens
max_history_tokens = token_allocation.history_tokens
logger.debug(f"[HISTORY] Using model-specific limits for {model_context.model_name}:")
logger.debug(f"[HISTORY] Max file tokens: {max_file_tokens:,}")
logger.debug(f"[HISTORY] Max history tokens: {max_history_tokens:,}")
history_parts = [
"=== CONVERSATION HISTORY ===",
f"Thread: {context.thread_id}",
f"Tool: {context.tool_name}", # Original tool that started the conversation
f"Turn {len(context.turns)}/{MAX_CONVERSATION_TURNS}",
f"Turn {total_turns}/{MAX_CONVERSATION_TURNS}",
"",
]
@@ -382,9 +491,6 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
]
)
# Import required functions
from config import MAX_CONTENT_TOKENS
if read_files_func is None:
from utils.file_utils import read_file_content
@@ -402,7 +508,7 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
if formatted_content:
# read_file_content already returns formatted content, use it directly
# Check if adding this file would exceed the limit
if total_tokens + content_tokens <= MAX_CONTENT_TOKENS:
if total_tokens + content_tokens <= max_file_tokens:
file_contents.append(formatted_content)
total_tokens += content_tokens
files_included += 1
@@ -415,7 +521,7 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
else:
files_truncated += 1
logger.debug(
f"📄 File truncated due to token limit: {file_path} ({content_tokens:,} tokens, would exceed {MAX_CONTENT_TOKENS:,} limit)"
f"📄 File truncated due to token limit: {file_path} ({content_tokens:,} tokens, would exceed {max_file_tokens:,} limit)"
)
logger.debug(
f"[FILES] File {file_path} would exceed token limit - skipping (would be {total_tokens + content_tokens:,} tokens)"
@@ -464,7 +570,7 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
history_parts.append(files_content)
else:
# Handle token limit exceeded for conversation files
error_message = f"ERROR: The total size of files referenced in this conversation has exceeded the context limit and cannot be displayed.\nEstimated tokens: {estimated_tokens}, but limit is {MAX_CONTENT_TOKENS}."
error_message = f"ERROR: The total size of files referenced in this conversation has exceeded the context limit and cannot be displayed.\nEstimated tokens: {estimated_tokens}, but limit is {max_file_tokens}."
history_parts.append(error_message)
else:
history_parts.append("(No accessible files found)")
@@ -478,29 +584,79 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
)
history_parts.append("Previous conversation turns:")
for i, turn in enumerate(context.turns, 1):
# Build conversation turns bottom-up (most recent first) but present chronologically
# This ensures we include as many recent turns as possible within the token budget
turn_entries = [] # Will store (index, formatted_turn_content) for chronological ordering
total_turn_tokens = 0
file_embedding_tokens = sum(model_context.estimate_tokens(part) for part in history_parts)
# Process turns in reverse order (most recent first) to prioritize recent context
for idx in range(len(all_turns) - 1, -1, -1):
turn = all_turns[idx]
turn_num = idx + 1
role_label = "Claude" if turn.role == "user" else "Gemini"
# Build the complete turn content
turn_parts = []
# Add turn header with tool attribution for cross-tool tracking
turn_header = f"\n--- Turn {i} ({role_label}"
turn_header = f"\n--- Turn {turn_num} ({role_label}"
if turn.tool_name:
turn_header += f" using {turn.tool_name}"
# Add model info if available
if turn.model_provider and turn.model_name:
turn_header += f" via {turn.model_provider}/{turn.model_name}"
turn_header += ") ---"
history_parts.append(turn_header)
turn_parts.append(turn_header)
# Add files context if present - but just reference which files were used
# (the actual contents are already embedded above)
if turn.files:
history_parts.append(f"📁 Files used in this turn: {', '.join(turn.files)}")
history_parts.append("") # Empty line for readability
turn_parts.append(f"📁 Files used in this turn: {', '.join(turn.files)}")
turn_parts.append("") # Empty line for readability
# Add the actual content
history_parts.append(turn.content)
turn_parts.append(turn.content)
# Add follow-up question if present
if turn.follow_up_question:
history_parts.append(f"\n[Gemini's Follow-up: {turn.follow_up_question}]")
turn_parts.append(f"\n[Gemini's Follow-up: {turn.follow_up_question}]")
# Calculate tokens for this turn
turn_content = "\n".join(turn_parts)
turn_tokens = model_context.estimate_tokens(turn_content)
# Check if adding this turn would exceed history budget
if file_embedding_tokens + total_turn_tokens + turn_tokens > max_history_tokens:
# Stop adding turns - we've reached the limit
logger.debug(f"[HISTORY] Stopping at turn {turn_num} - would exceed history budget")
logger.debug(f"[HISTORY] File tokens: {file_embedding_tokens:,}")
logger.debug(f"[HISTORY] Turn tokens so far: {total_turn_tokens:,}")
logger.debug(f"[HISTORY] This turn: {turn_tokens:,}")
logger.debug(f"[HISTORY] Would total: {file_embedding_tokens + total_turn_tokens + turn_tokens:,}")
logger.debug(f"[HISTORY] Budget: {max_history_tokens:,}")
break
# Add this turn to our list (we'll reverse it later for chronological order)
turn_entries.append((idx, turn_content))
total_turn_tokens += turn_tokens
# Reverse to get chronological order (oldest first)
turn_entries.reverse()
# Add the turns in chronological order
for _, turn_content in turn_entries:
history_parts.append(turn_content)
# Log what we included
included_turns = len(turn_entries)
total_turns = len(all_turns)
if included_turns < total_turns:
logger.info(f"[HISTORY] Included {included_turns}/{total_turns} turns due to token limit")
history_parts.append(f"\n[Note: Showing {included_turns} most recent turns out of {total_turns} total]")
history_parts.extend(
["", "=== END CONVERSATION HISTORY ===", "", "Continue this conversation by building on the previous context."]
@@ -513,8 +669,8 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
total_conversation_tokens = estimate_tokens(complete_history)
# Summary log of what was built
user_turns = len([t for t in context.turns if t.role == "user"])
assistant_turns = len([t for t in context.turns if t.role == "assistant"])
user_turns = len([t for t in all_turns if t.role == "user"])
assistant_turns = len([t for t in all_turns if t.role == "assistant"])
logger.debug(
f"[FLOW] Built conversation history: {user_turns} user + {assistant_turns} assistant turns, {len(all_files)} files, {total_conversation_tokens:,} tokens"
)

130
utils/model_context.py Normal file
View File

@@ -0,0 +1,130 @@
"""
Model context management for dynamic token allocation.
This module provides a clean abstraction for model-specific token management,
ensuring that token limits are properly calculated based on the current model
being used, not global constants.
"""
from typing import Optional, Dict, Any
from dataclasses import dataclass
import logging
from providers import ModelProviderRegistry, ModelCapabilities
from config import DEFAULT_MODEL
logger = logging.getLogger(__name__)
@dataclass
class TokenAllocation:
"""Token allocation strategy for a model."""
total_tokens: int
content_tokens: int
response_tokens: int
file_tokens: int
history_tokens: int
@property
def available_for_prompt(self) -> int:
"""Tokens available for the actual prompt after allocations."""
return self.content_tokens - self.file_tokens - self.history_tokens
class ModelContext:
"""
Encapsulates model-specific information and token calculations.
This class provides a single source of truth for all model-related
token calculations, ensuring consistency across the system.
"""
def __init__(self, model_name: str):
self.model_name = model_name
self._provider = None
self._capabilities = None
self._token_allocation = None
@property
def provider(self):
"""Get the model provider lazily."""
if self._provider is None:
self._provider = ModelProviderRegistry.get_provider_for_model(self.model_name)
if not self._provider:
raise ValueError(f"No provider found for model: {self.model_name}")
return self._provider
@property
def capabilities(self) -> ModelCapabilities:
"""Get model capabilities lazily."""
if self._capabilities is None:
self._capabilities = self.provider.get_capabilities(self.model_name)
return self._capabilities
def calculate_token_allocation(self, reserved_for_response: Optional[int] = None) -> TokenAllocation:
"""
Calculate token allocation based on model capacity.
Args:
reserved_for_response: Override response token reservation
Returns:
TokenAllocation with calculated budgets
"""
total_tokens = self.capabilities.max_tokens
# Dynamic allocation based on model capacity
if total_tokens < 300_000:
# Smaller context models (O3, GPT-4O): Conservative allocation
content_ratio = 0.6 # 60% for content
response_ratio = 0.4 # 40% for response
file_ratio = 0.3 # 30% of content for files
history_ratio = 0.5 # 50% of content for history
else:
# Larger context models (Gemini): More generous allocation
content_ratio = 0.8 # 80% for content
response_ratio = 0.2 # 20% for response
file_ratio = 0.4 # 40% of content for files
history_ratio = 0.4 # 40% of content for history
# Calculate allocations
content_tokens = int(total_tokens * content_ratio)
response_tokens = reserved_for_response or int(total_tokens * response_ratio)
# Sub-allocations within content budget
file_tokens = int(content_tokens * file_ratio)
history_tokens = int(content_tokens * history_ratio)
allocation = TokenAllocation(
total_tokens=total_tokens,
content_tokens=content_tokens,
response_tokens=response_tokens,
file_tokens=file_tokens,
history_tokens=history_tokens
)
logger.debug(f"Token allocation for {self.model_name}:")
logger.debug(f" Total: {allocation.total_tokens:,}")
logger.debug(f" Content: {allocation.content_tokens:,} ({content_ratio:.0%})")
logger.debug(f" Response: {allocation.response_tokens:,} ({response_ratio:.0%})")
logger.debug(f" Files: {allocation.file_tokens:,} ({file_ratio:.0%} of content)")
logger.debug(f" History: {allocation.history_tokens:,} ({history_ratio:.0%} of content)")
return allocation
def estimate_tokens(self, text: str) -> int:
"""
Estimate token count for text using model-specific tokenizer.
For now, uses simple estimation. Can be enhanced with model-specific
tokenizers (tiktoken for OpenAI, etc.) in the future.
"""
# TODO: Integrate model-specific tokenizers
# For now, use conservative estimation
return len(text) // 3 # Conservative estimate
@classmethod
def from_arguments(cls, arguments: Dict[str, Any]) -> "ModelContext":
"""Create ModelContext from tool arguments."""
model_name = arguments.get("model") or DEFAULT_MODEL
return cls(model_name)