This commit is contained in:
@@ -368,6 +368,57 @@ class TestLargePromptHandling:
|
|||||||
output = json.loads(result[0].text)
|
output = json.loads(result[0].text)
|
||||||
assert output["status"] in ["success", "continuation_available"]
|
assert output["status"] in ["success", "continuation_available"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_large_file_context_does_not_trigger_mcp_prompt_limit(self, tmp_path):
|
||||||
|
"""Large context files should not be blocked by MCP prompt limit enforcement."""
|
||||||
|
from tests.mock_helpers import create_mock_provider
|
||||||
|
from utils.model_context import TokenAllocation
|
||||||
|
|
||||||
|
tool = ChatTool()
|
||||||
|
|
||||||
|
# Create a file significantly larger than MCP_PROMPT_SIZE_LIMIT characters
|
||||||
|
large_content = "A" * (MCP_PROMPT_SIZE_LIMIT * 5)
|
||||||
|
large_file = tmp_path / "huge_context.txt"
|
||||||
|
large_file.write_text(large_content)
|
||||||
|
|
||||||
|
mock_provider = create_mock_provider(model_name="flash")
|
||||||
|
mock_provider.generate_content.return_value.content = "Processed large file context"
|
||||||
|
|
||||||
|
class DummyModelContext:
|
||||||
|
def __init__(self, provider):
|
||||||
|
self.model_name = "flash"
|
||||||
|
self._provider = provider
|
||||||
|
self.capabilities = provider.get_capabilities("flash")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider(self):
|
||||||
|
return self._provider
|
||||||
|
|
||||||
|
def calculate_token_allocation(self):
|
||||||
|
return TokenAllocation(
|
||||||
|
total_tokens=1_048_576,
|
||||||
|
content_tokens=838_861,
|
||||||
|
response_tokens=209_715,
|
||||||
|
file_tokens=335_544,
|
||||||
|
history_tokens=335_544,
|
||||||
|
)
|
||||||
|
|
||||||
|
dummy_context = DummyModelContext(mock_provider)
|
||||||
|
|
||||||
|
with patch.object(tool, "get_model_provider", return_value=mock_provider):
|
||||||
|
result = await tool.execute(
|
||||||
|
{
|
||||||
|
"prompt": "Summarize the design decisions",
|
||||||
|
"files": [str(large_file)],
|
||||||
|
"model": "flash",
|
||||||
|
"_model_context": dummy_context,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
output = json.loads(result[0].text)
|
||||||
|
assert output["status"] in ["success", "continuation_available"]
|
||||||
|
assert "Processed large file context" in output["content"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mcp_boundary_with_large_internal_context(self):
|
async def test_mcp_boundary_with_large_internal_context(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
from config import MCP_PROMPT_SIZE_LIMIT
|
from config import MCP_PROMPT_SIZE_LIMIT
|
||||||
from providers import ModelProvider, ModelProviderRegistry
|
from providers import ModelProvider, ModelProviderRegistry
|
||||||
from utils import check_token_limit
|
from utils import estimate_tokens
|
||||||
from utils.conversation_memory import (
|
from utils.conversation_memory import (
|
||||||
ConversationTurn,
|
ConversationTurn,
|
||||||
get_conversation_file_list,
|
get_conversation_file_list,
|
||||||
@@ -647,22 +647,38 @@ class BaseTool(ABC):
|
|||||||
|
|
||||||
def _validate_token_limit(self, content: str, content_type: str = "Content") -> None:
|
def _validate_token_limit(self, content: str, content_type: str = "Content") -> None:
|
||||||
"""
|
"""
|
||||||
Validate that content doesn't exceed the MCP prompt size limit.
|
Validate that user-provided content doesn't exceed the MCP prompt size limit.
|
||||||
|
|
||||||
|
This enforcement is strictly for text crossing the MCP transport boundary
|
||||||
|
(i.e., user input). Internal prompt construction may exceed this size and is
|
||||||
|
governed by model-specific token limits.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: The content to validate
|
content: The user-originated content to validate
|
||||||
content_type: Description of the content type for error messages
|
content_type: Description of the content type for error messages
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If content exceeds size limit
|
ValueError: If content exceeds the character size limit
|
||||||
"""
|
"""
|
||||||
is_valid, token_count = check_token_limit(content, MCP_PROMPT_SIZE_LIMIT)
|
if not content:
|
||||||
if not is_valid:
|
logger.debug(f"{self.name} tool {content_type.lower()} validation skipped (no content)")
|
||||||
error_msg = f"~{token_count:,} tokens. Maximum is {MCP_PROMPT_SIZE_LIMIT:,} tokens."
|
return
|
||||||
|
|
||||||
|
char_count = len(content)
|
||||||
|
if char_count > MCP_PROMPT_SIZE_LIMIT:
|
||||||
|
token_estimate = estimate_tokens(content)
|
||||||
|
error_msg = (
|
||||||
|
f"{char_count:,} characters (~{token_estimate:,} tokens). "
|
||||||
|
f"Maximum is {MCP_PROMPT_SIZE_LIMIT:,} characters."
|
||||||
|
)
|
||||||
logger.error(f"{self.name} tool {content_type.lower()} validation failed: {error_msg}")
|
logger.error(f"{self.name} tool {content_type.lower()} validation failed: {error_msg}")
|
||||||
raise ValueError(f"{content_type} too large: {error_msg}")
|
raise ValueError(f"{content_type} too large: {error_msg}")
|
||||||
|
|
||||||
logger.debug(f"{self.name} tool {content_type.lower()} token validation passed: {token_count:,} tokens")
|
token_estimate = estimate_tokens(content)
|
||||||
|
logger.debug(
|
||||||
|
f"{self.name} tool {content_type.lower()} validation passed: "
|
||||||
|
f"{char_count:,} characters (~{token_estimate:,} tokens)"
|
||||||
|
)
|
||||||
|
|
||||||
def get_model_provider(self, model_name: str) -> ModelProvider:
|
def get_model_provider(self, model_name: str) -> ModelProvider:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -778,7 +778,11 @@ class SimpleTool(BaseTool):
|
|||||||
Returns:
|
Returns:
|
||||||
Complete formatted prompt ready for the AI model
|
Complete formatted prompt ready for the AI model
|
||||||
"""
|
"""
|
||||||
# Add context files if provided
|
# Check size limits against raw user input before enriching with internal context
|
||||||
|
content_to_validate = self.get_prompt_content_for_size_validation(user_content)
|
||||||
|
self._validate_token_limit(content_to_validate, "Content")
|
||||||
|
|
||||||
|
# Add context files if provided (does not affect MCP boundary enforcement)
|
||||||
files = self.get_request_files(request)
|
files = self.get_request_files(request)
|
||||||
if files:
|
if files:
|
||||||
file_content, processed_files = self._prepare_file_content_for_prompt(
|
file_content, processed_files = self._prepare_file_content_for_prompt(
|
||||||
@@ -791,10 +795,6 @@ class SimpleTool(BaseTool):
|
|||||||
if file_content:
|
if file_content:
|
||||||
user_content = f"{user_content}\n\n=== {file_context_title} ===\n{file_content}\n=== END CONTEXT ===="
|
user_content = f"{user_content}\n\n=== {file_context_title} ===\n{file_content}\n=== END CONTEXT ===="
|
||||||
|
|
||||||
# Check token limits - only validate original user prompt, not conversation history
|
|
||||||
content_to_validate = self.get_prompt_content_for_size_validation(user_content)
|
|
||||||
self._validate_token_limit(content_to_validate, "Content")
|
|
||||||
|
|
||||||
# Add standardized web search guidance
|
# Add standardized web search guidance
|
||||||
websearch_instruction = self.get_websearch_instruction(self.get_websearch_guidance())
|
websearch_instruction = self.get_websearch_instruction(self.get_websearch_guidance())
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user