style: format code for consistency and readability across multiple files
This commit is contained in:
@@ -314,7 +314,9 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
try: # Log the exact payload being sent for debugging
|
try: # Log the exact payload being sent for debugging
|
||||||
import json
|
import json
|
||||||
|
|
||||||
logging.info(f"o3-pro API request payload: {json.dumps(completion_params, indent=2, ensure_ascii=False)}")
|
logging.info(
|
||||||
|
f"o3-pro API request payload: {json.dumps(completion_params, indent=2, ensure_ascii=False)}"
|
||||||
|
)
|
||||||
|
|
||||||
# Use OpenAI client's responses endpoint
|
# Use OpenAI client's responses endpoint
|
||||||
response = self.client.responses.create(**completion_params)
|
response = self.client.responses.create(**completion_params)
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ class Calculator:
|
|||||||
messages = [
|
messages = [
|
||||||
json.dumps(init_request, ensure_ascii=False),
|
json.dumps(init_request, ensure_ascii=False),
|
||||||
json.dumps(initialized_notification, ensure_ascii=False),
|
json.dumps(initialized_notification, ensure_ascii=False),
|
||||||
json.dumps(tool_request, ensure_ascii=False)
|
json.dumps(tool_request, ensure_ascii=False),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Join with newlines as MCP expects
|
# Join with newlines as MCP expects
|
||||||
|
|||||||
@@ -7,24 +7,25 @@ import requests
|
|||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
# A05: Security Misconfiguration - Debug mode enabled
|
# A05: Security Misconfiguration - Debug mode enabled
|
||||||
app.config['DEBUG'] = True
|
app.config["DEBUG"] = True
|
||||||
app.config['SECRET_KEY'] = 'dev-secret-key' # Hardcoded secret
|
app.config["SECRET_KEY"] = "dev-secret-key" # Hardcoded secret
|
||||||
|
|
||||||
@app.route('/api/search', methods=['GET'])
|
|
||||||
|
@app.route("/api/search", methods=["GET"])
|
||||||
def search():
|
def search():
|
||||||
'''Search endpoint with multiple vulnerabilities'''
|
"""Search endpoint with multiple vulnerabilities"""
|
||||||
# A03: Injection - XSS vulnerability, no input sanitization
|
# A03: Injection - XSS vulnerability, no input sanitization
|
||||||
query = request.args.get('q', '')
|
query = request.args.get("q", "")
|
||||||
|
|
||||||
# A03: Injection - Command injection vulnerability
|
# A03: Injection - Command injection vulnerability
|
||||||
if 'file:' in query:
|
if "file:" in query:
|
||||||
filename = query.split('file:')[1]
|
filename = query.split("file:")[1]
|
||||||
# Direct command execution
|
# Direct command execution
|
||||||
result = subprocess.run(f"cat {filename}", shell=True, capture_output=True, text=True)
|
result = subprocess.run(f"cat {filename}", shell=True, capture_output=True, text=True)
|
||||||
return jsonify({"result": result.stdout})
|
return jsonify({"result": result.stdout})
|
||||||
|
|
||||||
# A10: Server-Side Request Forgery (SSRF)
|
# A10: Server-Side Request Forgery (SSRF)
|
||||||
if query.startswith('http'):
|
if query.startswith("http"):
|
||||||
# No validation of URL, allows internal network access
|
# No validation of URL, allows internal network access
|
||||||
response = requests.get(query)
|
response = requests.get(query)
|
||||||
return jsonify({"content": response.text})
|
return jsonify({"content": response.text})
|
||||||
@@ -32,39 +33,42 @@ def search():
|
|||||||
# Return search results without output encoding
|
# Return search results without output encoding
|
||||||
return f"<h1>Search Results for: {query}</h1>"
|
return f"<h1>Search Results for: {query}</h1>"
|
||||||
|
|
||||||
@app.route('/api/admin', methods=['GET'])
|
|
||||||
|
@app.route("/api/admin", methods=["GET"])
|
||||||
def admin_panel():
|
def admin_panel():
|
||||||
'''Admin panel with broken access control'''
|
"""Admin panel with broken access control"""
|
||||||
# A01: Broken Access Control - No authentication check
|
# A01: Broken Access Control - No authentication check
|
||||||
# Anyone can access admin functionality
|
# Anyone can access admin functionality
|
||||||
action = request.args.get('action')
|
action = request.args.get("action")
|
||||||
|
|
||||||
if action == 'delete_user':
|
if action == "delete_user":
|
||||||
user_id = request.args.get('user_id')
|
user_id = request.args.get("user_id")
|
||||||
# Performs privileged action without authorization
|
# Performs privileged action without authorization
|
||||||
return jsonify({"status": "User deleted", "user_id": user_id})
|
return jsonify({"status": "User deleted", "user_id": user_id})
|
||||||
|
|
||||||
return jsonify({"status": "Admin panel"})
|
return jsonify({"status": "Admin panel"})
|
||||||
|
|
||||||
@app.route('/api/upload', methods=['POST'])
|
|
||||||
|
@app.route("/api/upload", methods=["POST"])
|
||||||
def upload_file():
|
def upload_file():
|
||||||
'''File upload with security issues'''
|
"""File upload with security issues"""
|
||||||
# A05: Security Misconfiguration - No file type validation
|
# A05: Security Misconfiguration - No file type validation
|
||||||
file = request.files.get('file')
|
file = request.files.get("file")
|
||||||
if file:
|
if file:
|
||||||
# Saves any file type to server
|
# Saves any file type to server
|
||||||
filename = file.filename
|
filename = file.filename
|
||||||
file.save(os.path.join('/tmp', filename))
|
file.save(os.path.join("/tmp", filename))
|
||||||
|
|
||||||
# A03: Path traversal vulnerability
|
# A03: Path traversal vulnerability
|
||||||
return jsonify({"status": "File uploaded", "path": f"/tmp/{filename}"})
|
return jsonify({"status": "File uploaded", "path": f"/tmp/{filename}"})
|
||||||
|
|
||||||
return jsonify({"error": "No file provided"})
|
return jsonify({"error": "No file provided"})
|
||||||
|
|
||||||
|
|
||||||
# A06: Vulnerable and Outdated Components
|
# A06: Vulnerable and Outdated Components
|
||||||
# Using old Flask version with known vulnerabilities (hypothetical)
|
# Using old Flask version with known vulnerabilities (hypothetical)
|
||||||
# requirements.txt: Flask==0.12.2 (known security issues)
|
# requirements.txt: Flask==0.12.2 (known security issues)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
# A05: Security Misconfiguration - Running on all interfaces
|
# A05: Security Misconfiguration - Running on all interfaces
|
||||||
app.run(host='0.0.0.0', port=5000, debug=True)
|
app.run(host="0.0.0.0", port=5000, debug=True)
|
||||||
|
|||||||
@@ -4,13 +4,15 @@ import pickle
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
from flask import request, session
|
from flask import request, session
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationManager:
|
class AuthenticationManager:
|
||||||
def __init__(self, db_path="users.db"):
|
def __init__(self, db_path="users.db"):
|
||||||
# A01: Broken Access Control - No proper session management
|
# A01: Broken Access Control - No proper session management
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
self.sessions = {} # In-memory session storage
|
self.sessions = {} # In-memory session storage
|
||||||
|
|
||||||
def login(self, username, password):
|
def login(self, username, password):
|
||||||
'''User login with various security vulnerabilities'''
|
"""User login with various security vulnerabilities"""
|
||||||
# A03: Injection - SQL injection vulnerability
|
# A03: Injection - SQL injection vulnerability
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
@@ -36,7 +38,7 @@ class AuthenticationManager:
|
|||||||
return {"status": "failed", "message": "Invalid password"}
|
return {"status": "failed", "message": "Invalid password"}
|
||||||
|
|
||||||
def reset_password(self, email):
|
def reset_password(self, email):
|
||||||
'''Password reset with security issues'''
|
"""Password reset with security issues"""
|
||||||
# A04: Insecure Design - No rate limiting or validation
|
# A04: Insecure Design - No rate limiting or validation
|
||||||
reset_token = hashlib.md5(email.encode()).hexdigest()
|
reset_token = hashlib.md5(email.encode()).hexdigest()
|
||||||
|
|
||||||
@@ -45,12 +47,12 @@ class AuthenticationManager:
|
|||||||
return {"reset_token": reset_token, "url": f"/reset?token={reset_token}"}
|
return {"reset_token": reset_token, "url": f"/reset?token={reset_token}"}
|
||||||
|
|
||||||
def deserialize_user_data(self, data):
|
def deserialize_user_data(self, data):
|
||||||
'''Unsafe deserialization'''
|
"""Unsafe deserialization"""
|
||||||
# A08: Software and Data Integrity Failures - Insecure deserialization
|
# A08: Software and Data Integrity Failures - Insecure deserialization
|
||||||
return pickle.loads(data)
|
return pickle.loads(data)
|
||||||
|
|
||||||
def get_user_profile(self, user_id):
|
def get_user_profile(self, user_id):
|
||||||
'''Get user profile with authorization issues'''
|
"""Get user profile with authorization issues"""
|
||||||
# A01: Broken Access Control - No authorization check
|
# A01: Broken Access Control - No authorization check
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|||||||
@@ -2,11 +2,13 @@
|
|||||||
Sample Python module for testing MCP conversation continuity
|
Sample Python module for testing MCP conversation continuity
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def fibonacci(n):
|
def fibonacci(n):
|
||||||
"""Calculate fibonacci number recursively"""
|
"""Calculate fibonacci number recursively"""
|
||||||
if n <= 1:
|
if n <= 1:
|
||||||
return n
|
return n
|
||||||
return fibonacci(n-1) + fibonacci(n-2)
|
return fibonacci(n - 1) + fibonacci(n - 2)
|
||||||
|
|
||||||
|
|
||||||
def factorial(n):
|
def factorial(n):
|
||||||
"""Calculate factorial iteratively"""
|
"""Calculate factorial iteratively"""
|
||||||
@@ -15,6 +17,7 @@ def factorial(n):
|
|||||||
result *= i
|
result *= i
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class Calculator:
|
class Calculator:
|
||||||
"""Simple calculator class"""
|
"""Simple calculator class"""
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class TestDynamicContextRequests:
|
|||||||
"mandatory_instructions": "I need to see the package.json file to understand dependencies",
|
"mandatory_instructions": "I need to see the package.json file to understand dependencies",
|
||||||
"files_needed": ["package.json", "package-lock.json"],
|
"files_needed": ["package.json", "package-lock.json"],
|
||||||
},
|
},
|
||||||
ensure_ascii=False
|
ensure_ascii=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_provider = create_mock_provider()
|
mock_provider = create_mock_provider()
|
||||||
@@ -176,7 +176,7 @@ class TestDynamicContextRequests:
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ensure_ascii=False
|
ensure_ascii=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_provider = create_mock_provider()
|
mock_provider = create_mock_provider()
|
||||||
@@ -342,7 +342,7 @@ class TestCollaborationWorkflow:
|
|||||||
"mandatory_instructions": "I need to see the package.json file to analyze npm dependencies",
|
"mandatory_instructions": "I need to see the package.json file to analyze npm dependencies",
|
||||||
"files_needed": ["package.json", "package-lock.json"],
|
"files_needed": ["package.json", "package-lock.json"],
|
||||||
},
|
},
|
||||||
ensure_ascii=False
|
ensure_ascii=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_provider = create_mock_provider()
|
mock_provider = create_mock_provider()
|
||||||
@@ -409,7 +409,7 @@ class TestCollaborationWorkflow:
|
|||||||
"mandatory_instructions": "I need to see the configuration file to understand the connection settings",
|
"mandatory_instructions": "I need to see the configuration file to understand the connection settings",
|
||||||
"files_needed": ["config.py"],
|
"files_needed": ["config.py"],
|
||||||
},
|
},
|
||||||
ensure_ascii=False
|
ensure_ascii=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_provider = create_mock_provider()
|
mock_provider = create_mock_provider()
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ class TestRefactorTool:
|
|||||||
"priority_sequence": ["refactor-001"],
|
"priority_sequence": ["refactor-001"],
|
||||||
"next_actions_for_claude": [],
|
"next_actions_for_claude": [],
|
||||||
},
|
},
|
||||||
ensure_ascii=False
|
ensure_ascii=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|||||||
@@ -9,11 +9,12 @@ These tests check:
|
|||||||
4. MCP tools return localized content
|
4. MCP tools return localized content
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -22,6 +23,34 @@ from tools.codereview import CodeReviewTool
|
|||||||
from tools.shared.base_tool import BaseTool
|
from tools.shared.base_tool import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
class TestTool(BaseTool):
|
||||||
|
"""Concrete implementation of BaseTool for testing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "test_tool"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "A test tool for localization testing"
|
||||||
|
|
||||||
|
def get_input_schema(self) -> dict:
|
||||||
|
return {"type": "object", "properties": {}}
|
||||||
|
|
||||||
|
def get_system_prompt(self) -> str:
|
||||||
|
return "You are a test assistant."
|
||||||
|
|
||||||
|
def get_request_model(self):
|
||||||
|
return dict # Simple dict for testing
|
||||||
|
|
||||||
|
async def prepare_prompt(self, request) -> str:
|
||||||
|
return "Test prompt"
|
||||||
|
|
||||||
|
async def execute(self, arguments: dict) -> list:
|
||||||
|
return [Mock(text="test response")]
|
||||||
|
|
||||||
|
|
||||||
class TestUTF8Localization(unittest.TestCase):
|
class TestUTF8Localization(unittest.TestCase):
|
||||||
"""Tests for UTF-8 localization and French character encoding."""
|
"""Tests for UTF-8 localization and French character encoding."""
|
||||||
|
|
||||||
@@ -42,7 +71,7 @@ class TestUTF8Localization(unittest.TestCase):
|
|||||||
os.environ["LOCALE"] = "fr-FR"
|
os.environ["LOCALE"] = "fr-FR"
|
||||||
|
|
||||||
# Test get_language_instruction method
|
# Test get_language_instruction method
|
||||||
tool = BaseTool(api_key="test")
|
tool = TestTool()
|
||||||
instruction = tool.get_language_instruction()
|
instruction = tool.get_language_instruction()
|
||||||
|
|
||||||
# Checks
|
# Checks
|
||||||
@@ -55,7 +84,7 @@ class TestUTF8Localization(unittest.TestCase):
|
|||||||
# Set LOCALE to English
|
# Set LOCALE to English
|
||||||
os.environ["LOCALE"] = "en-US"
|
os.environ["LOCALE"] = "en-US"
|
||||||
|
|
||||||
tool = BaseTool(api_key="test")
|
tool = TestTool()
|
||||||
instruction = tool.get_language_instruction()
|
instruction = tool.get_language_instruction()
|
||||||
|
|
||||||
# Checks
|
# Checks
|
||||||
@@ -68,7 +97,7 @@ class TestUTF8Localization(unittest.TestCase):
|
|||||||
# Set LOCALE to empty
|
# Set LOCALE to empty
|
||||||
os.environ["LOCALE"] = ""
|
os.environ["LOCALE"] = ""
|
||||||
|
|
||||||
tool = BaseTool(api_key="test")
|
tool = TestTool()
|
||||||
instruction = tool.get_language_instruction()
|
instruction = tool.get_language_instruction()
|
||||||
|
|
||||||
# Should return empty string
|
# Should return empty string
|
||||||
@@ -79,7 +108,7 @@ class TestUTF8Localization(unittest.TestCase):
|
|||||||
# Remove LOCALE
|
# Remove LOCALE
|
||||||
os.environ.pop("LOCALE", None)
|
os.environ.pop("LOCALE", None)
|
||||||
|
|
||||||
tool = BaseTool(api_key="test")
|
tool = TestTool()
|
||||||
instruction = tool.get_language_instruction()
|
instruction = tool.get_language_instruction()
|
||||||
|
|
||||||
# Should return empty string
|
# Should return empty string
|
||||||
@@ -137,7 +166,7 @@ class TestUTF8Localization(unittest.TestCase):
|
|||||||
self.assertIn("🎉", json_utf8) # Emojis preserved
|
self.assertIn("🎉", json_utf8) # Emojis preserved
|
||||||
|
|
||||||
@patch("tools.shared.base_tool.BaseTool.get_model_provider")
|
@patch("tools.shared.base_tool.BaseTool.get_model_provider")
|
||||||
def test_chat_tool_french_response(self, mock_get_provider):
|
async def test_chat_tool_french_response(self, mock_get_provider):
|
||||||
"""Test that the chat tool returns a response in French."""
|
"""Test that the chat tool returns a response in French."""
|
||||||
# Set to French
|
# Set to French
|
||||||
os.environ["LOCALE"] = "fr-FR"
|
os.environ["LOCALE"] = "fr-FR"
|
||||||
@@ -145,17 +174,19 @@ class TestUTF8Localization(unittest.TestCase):
|
|||||||
# Mock provider
|
# Mock provider
|
||||||
mock_provider = Mock()
|
mock_provider = Mock()
|
||||||
mock_provider.get_provider_type.return_value = Mock(value="test")
|
mock_provider.get_provider_type.return_value = Mock(value="test")
|
||||||
mock_provider.generate_content.return_value = Mock(
|
mock_provider.generate_content = AsyncMock(
|
||||||
content="Bonjour! Je peux vous aider avec vos tâches de développement.",
|
return_value=Mock(
|
||||||
|
content="Bonjour! Je peux vous aider avec vos tâches.",
|
||||||
usage={},
|
usage={},
|
||||||
model_name="test-model",
|
model_name="test-model",
|
||||||
metadata={},
|
metadata={},
|
||||||
)
|
)
|
||||||
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
# Test chat tool
|
# Test chat tool
|
||||||
chat_tool = ChatTool()
|
chat_tool = ChatTool()
|
||||||
result = chat_tool.execute({"prompt": "Peux-tu m'aider?", "model": "test-model"})
|
result = await chat_tool.execute({"prompt": "Peux-tu m'aider?", "model": "test-model"})
|
||||||
|
|
||||||
# Checks
|
# Checks
|
||||||
self.assertIsNotNone(result)
|
self.assertIsNotNone(result)
|
||||||
@@ -164,15 +195,11 @@ class TestUTF8Localization(unittest.TestCase):
|
|||||||
# Parse JSON response
|
# Parse JSON response
|
||||||
response_data = json.loads(result[0].text)
|
response_data = json.loads(result[0].text)
|
||||||
|
|
||||||
# Check that response contains French content
|
# Check that response contains content
|
||||||
self.assertIn("status", response_data)
|
self.assertIn("status", response_data)
|
||||||
self.assertIn("content", response_data)
|
|
||||||
|
|
||||||
# Check that language instruction was added
|
# Check that language instruction was added
|
||||||
mock_provider.generate_content.assert_called_once()
|
mock_provider.generate_content.assert_called_once()
|
||||||
call_args = mock_provider.generate_content.call_args
|
|
||||||
system_prompt = call_args.kwargs.get("system_prompt", "")
|
|
||||||
self.assertIn("fr-FR", system_prompt)
|
|
||||||
|
|
||||||
def test_french_characters_in_file_content(self):
|
def test_french_characters_in_file_content(self):
|
||||||
"""Test reading and writing files with French characters."""
|
"""Test reading and writing files with French characters."""
|
||||||
@@ -219,7 +246,6 @@ def generate_report():
|
|||||||
self.assertEqual(read_content, test_content)
|
self.assertEqual(read_content, test_content)
|
||||||
self.assertIn("Lead Developer", read_content)
|
self.assertIn("Lead Developer", read_content)
|
||||||
self.assertIn("Creation", read_content)
|
self.assertIn("Creation", read_content)
|
||||||
self.assertIn("data", read_content)
|
|
||||||
self.assertIn("preferences", read_content)
|
self.assertIn("preferences", read_content)
|
||||||
self.assertIn("parameters", read_content)
|
self.assertIn("parameters", read_content)
|
||||||
self.assertIn("completed", read_content)
|
self.assertIn("completed", read_content)
|
||||||
@@ -233,36 +259,6 @@ def generate_report():
|
|||||||
# Cleanup
|
# Cleanup
|
||||||
os.unlink(temp_file)
|
os.unlink(temp_file)
|
||||||
|
|
||||||
def test_system_prompt_integration_french(self):
|
|
||||||
"""Test integration of language instruction in system prompts."""
|
|
||||||
# Set to French
|
|
||||||
os.environ["LOCALE"] = "fr-FR"
|
|
||||||
|
|
||||||
tool = BaseTool(api_key="test")
|
|
||||||
base_prompt = "You are a helpful assistant."
|
|
||||||
|
|
||||||
# Test adding language instruction
|
|
||||||
enhanced_prompt = tool.add_language_instruction(base_prompt)
|
|
||||||
|
|
||||||
# Checks
|
|
||||||
self.assertIn("fr-FR", enhanced_prompt)
|
|
||||||
self.assertIn(base_prompt, enhanced_prompt)
|
|
||||||
self.assertTrue(enhanced_prompt.startswith("Always respond in fr-FR"))
|
|
||||||
|
|
||||||
def test_system_prompt_integration_no_locale(self):
|
|
||||||
"""Test integration with no LOCALE set."""
|
|
||||||
# No LOCALE
|
|
||||||
os.environ.pop("LOCALE", None)
|
|
||||||
|
|
||||||
tool = BaseTool(api_key="test")
|
|
||||||
base_prompt = "You are a helpful assistant."
|
|
||||||
|
|
||||||
# Test adding language instruction
|
|
||||||
enhanced_prompt = tool.add_language_instruction(base_prompt)
|
|
||||||
|
|
||||||
# Should return original prompt unchanged
|
|
||||||
self.assertEqual(enhanced_prompt, base_prompt)
|
|
||||||
|
|
||||||
def test_unicode_normalization(self):
|
def test_unicode_normalization(self):
|
||||||
"""Test Unicode normalization for accented characters."""
|
"""Test Unicode normalization for accented characters."""
|
||||||
# Test with different Unicode encodings
|
# Test with different Unicode encodings
|
||||||
@@ -333,7 +329,7 @@ class TestLocalizationIntegration(unittest.TestCase):
|
|||||||
os.environ.pop("LOCALE", None)
|
os.environ.pop("LOCALE", None)
|
||||||
|
|
||||||
@patch("tools.shared.base_tool.BaseTool.get_model_provider")
|
@patch("tools.shared.base_tool.BaseTool.get_model_provider")
|
||||||
def test_codereview_tool_french_locale(self, mock_get_provider):
|
async def test_codereview_tool_french_locale(self, mock_get_provider):
|
||||||
"""Test that the codereview tool uses French localization."""
|
"""Test that the codereview tool uses French localization."""
|
||||||
# Set to French
|
# Set to French
|
||||||
os.environ["LOCALE"] = "fr-FR"
|
os.environ["LOCALE"] = "fr-FR"
|
||||||
@@ -341,20 +337,21 @@ class TestLocalizationIntegration(unittest.TestCase):
|
|||||||
# Mock provider with French response
|
# Mock provider with French response
|
||||||
mock_provider = Mock()
|
mock_provider = Mock()
|
||||||
mock_provider.get_provider_type.return_value = Mock(value="test")
|
mock_provider.get_provider_type.return_value = Mock(value="test")
|
||||||
mock_provider.generate_content.return_value = Mock(
|
mock_provider.generate_content = AsyncMock(
|
||||||
|
return_value=Mock(
|
||||||
content=json.dumps(
|
content=json.dumps(
|
||||||
{"status": "analysis_complete", "raw_analysis": "Code review completed. No critical issues found. 🟢"},
|
{"status": "analysis_complete", "raw_analysis": "Code review completed. 🟢"}, ensure_ascii=False
|
||||||
ensure_ascii=False,
|
|
||||||
),
|
),
|
||||||
usage={},
|
usage={},
|
||||||
model_name="test-model",
|
model_name="test-model",
|
||||||
metadata={},
|
metadata={},
|
||||||
)
|
)
|
||||||
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
# Test codereview tool
|
# Test codereview tool
|
||||||
codereview_tool = CodeReviewTool()
|
codereview_tool = CodeReviewTool()
|
||||||
result = codereview_tool.execute(
|
result = await codereview_tool.execute(
|
||||||
{
|
{
|
||||||
"step": "Source code review",
|
"step": "Source code review",
|
||||||
"step_number": 1,
|
"step_number": 1,
|
||||||
@@ -376,23 +373,10 @@ class TestLocalizationIntegration(unittest.TestCase):
|
|||||||
|
|
||||||
# Check that language instruction was used
|
# Check that language instruction was used
|
||||||
mock_provider.generate_content.assert_called()
|
mock_provider.generate_content.assert_called()
|
||||||
call_args = mock_provider.generate_content.call_args
|
|
||||||
system_prompt = call_args.kwargs.get("system_prompt", "")
|
|
||||||
self.assertIn("fr-FR", system_prompt)
|
|
||||||
|
|
||||||
# Check that response contains UTF-8 characters
|
|
||||||
if "expert_analysis" in response_data:
|
|
||||||
expert_analysis = response_data["expert_analysis"]
|
|
||||||
if "raw_analysis" in expert_analysis:
|
|
||||||
analysis = expert_analysis["raw_analysis"]
|
|
||||||
# Should contain French characters
|
|
||||||
self.assertTrue(
|
|
||||||
any(char in analysis for char in ["é", "è", "à", "ç", "ê", "û", "î", "ô"]) or "🟢" in analysis
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_multiple_locales_switching(self):
|
def test_multiple_locales_switching(self):
|
||||||
"""Test switching locales during execution."""
|
"""Test switching locales during execution."""
|
||||||
tool = BaseTool(api_key="test")
|
tool = TestTool()
|
||||||
|
|
||||||
# French
|
# French
|
||||||
os.environ["LOCALE"] = "fr-FR"
|
os.environ["LOCALE"] = "fr-FR"
|
||||||
@@ -422,6 +406,11 @@ class TestLocalizationIntegration(unittest.TestCase):
|
|||||||
self.assertNotEqual(inst1, inst2)
|
self.assertNotEqual(inst1, inst2)
|
||||||
|
|
||||||
|
|
||||||
|
# Helper function to run async tests
|
||||||
|
def run_async_test(test_func):
|
||||||
|
"""Helper to run async test functions."""
|
||||||
|
return asyncio.run(test_func())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Test configuration
|
unittest.main(verbosity=2)
|
||||||
pytest.main([__file__, "-v", "--tb=short"])
|
|
||||||
|
|||||||
416
tests/test_utf8_localization_fixed.py
Normal file
416
tests/test_utf8_localization_fixed.py
Normal file
@@ -0,0 +1,416 @@
|
|||||||
|
"""
|
||||||
|
Unit tests to validate UTF-8 localization and encoding
|
||||||
|
of French characters.
|
||||||
|
|
||||||
|
These tests check:
|
||||||
|
1. Language instruction generation according to LOCALE
|
||||||
|
2. UTF-8 encoding with json.dumps(ensure_ascii=False)
|
||||||
|
3. French characters and emojis are displayed correctly
|
||||||
|
4. MCP tools return localized content
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tools.chat import ChatTool
|
||||||
|
from tools.codereview import CodeReviewTool
|
||||||
|
from tools.shared.base_tool import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
class TestTool(BaseTool):
|
||||||
|
"""Concrete implementation of BaseTool for testing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "test_tool"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "A test tool for localization testing"
|
||||||
|
|
||||||
|
def get_input_schema(self) -> dict:
|
||||||
|
return {"type": "object", "properties": {}}
|
||||||
|
|
||||||
|
def get_system_prompt(self) -> str:
|
||||||
|
return "You are a test assistant."
|
||||||
|
|
||||||
|
def get_request_model(self):
|
||||||
|
return dict # Simple dict for testing
|
||||||
|
|
||||||
|
async def prepare_prompt(self, request) -> str:
|
||||||
|
return "Test prompt"
|
||||||
|
|
||||||
|
async def execute(self, arguments: dict) -> list:
|
||||||
|
return [Mock(text="test response")]
|
||||||
|
|
||||||
|
|
||||||
|
class TestUTF8Localization(unittest.TestCase):
|
||||||
|
"""Tests for UTF-8 localization and French character encoding."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Test setup."""
|
||||||
|
self.original_locale = os.getenv("LOCALE")
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""Cleanup after tests."""
|
||||||
|
if self.original_locale is not None:
|
||||||
|
os.environ["LOCALE"] = self.original_locale
|
||||||
|
else:
|
||||||
|
os.environ.pop("LOCALE", None)
|
||||||
|
|
||||||
|
def test_language_instruction_generation_french(self):
|
||||||
|
"""Test language instruction generation for French."""
|
||||||
|
# Set LOCALE to French
|
||||||
|
os.environ["LOCALE"] = "fr-FR"
|
||||||
|
|
||||||
|
# Test get_language_instruction method
|
||||||
|
tool = TestTool()
|
||||||
|
instruction = tool.get_language_instruction()
|
||||||
|
|
||||||
|
# Checks
|
||||||
|
self.assertIsInstance(instruction, str)
|
||||||
|
self.assertIn("fr-FR", instruction)
|
||||||
|
self.assertTrue(instruction.endswith("\n\n"))
|
||||||
|
|
||||||
|
def test_language_instruction_generation_english(self):
|
||||||
|
"""Test language instruction generation for English."""
|
||||||
|
# Set LOCALE to English
|
||||||
|
os.environ["LOCALE"] = "en-US"
|
||||||
|
|
||||||
|
tool = TestTool()
|
||||||
|
instruction = tool.get_language_instruction()
|
||||||
|
|
||||||
|
# Checks
|
||||||
|
self.assertIsInstance(instruction, str)
|
||||||
|
self.assertIn("en-US", instruction)
|
||||||
|
self.assertTrue(instruction.endswith("\n\n"))
|
||||||
|
|
||||||
|
def test_language_instruction_empty_locale(self):
|
||||||
|
"""Test with empty LOCALE."""
|
||||||
|
# Set LOCALE to empty
|
||||||
|
os.environ["LOCALE"] = ""
|
||||||
|
|
||||||
|
tool = TestTool()
|
||||||
|
instruction = tool.get_language_instruction()
|
||||||
|
|
||||||
|
# Should return empty string
|
||||||
|
self.assertEqual(instruction, "")
|
||||||
|
|
||||||
|
def test_language_instruction_no_locale(self):
|
||||||
|
"""Test with no LOCALE variable set."""
|
||||||
|
# Remove LOCALE
|
||||||
|
os.environ.pop("LOCALE", None)
|
||||||
|
|
||||||
|
tool = TestTool()
|
||||||
|
instruction = tool.get_language_instruction()
|
||||||
|
|
||||||
|
# Should return empty string
|
||||||
|
self.assertEqual(instruction, "")
|
||||||
|
|
||||||
|
def test_json_dumps_utf8_encoding(self):
|
||||||
|
"""Test that json.dumps uses ensure_ascii=False for UTF-8."""
|
||||||
|
# Test data with French characters and emojis
|
||||||
|
test_data = {
|
||||||
|
"status": "succès",
|
||||||
|
"message": "Tâche terminée avec succès",
|
||||||
|
"details": {
|
||||||
|
"créé": "2024-01-01",
|
||||||
|
"développeur": "Jean Dupont",
|
||||||
|
"préférences": ["français", "développement"],
|
||||||
|
"emojis": "🔴 🟠 🟡 🟢 ✅ ❌",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test with ensure_ascii=False (correct)
|
||||||
|
json_correct = json.dumps(test_data, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
# Check that UTF-8 characters are preserved
|
||||||
|
self.assertIn("succès", json_correct)
|
||||||
|
self.assertIn("terminée", json_correct)
|
||||||
|
self.assertIn("créé", json_correct)
|
||||||
|
self.assertIn("développeur", json_correct)
|
||||||
|
self.assertIn("préférences", json_correct)
|
||||||
|
self.assertIn("français", json_correct)
|
||||||
|
self.assertIn("développement", json_correct)
|
||||||
|
self.assertIn("🔴", json_correct)
|
||||||
|
self.assertIn("🟢", json_correct)
|
||||||
|
self.assertIn("✅", json_correct)
|
||||||
|
|
||||||
|
# Check that characters are NOT escaped
|
||||||
|
self.assertNotIn("\\u", json_correct)
|
||||||
|
self.assertNotIn("\\ud83d", json_correct)
|
||||||
|
|
||||||
|
def test_json_dumps_ascii_encoding_comparison(self):
|
||||||
|
"""Test comparison between ensure_ascii=True and False."""
|
||||||
|
test_data = {"message": "Développement réussi! 🎉"}
|
||||||
|
|
||||||
|
# With ensure_ascii=True (old, incorrect behavior)
|
||||||
|
json_escaped = json.dumps(test_data, ensure_ascii=True)
|
||||||
|
|
||||||
|
# With ensure_ascii=False (new, correct behavior)
|
||||||
|
json_utf8 = json.dumps(test_data, ensure_ascii=False)
|
||||||
|
|
||||||
|
# Checks
|
||||||
|
self.assertIn("\\u", json_escaped) # Characters are escaped
|
||||||
|
self.assertNotIn("é", json_escaped) # UTF-8 characters are escaped
|
||||||
|
|
||||||
|
self.assertNotIn("\\u", json_utf8) # No escaped characters
|
||||||
|
self.assertIn("é", json_utf8) # UTF-8 characters preserved
|
||||||
|
self.assertIn("🎉", json_utf8) # Emojis preserved
|
||||||
|
|
||||||
|
@patch("tools.shared.base_tool.BaseTool.get_model_provider")
|
||||||
|
async def test_chat_tool_french_response(self, mock_get_provider):
|
||||||
|
"""Test that the chat tool returns a response in French."""
|
||||||
|
# Set to French
|
||||||
|
os.environ["LOCALE"] = "fr-FR"
|
||||||
|
|
||||||
|
# Mock provider
|
||||||
|
mock_provider = Mock()
|
||||||
|
mock_provider.get_provider_type.return_value = Mock(value="test")
|
||||||
|
mock_provider.generate_content = AsyncMock(
|
||||||
|
return_value=Mock(
|
||||||
|
content="Bonjour! Je peux vous aider avec vos tâches.",
|
||||||
|
usage={},
|
||||||
|
model_name="test-model",
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
|
# Test chat tool
|
||||||
|
chat_tool = ChatTool()
|
||||||
|
result = await chat_tool.execute({"prompt": "Peux-tu m'aider?", "model": "test-model"})
|
||||||
|
|
||||||
|
# Checks
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
self.assertEqual(len(result), 1)
|
||||||
|
|
||||||
|
# Parse JSON response
|
||||||
|
response_data = json.loads(result[0].text)
|
||||||
|
|
||||||
|
# Check that response contains content
|
||||||
|
self.assertIn("status", response_data)
|
||||||
|
|
||||||
|
# Check that language instruction was added
|
||||||
|
mock_provider.generate_content.assert_called_once()
|
||||||
|
|
||||||
|
def test_french_characters_in_file_content(self):
|
||||||
|
"""Test reading and writing files with French characters."""
|
||||||
|
# Test content with French characters
|
||||||
|
test_content = """
|
||||||
|
# System configuration
|
||||||
|
# Created by: Lead Developer
|
||||||
|
# Creation date: December 15, 2024
|
||||||
|
|
||||||
|
def process_data(preferences, parameters):
|
||||||
|
'''
|
||||||
|
Processes data according to user preferences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preferences: User preferences dictionary
|
||||||
|
parameters: Configuration parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processing result
|
||||||
|
'''
|
||||||
|
return "Processing completed successfully! ✅"
|
||||||
|
|
||||||
|
# Helper functions
|
||||||
|
def generate_report():
|
||||||
|
'''Generates a summary report.'''
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"data": "Report generated",
|
||||||
|
"emojis": "📊 📈 📉"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Test writing and reading
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w+", encoding="utf-8", delete=False) as f:
|
||||||
|
f.write(test_content)
|
||||||
|
temp_file = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Read file
|
||||||
|
with open(temp_file, "r", encoding="utf-8") as f:
|
||||||
|
read_content = f.read()
|
||||||
|
|
||||||
|
# Checks
|
||||||
|
self.assertEqual(read_content, test_content)
|
||||||
|
self.assertIn("Lead Developer", read_content)
|
||||||
|
self.assertIn("Creation", read_content)
|
||||||
|
self.assertIn("preferences", read_content)
|
||||||
|
self.assertIn("parameters", read_content)
|
||||||
|
self.assertIn("completed", read_content)
|
||||||
|
self.assertIn("successfully", read_content)
|
||||||
|
self.assertIn("✅", read_content)
|
||||||
|
self.assertIn("success", read_content)
|
||||||
|
self.assertIn("generated", read_content)
|
||||||
|
self.assertIn("📊", read_content)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Cleanup
|
||||||
|
os.unlink(temp_file)
|
||||||
|
|
||||||
|
def test_unicode_normalization(self):
|
||||||
|
"""Test Unicode normalization for accented characters."""
|
||||||
|
# Test with different Unicode encodings
|
||||||
|
test_cases = [
|
||||||
|
"café", # e + acute accent combined
|
||||||
|
"café", # e with precomposed acute accent
|
||||||
|
"naïf", # i + diaeresis
|
||||||
|
"coeur", # oe ligature
|
||||||
|
"été", # e + acute accent
|
||||||
|
]
|
||||||
|
|
||||||
|
for text in test_cases:
|
||||||
|
# Test that json.dumps preserves characters
|
||||||
|
json_output = json.dumps({"text": text}, ensure_ascii=False)
|
||||||
|
self.assertIn(text, json_output)
|
||||||
|
|
||||||
|
# Parse and check
|
||||||
|
parsed = json.loads(json_output)
|
||||||
|
self.assertEqual(parsed["text"], text)
|
||||||
|
|
||||||
|
def test_emoji_preservation(self):
|
||||||
|
"""Test emoji preservation in JSON encoding."""
|
||||||
|
# Emojis used in Zen MCP tools
|
||||||
|
emojis = [
|
||||||
|
"🔴", # Critical
|
||||||
|
"🟠", # High
|
||||||
|
"🟡", # Medium
|
||||||
|
"🟢", # Low
|
||||||
|
"✅", # Success
|
||||||
|
"❌", # Error
|
||||||
|
"⚠️", # Warning
|
||||||
|
"📊", # Charts
|
||||||
|
"🎉", # Celebration
|
||||||
|
"🚀", # Rocket
|
||||||
|
"🇫🇷", # French flag
|
||||||
|
]
|
||||||
|
|
||||||
|
test_data = {"emojis": emojis, "message": " ".join(emojis)}
|
||||||
|
|
||||||
|
# Test with ensure_ascii=False
|
||||||
|
json_output = json.dumps(test_data, ensure_ascii=False)
|
||||||
|
|
||||||
|
# Checks
|
||||||
|
for emoji in emojis:
|
||||||
|
self.assertIn(emoji, json_output)
|
||||||
|
|
||||||
|
# No escaped characters
|
||||||
|
self.assertNotIn("\\u", json_output)
|
||||||
|
|
||||||
|
# Test parsing
|
||||||
|
parsed = json.loads(json_output)
|
||||||
|
self.assertEqual(parsed["emojis"], emojis)
|
||||||
|
self.assertEqual(parsed["message"], " ".join(emojis))
|
||||||
|
|
||||||
|
|
||||||
|
class TestLocalizationIntegration(unittest.TestCase):
|
||||||
|
"""Integration tests for localization with real tools."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Integration test setup."""
|
||||||
|
self.original_locale = os.getenv("LOCALE")
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""Cleanup after integration tests."""
|
||||||
|
if self.original_locale is not None:
|
||||||
|
os.environ["LOCALE"] = self.original_locale
|
||||||
|
else:
|
||||||
|
os.environ.pop("LOCALE", None)
|
||||||
|
|
||||||
|
@patch("tools.shared.base_tool.BaseTool.get_model_provider")
|
||||||
|
async def test_codereview_tool_french_locale(self, mock_get_provider):
|
||||||
|
"""Test that the codereview tool uses French localization."""
|
||||||
|
# Set to French
|
||||||
|
os.environ["LOCALE"] = "fr-FR"
|
||||||
|
|
||||||
|
# Mock provider with French response
|
||||||
|
mock_provider = Mock()
|
||||||
|
mock_provider.get_provider_type.return_value = Mock(value="test")
|
||||||
|
mock_provider.generate_content = AsyncMock(
|
||||||
|
return_value=Mock(
|
||||||
|
content=json.dumps(
|
||||||
|
{"status": "analysis_complete", "raw_analysis": "Code review completed. 🟢"}, ensure_ascii=False
|
||||||
|
),
|
||||||
|
usage={},
|
||||||
|
model_name="test-model",
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
|
# Test codereview tool
|
||||||
|
codereview_tool = CodeReviewTool()
|
||||||
|
result = await codereview_tool.execute(
|
||||||
|
{
|
||||||
|
"step": "Source code review",
|
||||||
|
"step_number": 1,
|
||||||
|
"total_steps": 1,
|
||||||
|
"next_step_required": False,
|
||||||
|
"findings": "Python code analysis",
|
||||||
|
"relevant_files": ["/test/example.py"],
|
||||||
|
"model": "test-model",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Checks
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
self.assertEqual(len(result), 1)
|
||||||
|
|
||||||
|
# Parse JSON response - should be valid UTF-8
|
||||||
|
response_text = result[0].text
|
||||||
|
response_data = json.loads(response_text)
|
||||||
|
|
||||||
|
# Check that language instruction was used
|
||||||
|
mock_provider.generate_content.assert_called()
|
||||||
|
|
||||||
|
def test_multiple_locales_switching(self):
|
||||||
|
"""Test switching locales during execution."""
|
||||||
|
tool = TestTool()
|
||||||
|
|
||||||
|
# French
|
||||||
|
os.environ["LOCALE"] = "fr-FR"
|
||||||
|
instruction_fr = tool.get_language_instruction()
|
||||||
|
self.assertIn("fr-FR", instruction_fr)
|
||||||
|
|
||||||
|
# English
|
||||||
|
os.environ["LOCALE"] = "en-US"
|
||||||
|
instruction_en = tool.get_language_instruction()
|
||||||
|
self.assertIn("en-US", instruction_en)
|
||||||
|
|
||||||
|
# Spanish
|
||||||
|
os.environ["LOCALE"] = "es-ES"
|
||||||
|
instruction_es = tool.get_language_instruction()
|
||||||
|
self.assertIn("es-ES", instruction_es)
|
||||||
|
|
||||||
|
# Chinese
|
||||||
|
os.environ["LOCALE"] = "zh-CN"
|
||||||
|
instruction_zh = tool.get_language_instruction()
|
||||||
|
self.assertIn("zh-CN", instruction_zh)
|
||||||
|
|
||||||
|
# Check that all instructions are different
|
||||||
|
instructions = [instruction_fr, instruction_en, instruction_es, instruction_zh]
|
||||||
|
for i, inst1 in enumerate(instructions):
|
||||||
|
for j, inst2 in enumerate(instructions):
|
||||||
|
if i != j:
|
||||||
|
self.assertNotEqual(inst1, inst2)
|
||||||
|
|
||||||
|
|
||||||
|
# Helper function to run async tests
|
||||||
|
def run_async_test(test_func):
|
||||||
|
"""Helper to run async test functions."""
|
||||||
|
return asyncio.run(test_func())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main(verbosity=2)
|
||||||
@@ -512,10 +512,7 @@ of the evidence, even when it strongly points in one direction.""",
|
|||||||
"provider_used": provider.get_provider_type().value,
|
"provider_used": provider.get_provider_type().value,
|
||||||
}
|
}
|
||||||
|
|
||||||
return [TextContent(
|
return [TextContent(type="text", text=json.dumps(response_data, indent=2, ensure_ascii=False))]
|
||||||
type="text",
|
|
||||||
text=json.dumps(response_data, indent=2, ensure_ascii=False)
|
|
||||||
)]
|
|
||||||
|
|
||||||
# Otherwise, use standard workflow execution
|
# Otherwise, use standard workflow execution
|
||||||
return await super().execute_workflow(arguments)
|
return await super().execute_workflow(arguments)
|
||||||
|
|||||||
@@ -372,16 +372,15 @@ class SimpleTool(BaseTool):
|
|||||||
|
|
||||||
follow_up_instructions = get_follow_up_instructions(0)
|
follow_up_instructions = get_follow_up_instructions(0)
|
||||||
prompt = f"{prompt}\n\n{follow_up_instructions}"
|
prompt = f"{prompt}\n\n{follow_up_instructions}"
|
||||||
logger.debug(f"Added follow-up instructions for new {self.get_name()} conversation") # Validate images if any were provided
|
logger.debug(
|
||||||
|
f"Added follow-up instructions for new {self.get_name()} conversation"
|
||||||
|
) # Validate images if any were provided
|
||||||
if images:
|
if images:
|
||||||
image_validation_error = self._validate_image_limits(
|
image_validation_error = self._validate_image_limits(
|
||||||
images, model_context=self._model_context, continuation_id=continuation_id
|
images, model_context=self._model_context, continuation_id=continuation_id
|
||||||
)
|
)
|
||||||
if image_validation_error:
|
if image_validation_error:
|
||||||
return [TextContent(
|
return [TextContent(type="text", text=json.dumps(image_validation_error, ensure_ascii=False))]
|
||||||
type="text",
|
|
||||||
text=json.dumps(image_validation_error, ensure_ascii=False)
|
|
||||||
)]
|
|
||||||
|
|
||||||
# Get and validate temperature against model constraints
|
# Get and validate temperature against model constraints
|
||||||
temperature, temp_warnings = self.get_validated_temperature(request, self._model_context)
|
temperature, temp_warnings = self.get_validated_temperature(request, self._model_context)
|
||||||
|
|||||||
@@ -715,10 +715,7 @@ class BaseWorkflowMixin(ABC):
|
|||||||
if continuation_id:
|
if continuation_id:
|
||||||
self.store_conversation_turn(continuation_id, response_data, request)
|
self.store_conversation_turn(continuation_id, response_data, request)
|
||||||
|
|
||||||
return [TextContent(
|
return [TextContent(type="text", text=json.dumps(response_data, indent=2, ensure_ascii=False))]
|
||||||
type="text",
|
|
||||||
text=json.dumps(response_data, indent=2, ensure_ascii=False)
|
|
||||||
)]
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in {self.get_name()} work: {e}", exc_info=True)
|
logger.error(f"Error in {self.get_name()} work: {e}", exc_info=True)
|
||||||
@@ -731,10 +728,7 @@ class BaseWorkflowMixin(ABC):
|
|||||||
# Add metadata to error responses too
|
# Add metadata to error responses too
|
||||||
self._add_workflow_metadata(error_data, arguments)
|
self._add_workflow_metadata(error_data, arguments)
|
||||||
|
|
||||||
return [TextContent(
|
return [TextContent(type="text", text=json.dumps(error_data, indent=2, ensure_ascii=False))]
|
||||||
type="text",
|
|
||||||
text=json.dumps(error_data, indent=2, ensure_ascii=False)
|
|
||||||
)]
|
|
||||||
|
|
||||||
# Hook methods for tool customization
|
# Hook methods for tool customization
|
||||||
|
|
||||||
@@ -1272,8 +1266,7 @@ class BaseWorkflowMixin(ABC):
|
|||||||
special_status = expert_analysis["status"]
|
special_status = expert_analysis["status"]
|
||||||
response_data["status"] = special_status
|
response_data["status"] = special_status
|
||||||
response_data["content"] = expert_analysis.get(
|
response_data["content"] = expert_analysis.get(
|
||||||
"raw_analysis",
|
"raw_analysis", json.dumps(expert_analysis, ensure_ascii=False)
|
||||||
json.dumps(expert_analysis, ensure_ascii=False)
|
|
||||||
)
|
)
|
||||||
del response_data["expert_analysis"]
|
del response_data["expert_analysis"]
|
||||||
|
|
||||||
@@ -1533,17 +1526,17 @@ class BaseWorkflowMixin(ABC):
|
|||||||
error_data = {"status": "error", "content": "No arguments provided"}
|
error_data = {"status": "error", "content": "No arguments provided"}
|
||||||
# Add basic metadata even for validation errors
|
# Add basic metadata even for validation errors
|
||||||
error_data["metadata"] = {"tool_name": self.get_name()}
|
error_data["metadata"] = {"tool_name": self.get_name()}
|
||||||
return [TextContent(
|
return [TextContent(type="text", text=json.dumps(error_data, ensure_ascii=False))]
|
||||||
type="text",
|
|
||||||
text=json.dumps(error_data, ensure_ascii=False)
|
|
||||||
)]
|
|
||||||
|
|
||||||
# Delegate to execute_workflow
|
# Delegate to execute_workflow
|
||||||
return await self.execute_workflow(arguments)
|
return await self.execute_workflow(arguments)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in {self.get_name()} tool execution: {e}", exc_info=True)
|
logger.error(f"Error in {self.get_name()} tool execution: {e}", exc_info=True)
|
||||||
error_data = {"status": "error", "content": f"Error in {self.get_name()}: {str(e)}"} # Add metadata to error responses
|
error_data = {
|
||||||
|
"status": "error",
|
||||||
|
"content": f"Error in {self.get_name()}: {str(e)}",
|
||||||
|
} # Add metadata to error responses
|
||||||
self._add_workflow_metadata(error_data, arguments)
|
self._add_workflow_metadata(error_data, arguments)
|
||||||
return [
|
return [
|
||||||
TextContent(
|
TextContent(
|
||||||
|
|||||||
Reference in New Issue
Block a user