fix: consensus now advertises a short list of models to avoid the CLI getting the names wrong
This commit is contained in:
99
server.py
99
server.py
@@ -1068,9 +1068,12 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
|
|||||||
# Create model context early to use for history building
|
# Create model context early to use for history building
|
||||||
from utils.model_context import ModelContext
|
from utils.model_context import ModelContext
|
||||||
|
|
||||||
|
tool = TOOLS.get(context.tool_name)
|
||||||
|
requires_model = tool.requires_model() if tool else True
|
||||||
|
|
||||||
# Check if we should use the model from the previous conversation turn
|
# Check if we should use the model from the previous conversation turn
|
||||||
model_from_args = arguments.get("model")
|
model_from_args = arguments.get("model")
|
||||||
if not model_from_args and context.turns:
|
if requires_model and not model_from_args and context.turns:
|
||||||
# Find the last assistant turn to get the model used
|
# Find the last assistant turn to get the model used
|
||||||
for turn in reversed(context.turns):
|
for turn in reversed(context.turns):
|
||||||
if turn.role == "assistant" and turn.model_name:
|
if turn.role == "assistant" and turn.model_name:
|
||||||
@@ -1078,7 +1081,99 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
|
|||||||
logger.debug(f"[CONVERSATION_DEBUG] Using model from previous turn: {turn.model_name}")
|
logger.debug(f"[CONVERSATION_DEBUG] Using model from previous turn: {turn.model_name}")
|
||||||
break
|
break
|
||||||
|
|
||||||
model_context = ModelContext.from_arguments(arguments)
|
# Resolve an effective model for context reconstruction when DEFAULT_MODEL=auto
|
||||||
|
model_context = arguments.get("_model_context")
|
||||||
|
|
||||||
|
if requires_model:
|
||||||
|
if model_context is None:
|
||||||
|
try:
|
||||||
|
model_context = ModelContext.from_arguments(arguments)
|
||||||
|
arguments.setdefault("_resolved_model_name", model_context.model_name)
|
||||||
|
except ValueError as exc:
|
||||||
|
from providers.registry import ModelProviderRegistry
|
||||||
|
|
||||||
|
fallback_model = None
|
||||||
|
if tool is not None:
|
||||||
|
try:
|
||||||
|
fallback_model = ModelProviderRegistry.get_preferred_fallback_model(tool.get_model_category())
|
||||||
|
except Exception as fallback_exc: # pragma: no cover - defensive log
|
||||||
|
logger.debug(
|
||||||
|
f"[CONVERSATION_DEBUG] Unable to resolve fallback model for {context.tool_name}: {fallback_exc}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if fallback_model is None:
|
||||||
|
available_models = ModelProviderRegistry.get_available_model_names()
|
||||||
|
if available_models:
|
||||||
|
fallback_model = available_models[0]
|
||||||
|
|
||||||
|
if fallback_model is None:
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[CONVERSATION_DEBUG] Falling back to model '{fallback_model}' for context reconstruction after error: {exc}"
|
||||||
|
)
|
||||||
|
model_context = ModelContext(fallback_model)
|
||||||
|
arguments["_model_context"] = model_context
|
||||||
|
arguments["_resolved_model_name"] = fallback_model
|
||||||
|
|
||||||
|
from providers.registry import ModelProviderRegistry
|
||||||
|
|
||||||
|
provider = ModelProviderRegistry.get_provider_for_model(model_context.model_name)
|
||||||
|
if provider is None:
|
||||||
|
fallback_model = None
|
||||||
|
if tool is not None:
|
||||||
|
try:
|
||||||
|
fallback_model = ModelProviderRegistry.get_preferred_fallback_model(tool.get_model_category())
|
||||||
|
except Exception as fallback_exc: # pragma: no cover - defensive log
|
||||||
|
logger.debug(
|
||||||
|
f"[CONVERSATION_DEBUG] Unable to resolve fallback model for {context.tool_name}: {fallback_exc}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if fallback_model is None:
|
||||||
|
available_models = ModelProviderRegistry.get_available_model_names()
|
||||||
|
if available_models:
|
||||||
|
fallback_model = available_models[0]
|
||||||
|
|
||||||
|
if fallback_model is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Conversation continuation failed: model '{model_context.model_name}' is not available with current API keys."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[CONVERSATION_DEBUG] Model '{model_context.model_name}' unavailable; swapping to '{fallback_model}' for context reconstruction"
|
||||||
|
)
|
||||||
|
model_context = ModelContext(fallback_model)
|
||||||
|
arguments["_model_context"] = model_context
|
||||||
|
arguments["_resolved_model_name"] = fallback_model
|
||||||
|
else:
|
||||||
|
if model_context is None:
|
||||||
|
from providers.registry import ModelProviderRegistry
|
||||||
|
|
||||||
|
fallback_model = None
|
||||||
|
if tool is not None:
|
||||||
|
try:
|
||||||
|
fallback_model = ModelProviderRegistry.get_preferred_fallback_model(tool.get_model_category())
|
||||||
|
except Exception as fallback_exc: # pragma: no cover - defensive log
|
||||||
|
logger.debug(
|
||||||
|
f"[CONVERSATION_DEBUG] Unable to resolve fallback model for {context.tool_name}: {fallback_exc}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if fallback_model is None:
|
||||||
|
available_models = ModelProviderRegistry.get_available_model_names()
|
||||||
|
if available_models:
|
||||||
|
fallback_model = available_models[0]
|
||||||
|
|
||||||
|
if fallback_model is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Conversation continuation failed: no available models detected for context reconstruction."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[CONVERSATION_DEBUG] Using fallback model '{fallback_model}' for context reconstruction of tool without model requirement"
|
||||||
|
)
|
||||||
|
model_context = ModelContext(fallback_model)
|
||||||
|
arguments["_model_context"] = model_context
|
||||||
|
arguments["_resolved_model_name"] = fallback_model
|
||||||
|
|
||||||
# Build conversation history with model-specific limits
|
# Build conversation history with model-specific limits
|
||||||
logger.debug(f"[CONVERSATION_DEBUG] Building conversation history for thread {continuation_id}")
|
logger.debug(f"[CONVERSATION_DEBUG] Building conversation history for thread {continuation_id}")
|
||||||
|
|||||||
@@ -164,3 +164,107 @@ async def test_consensus_multi_model_consultations(monkeypatch):
|
|||||||
|
|
||||||
# Clean up provider registry state after test
|
# Clean up provider registry state after test
|
||||||
ModelProviderRegistry.reset_for_testing()
|
ModelProviderRegistry.reset_for_testing()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.no_mock_provider
|
||||||
|
async def test_consensus_auto_mode_with_openrouter_and_gemini(monkeypatch):
|
||||||
|
"""Ensure continuation flow resolves to real models instead of leaking 'auto'."""
|
||||||
|
|
||||||
|
gemini_key = os.getenv("GEMINI_API_KEY", "").strip() or "dummy-key-for-replay"
|
||||||
|
openrouter_key = os.getenv("OPENROUTER_API_KEY", "").strip() or "dummy-key-for-replay"
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("DEFAULT_MODEL", "auto")
|
||||||
|
m.setenv("GEMINI_API_KEY", gemini_key)
|
||||||
|
m.setenv("OPENROUTER_API_KEY", openrouter_key)
|
||||||
|
|
||||||
|
for key in [
|
||||||
|
"OPENAI_API_KEY",
|
||||||
|
"XAI_API_KEY",
|
||||||
|
"DIAL_API_KEY",
|
||||||
|
"CUSTOM_API_KEY",
|
||||||
|
"CUSTOM_API_URL",
|
||||||
|
]:
|
||||||
|
m.delenv(key, raising=False)
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
m.setattr(config, "DEFAULT_MODEL", "auto")
|
||||||
|
|
||||||
|
import server as server_module
|
||||||
|
|
||||||
|
server = importlib.reload(server_module)
|
||||||
|
m.setattr(server, "DEFAULT_MODEL", "auto", raising=False)
|
||||||
|
|
||||||
|
ModelProviderRegistry.reset_for_testing()
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
from providers.openrouter import OpenRouterProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||||
|
|
||||||
|
from utils.storage_backend import get_storage_backend
|
||||||
|
|
||||||
|
# Clear conversation storage to avoid cross-test leakage
|
||||||
|
storage = get_storage_backend()
|
||||||
|
storage._store.clear()
|
||||||
|
|
||||||
|
models_to_consult = [
|
||||||
|
{"model": "claude-3-5-flash-20241022", "stance": "neutral"},
|
||||||
|
{"model": "gpt-5-mini", "stance": "neutral"},
|
||||||
|
]
|
||||||
|
|
||||||
|
step1_args = {
|
||||||
|
"step": "Evaluate framework options.",
|
||||||
|
"step_number": 1,
|
||||||
|
"total_steps": len(models_to_consult),
|
||||||
|
"next_step_required": True,
|
||||||
|
"findings": "Initial analysis requested.",
|
||||||
|
"models": models_to_consult,
|
||||||
|
}
|
||||||
|
|
||||||
|
step1_output = await server.handle_call_tool("consensus", step1_args)
|
||||||
|
assert step1_output and step1_output[0].type == "text"
|
||||||
|
step1_payload = json.loads(step1_output[0].text)
|
||||||
|
|
||||||
|
assert step1_payload["status"] == "analysis_and_first_model_consulted"
|
||||||
|
assert step1_payload["model_consulted"] == "claude-3-5-flash-20241022"
|
||||||
|
assert step1_payload["model_response"]["status"] == "error"
|
||||||
|
assert "claude-3-5-flash-20241022" in step1_payload["model_response"]["error"]
|
||||||
|
|
||||||
|
continuation_offer = step1_payload.get("continuation_offer")
|
||||||
|
assert continuation_offer is not None
|
||||||
|
continuation_id = continuation_offer["continuation_id"]
|
||||||
|
|
||||||
|
step2_args = {
|
||||||
|
"step": "Continue consultation sequence.",
|
||||||
|
"step_number": 2,
|
||||||
|
"total_steps": len(models_to_consult),
|
||||||
|
"next_step_required": False,
|
||||||
|
"findings": "Ready for next model.",
|
||||||
|
"continuation_id": continuation_id,
|
||||||
|
"models": models_to_consult,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
step2_output = await server.handle_call_tool("consensus", step2_args)
|
||||||
|
finally:
|
||||||
|
# Reset provider registry regardless of outcome to avoid cross-test bleed
|
||||||
|
ModelProviderRegistry.reset_for_testing()
|
||||||
|
|
||||||
|
assert step2_output and step2_output[0].type == "text"
|
||||||
|
step2_payload = json.loads(step2_output[0].text)
|
||||||
|
|
||||||
|
serialized = json.dumps(step2_payload)
|
||||||
|
assert "auto" not in serialized.lower(), "Auto model leakage should be resolved"
|
||||||
|
assert "gpt-5-mini" in serialized or "claude-3-5-flash-20241022" in serialized
|
||||||
|
|
||||||
|
# Restore server module to reflect original configuration for other tests
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
import server as server_module
|
||||||
|
|
||||||
|
importlib.reload(server_module)
|
||||||
|
|||||||
24
tests/test_consensus_schema.py
Normal file
24
tests/test_consensus_schema.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
"""Schema-related tests for ConsensusTool."""
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
|
||||||
|
from tools.consensus import ConsensusTool
|
||||||
|
|
||||||
|
|
||||||
|
def test_consensus_models_field_includes_available_models(monkeypatch):
|
||||||
|
"""Consensus schema should surface available model guidance like single-model tools."""
|
||||||
|
|
||||||
|
tool = ConsensusTool()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
tool,
|
||||||
|
"_get_ranked_model_summaries",
|
||||||
|
MethodType(lambda self, limit=5: (["gemini-2.5-pro (score 100, 1.0M ctx, thinking)"], 1, False), tool),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(tool, "_get_restriction_note", MethodType(lambda self: None, tool))
|
||||||
|
|
||||||
|
schema = tool.get_input_schema()
|
||||||
|
models_field_description = schema["properties"]["models"]["description"]
|
||||||
|
|
||||||
|
assert "listmodels" in models_field_description
|
||||||
|
assert "Top models" in models_field_description
|
||||||
@@ -258,6 +258,35 @@ of the evidence, even when it strongly points in one direction.""",
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Provide guidance on available models similar to single-model tools
|
||||||
|
model_description = (
|
||||||
|
"When the user names a model, you MUST use that exact value or report the "
|
||||||
|
"provider error—never swap in another option. Use the `listmodels` tool for the full roster."
|
||||||
|
)
|
||||||
|
|
||||||
|
summaries, total, restricted = self._get_ranked_model_summaries()
|
||||||
|
remainder = max(0, total - len(summaries))
|
||||||
|
if summaries:
|
||||||
|
label = "Allowed models" if restricted else "Top models"
|
||||||
|
top_line = "; ".join(summaries)
|
||||||
|
if remainder > 0:
|
||||||
|
top_line = f"{label}: {top_line}; +{remainder} more via `listmodels`."
|
||||||
|
else:
|
||||||
|
top_line = f"{label}: {top_line}."
|
||||||
|
model_description = f"{model_description} {top_line}"
|
||||||
|
else:
|
||||||
|
model_description = (
|
||||||
|
f"{model_description} No models detected—configure provider credentials or use the `listmodels` tool "
|
||||||
|
"to inspect availability."
|
||||||
|
)
|
||||||
|
|
||||||
|
restriction_note = self._get_restriction_note()
|
||||||
|
if restriction_note and (remainder > 0 or not summaries):
|
||||||
|
model_description = f"{model_description} {restriction_note}."
|
||||||
|
|
||||||
|
existing_models_desc = consensus_field_overrides["models"]["description"]
|
||||||
|
consensus_field_overrides["models"]["description"] = f"{existing_models_desc} {model_description}"
|
||||||
|
|
||||||
# Define excluded fields for consensus workflow
|
# Define excluded fields for consensus workflow
|
||||||
excluded_workflow_fields = [
|
excluded_workflow_fields = [
|
||||||
"files_checked", # Not used in consensus workflow
|
"files_checked", # Not used in consensus workflow
|
||||||
|
|||||||
Reference in New Issue
Block a user