diff --git a/tests/test_large_prompt_handling.py b/tests/test_large_prompt_handling.py index 54d736b..1b9a92c 100644 --- a/tests/test_large_prompt_handling.py +++ b/tests/test_large_prompt_handling.py @@ -368,6 +368,57 @@ class TestLargePromptHandling: output = json.loads(result[0].text) assert output["status"] in ["success", "continuation_available"] + @pytest.mark.asyncio + async def test_large_file_context_does_not_trigger_mcp_prompt_limit(self, tmp_path): + """Large context files should not be blocked by MCP prompt limit enforcement.""" + from tests.mock_helpers import create_mock_provider + from utils.model_context import TokenAllocation + + tool = ChatTool() + + # Create a file significantly larger than MCP_PROMPT_SIZE_LIMIT characters + large_content = "A" * (MCP_PROMPT_SIZE_LIMIT * 5) + large_file = tmp_path / "huge_context.txt" + large_file.write_text(large_content) + + mock_provider = create_mock_provider(model_name="flash") + mock_provider.generate_content.return_value.content = "Processed large file context" + + class DummyModelContext: + def __init__(self, provider): + self.model_name = "flash" + self._provider = provider + self.capabilities = provider.get_capabilities("flash") + + @property + def provider(self): + return self._provider + + def calculate_token_allocation(self): + return TokenAllocation( + total_tokens=1_048_576, + content_tokens=838_861, + response_tokens=209_715, + file_tokens=335_544, + history_tokens=335_544, + ) + + dummy_context = DummyModelContext(mock_provider) + + with patch.object(tool, "get_model_provider", return_value=mock_provider): + result = await tool.execute( + { + "prompt": "Summarize the design decisions", + "files": [str(large_file)], + "model": "flash", + "_model_context": dummy_context, + } + ) + + output = json.loads(result[0].text) + assert output["status"] in ["success", "continuation_available"] + assert "Processed large file context" in output["content"] + @pytest.mark.asyncio async def test_mcp_boundary_with_large_internal_context(self): """ diff --git a/tools/shared/base_tool.py b/tools/shared/base_tool.py index 927bd28..a297208 100644 --- a/tools/shared/base_tool.py +++ b/tools/shared/base_tool.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: from config import MCP_PROMPT_SIZE_LIMIT from providers import ModelProvider, ModelProviderRegistry -from utils import check_token_limit +from utils import estimate_tokens from utils.conversation_memory import ( ConversationTurn, get_conversation_file_list, @@ -647,22 +647,38 @@ class BaseTool(ABC): def _validate_token_limit(self, content: str, content_type: str = "Content") -> None: """ - Validate that content doesn't exceed the MCP prompt size limit. + Validate that user-provided content doesn't exceed the MCP prompt size limit. + + This enforcement is strictly for text crossing the MCP transport boundary + (i.e., user input). Internal prompt construction may exceed this size and is + governed by model-specific token limits. Args: - content: The content to validate + content: The user-originated content to validate content_type: Description of the content type for error messages Raises: - ValueError: If content exceeds size limit + ValueError: If content exceeds the character size limit """ - is_valid, token_count = check_token_limit(content, MCP_PROMPT_SIZE_LIMIT) - if not is_valid: - error_msg = f"~{token_count:,} tokens. Maximum is {MCP_PROMPT_SIZE_LIMIT:,} tokens." + if not content: + logger.debug(f"{self.name} tool {content_type.lower()} validation skipped (no content)") + return + + char_count = len(content) + if char_count > MCP_PROMPT_SIZE_LIMIT: + token_estimate = estimate_tokens(content) + error_msg = ( + f"{char_count:,} characters (~{token_estimate:,} tokens). " + f"Maximum is {MCP_PROMPT_SIZE_LIMIT:,} characters." + ) logger.error(f"{self.name} tool {content_type.lower()} validation failed: {error_msg}") raise ValueError(f"{content_type} too large: {error_msg}") - logger.debug(f"{self.name} tool {content_type.lower()} token validation passed: {token_count:,} tokens") + token_estimate = estimate_tokens(content) + logger.debug( + f"{self.name} tool {content_type.lower()} validation passed: " + f"{char_count:,} characters (~{token_estimate:,} tokens)" + ) def get_model_provider(self, model_name: str) -> ModelProvider: """ diff --git a/tools/simple/base.py b/tools/simple/base.py index 607f48d..7a04911 100644 --- a/tools/simple/base.py +++ b/tools/simple/base.py @@ -778,7 +778,11 @@ class SimpleTool(BaseTool): Returns: Complete formatted prompt ready for the AI model """ - # Add context files if provided + # Check size limits against raw user input before enriching with internal context + content_to_validate = self.get_prompt_content_for_size_validation(user_content) + self._validate_token_limit(content_to_validate, "Content") + + # Add context files if provided (does not affect MCP boundary enforcement) files = self.get_request_files(request) if files: file_content, processed_files = self._prepare_file_content_for_prompt( @@ -791,10 +795,6 @@ class SimpleTool(BaseTool): if file_content: user_content = f"{user_content}\n\n=== {file_context_title} ===\n{file_content}\n=== END CONTEXT ====" - # Check token limits - only validate original user prompt, not conversation history - content_to_validate = self.get_prompt_content_for_size_validation(user_content) - self._validate_token_limit(content_to_validate, "Content") - # Add standardized web search guidance websearch_instruction = self.get_websearch_instruction(self.get_websearch_guidance())