fix: consensus now advertises a short list of models to avoid the CLI getting the names wrong

This commit is contained in:
Fahad
2025-10-03 22:41:28 +04:00
parent 9780c8ef02
commit 8759edc817
4 changed files with 254 additions and 2 deletions

View File

@@ -1068,9 +1068,12 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
# Create model context early to use for history building
from utils.model_context import ModelContext
tool = TOOLS.get(context.tool_name)
requires_model = tool.requires_model() if tool else True
# Check if we should use the model from the previous conversation turn
model_from_args = arguments.get("model")
if not model_from_args and context.turns:
if requires_model and not model_from_args and context.turns:
# Find the last assistant turn to get the model used
for turn in reversed(context.turns):
if turn.role == "assistant" and turn.model_name:
@@ -1078,7 +1081,99 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
logger.debug(f"[CONVERSATION_DEBUG] Using model from previous turn: {turn.model_name}")
break
# Resolve an effective model for context reconstruction when DEFAULT_MODEL=auto
model_context = arguments.get("_model_context")
if requires_model:
if model_context is None:
try:
model_context = ModelContext.from_arguments(arguments)
arguments.setdefault("_resolved_model_name", model_context.model_name)
except ValueError as exc:
from providers.registry import ModelProviderRegistry
fallback_model = None
if tool is not None:
try:
fallback_model = ModelProviderRegistry.get_preferred_fallback_model(tool.get_model_category())
except Exception as fallback_exc: # pragma: no cover - defensive log
logger.debug(
f"[CONVERSATION_DEBUG] Unable to resolve fallback model for {context.tool_name}: {fallback_exc}"
)
if fallback_model is None:
available_models = ModelProviderRegistry.get_available_model_names()
if available_models:
fallback_model = available_models[0]
if fallback_model is None:
raise
logger.debug(
f"[CONVERSATION_DEBUG] Falling back to model '{fallback_model}' for context reconstruction after error: {exc}"
)
model_context = ModelContext(fallback_model)
arguments["_model_context"] = model_context
arguments["_resolved_model_name"] = fallback_model
from providers.registry import ModelProviderRegistry
provider = ModelProviderRegistry.get_provider_for_model(model_context.model_name)
if provider is None:
fallback_model = None
if tool is not None:
try:
fallback_model = ModelProviderRegistry.get_preferred_fallback_model(tool.get_model_category())
except Exception as fallback_exc: # pragma: no cover - defensive log
logger.debug(
f"[CONVERSATION_DEBUG] Unable to resolve fallback model for {context.tool_name}: {fallback_exc}"
)
if fallback_model is None:
available_models = ModelProviderRegistry.get_available_model_names()
if available_models:
fallback_model = available_models[0]
if fallback_model is None:
raise ValueError(
f"Conversation continuation failed: model '{model_context.model_name}' is not available with current API keys."
)
logger.debug(
f"[CONVERSATION_DEBUG] Model '{model_context.model_name}' unavailable; swapping to '{fallback_model}' for context reconstruction"
)
model_context = ModelContext(fallback_model)
arguments["_model_context"] = model_context
arguments["_resolved_model_name"] = fallback_model
else:
if model_context is None:
from providers.registry import ModelProviderRegistry
fallback_model = None
if tool is not None:
try:
fallback_model = ModelProviderRegistry.get_preferred_fallback_model(tool.get_model_category())
except Exception as fallback_exc: # pragma: no cover - defensive log
logger.debug(
f"[CONVERSATION_DEBUG] Unable to resolve fallback model for {context.tool_name}: {fallback_exc}"
)
if fallback_model is None:
available_models = ModelProviderRegistry.get_available_model_names()
if available_models:
fallback_model = available_models[0]
if fallback_model is None:
raise ValueError(
"Conversation continuation failed: no available models detected for context reconstruction."
)
logger.debug(
f"[CONVERSATION_DEBUG] Using fallback model '{fallback_model}' for context reconstruction of tool without model requirement"
)
model_context = ModelContext(fallback_model)
arguments["_model_context"] = model_context
arguments["_resolved_model_name"] = fallback_model
# Build conversation history with model-specific limits
logger.debug(f"[CONVERSATION_DEBUG] Building conversation history for thread {continuation_id}")

View File

@@ -164,3 +164,107 @@ async def test_consensus_multi_model_consultations(monkeypatch):
# Clean up provider registry state after test
ModelProviderRegistry.reset_for_testing()
@pytest.mark.asyncio
@pytest.mark.no_mock_provider
async def test_consensus_auto_mode_with_openrouter_and_gemini(monkeypatch):
"""Ensure continuation flow resolves to real models instead of leaking 'auto'."""
gemini_key = os.getenv("GEMINI_API_KEY", "").strip() or "dummy-key-for-replay"
openrouter_key = os.getenv("OPENROUTER_API_KEY", "").strip() or "dummy-key-for-replay"
with monkeypatch.context() as m:
m.setenv("DEFAULT_MODEL", "auto")
m.setenv("GEMINI_API_KEY", gemini_key)
m.setenv("OPENROUTER_API_KEY", openrouter_key)
for key in [
"OPENAI_API_KEY",
"XAI_API_KEY",
"DIAL_API_KEY",
"CUSTOM_API_KEY",
"CUSTOM_API_URL",
]:
m.delenv(key, raising=False)
import importlib
import config
m.setattr(config, "DEFAULT_MODEL", "auto")
import server as server_module
server = importlib.reload(server_module)
m.setattr(server, "DEFAULT_MODEL", "auto", raising=False)
ModelProviderRegistry.reset_for_testing()
from providers.gemini import GeminiModelProvider
from providers.openrouter import OpenRouterProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
from utils.storage_backend import get_storage_backend
# Clear conversation storage to avoid cross-test leakage
storage = get_storage_backend()
storage._store.clear()
models_to_consult = [
{"model": "claude-3-5-flash-20241022", "stance": "neutral"},
{"model": "gpt-5-mini", "stance": "neutral"},
]
step1_args = {
"step": "Evaluate framework options.",
"step_number": 1,
"total_steps": len(models_to_consult),
"next_step_required": True,
"findings": "Initial analysis requested.",
"models": models_to_consult,
}
step1_output = await server.handle_call_tool("consensus", step1_args)
assert step1_output and step1_output[0].type == "text"
step1_payload = json.loads(step1_output[0].text)
assert step1_payload["status"] == "analysis_and_first_model_consulted"
assert step1_payload["model_consulted"] == "claude-3-5-flash-20241022"
assert step1_payload["model_response"]["status"] == "error"
assert "claude-3-5-flash-20241022" in step1_payload["model_response"]["error"]
continuation_offer = step1_payload.get("continuation_offer")
assert continuation_offer is not None
continuation_id = continuation_offer["continuation_id"]
step2_args = {
"step": "Continue consultation sequence.",
"step_number": 2,
"total_steps": len(models_to_consult),
"next_step_required": False,
"findings": "Ready for next model.",
"continuation_id": continuation_id,
"models": models_to_consult,
}
try:
step2_output = await server.handle_call_tool("consensus", step2_args)
finally:
# Reset provider registry regardless of outcome to avoid cross-test bleed
ModelProviderRegistry.reset_for_testing()
assert step2_output and step2_output[0].type == "text"
step2_payload = json.loads(step2_output[0].text)
serialized = json.dumps(step2_payload)
assert "auto" not in serialized.lower(), "Auto model leakage should be resolved"
assert "gpt-5-mini" in serialized or "claude-3-5-flash-20241022" in serialized
# Restore server module to reflect original configuration for other tests
import importlib
import server as server_module
importlib.reload(server_module)

View File

@@ -0,0 +1,24 @@
"""Schema-related tests for ConsensusTool."""
from types import MethodType
from tools.consensus import ConsensusTool
def test_consensus_models_field_includes_available_models(monkeypatch):
"""Consensus schema should surface available model guidance like single-model tools."""
tool = ConsensusTool()
monkeypatch.setattr(
tool,
"_get_ranked_model_summaries",
MethodType(lambda self, limit=5: (["gemini-2.5-pro (score 100, 1.0M ctx, thinking)"], 1, False), tool),
)
monkeypatch.setattr(tool, "_get_restriction_note", MethodType(lambda self: None, tool))
schema = tool.get_input_schema()
models_field_description = schema["properties"]["models"]["description"]
assert "listmodels" in models_field_description
assert "Top models" in models_field_description

View File

@@ -258,6 +258,35 @@ of the evidence, even when it strongly points in one direction.""",
},
}
# Provide guidance on available models similar to single-model tools
model_description = (
"When the user names a model, you MUST use that exact value or report the "
"provider error—never swap in another option. Use the `listmodels` tool for the full roster."
)
summaries, total, restricted = self._get_ranked_model_summaries()
remainder = max(0, total - len(summaries))
if summaries:
label = "Allowed models" if restricted else "Top models"
top_line = "; ".join(summaries)
if remainder > 0:
top_line = f"{label}: {top_line}; +{remainder} more via `listmodels`."
else:
top_line = f"{label}: {top_line}."
model_description = f"{model_description} {top_line}"
else:
model_description = (
f"{model_description} No models detected—configure provider credentials or use the `listmodels` tool "
"to inspect availability."
)
restriction_note = self._get_restriction_note()
if restriction_note and (remainder > 0 or not summaries):
model_description = f"{model_description} {restriction_note}."
existing_models_desc = consensus_field_overrides["models"]["description"]
consensus_field_overrides["models"]["description"] = f"{existing_models_desc} {model_description}"
# Define excluded fields for consensus workflow
excluded_workflow_fields = [
"files_checked", # Not used in consensus workflow