diff --git a/server.py b/server.py index 81e5562..a8bf47e 100644 --- a/server.py +++ b/server.py @@ -1068,9 +1068,12 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any # Create model context early to use for history building from utils.model_context import ModelContext + tool = TOOLS.get(context.tool_name) + requires_model = tool.requires_model() if tool else True + # Check if we should use the model from the previous conversation turn model_from_args = arguments.get("model") - if not model_from_args and context.turns: + if requires_model and not model_from_args and context.turns: # Find the last assistant turn to get the model used for turn in reversed(context.turns): if turn.role == "assistant" and turn.model_name: @@ -1078,7 +1081,99 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any logger.debug(f"[CONVERSATION_DEBUG] Using model from previous turn: {turn.model_name}") break - model_context = ModelContext.from_arguments(arguments) + # Resolve an effective model for context reconstruction when DEFAULT_MODEL=auto + model_context = arguments.get("_model_context") + + if requires_model: + if model_context is None: + try: + model_context = ModelContext.from_arguments(arguments) + arguments.setdefault("_resolved_model_name", model_context.model_name) + except ValueError as exc: + from providers.registry import ModelProviderRegistry + + fallback_model = None + if tool is not None: + try: + fallback_model = ModelProviderRegistry.get_preferred_fallback_model(tool.get_model_category()) + except Exception as fallback_exc: # pragma: no cover - defensive log + logger.debug( + f"[CONVERSATION_DEBUG] Unable to resolve fallback model for {context.tool_name}: {fallback_exc}" + ) + + if fallback_model is None: + available_models = ModelProviderRegistry.get_available_model_names() + if available_models: + fallback_model = available_models[0] + + if fallback_model is None: + raise + + logger.debug( + f"[CONVERSATION_DEBUG] Falling back to model '{fallback_model}' for context reconstruction after error: {exc}" + ) + model_context = ModelContext(fallback_model) + arguments["_model_context"] = model_context + arguments["_resolved_model_name"] = fallback_model + + from providers.registry import ModelProviderRegistry + + provider = ModelProviderRegistry.get_provider_for_model(model_context.model_name) + if provider is None: + fallback_model = None + if tool is not None: + try: + fallback_model = ModelProviderRegistry.get_preferred_fallback_model(tool.get_model_category()) + except Exception as fallback_exc: # pragma: no cover - defensive log + logger.debug( + f"[CONVERSATION_DEBUG] Unable to resolve fallback model for {context.tool_name}: {fallback_exc}" + ) + + if fallback_model is None: + available_models = ModelProviderRegistry.get_available_model_names() + if available_models: + fallback_model = available_models[0] + + if fallback_model is None: + raise ValueError( + f"Conversation continuation failed: model '{model_context.model_name}' is not available with current API keys." + ) + + logger.debug( + f"[CONVERSATION_DEBUG] Model '{model_context.model_name}' unavailable; swapping to '{fallback_model}' for context reconstruction" + ) + model_context = ModelContext(fallback_model) + arguments["_model_context"] = model_context + arguments["_resolved_model_name"] = fallback_model + else: + if model_context is None: + from providers.registry import ModelProviderRegistry + + fallback_model = None + if tool is not None: + try: + fallback_model = ModelProviderRegistry.get_preferred_fallback_model(tool.get_model_category()) + except Exception as fallback_exc: # pragma: no cover - defensive log + logger.debug( + f"[CONVERSATION_DEBUG] Unable to resolve fallback model for {context.tool_name}: {fallback_exc}" + ) + + if fallback_model is None: + available_models = ModelProviderRegistry.get_available_model_names() + if available_models: + fallback_model = available_models[0] + + if fallback_model is None: + raise ValueError( + "Conversation continuation failed: no available models detected for context reconstruction." + ) + + logger.debug( + f"[CONVERSATION_DEBUG] Using fallback model '{fallback_model}' for context reconstruction of tool without model requirement" + ) + model_context = ModelContext(fallback_model) + arguments["_model_context"] = model_context + arguments["_resolved_model_name"] = fallback_model # Build conversation history with model-specific limits logger.debug(f"[CONVERSATION_DEBUG] Building conversation history for thread {continuation_id}") diff --git a/tests/test_consensus_integration.py b/tests/test_consensus_integration.py index a652173..37025e6 100644 --- a/tests/test_consensus_integration.py +++ b/tests/test_consensus_integration.py @@ -164,3 +164,107 @@ async def test_consensus_multi_model_consultations(monkeypatch): # Clean up provider registry state after test ModelProviderRegistry.reset_for_testing() + + +@pytest.mark.asyncio +@pytest.mark.no_mock_provider +async def test_consensus_auto_mode_with_openrouter_and_gemini(monkeypatch): + """Ensure continuation flow resolves to real models instead of leaking 'auto'.""" + + gemini_key = os.getenv("GEMINI_API_KEY", "").strip() or "dummy-key-for-replay" + openrouter_key = os.getenv("OPENROUTER_API_KEY", "").strip() or "dummy-key-for-replay" + + with monkeypatch.context() as m: + m.setenv("DEFAULT_MODEL", "auto") + m.setenv("GEMINI_API_KEY", gemini_key) + m.setenv("OPENROUTER_API_KEY", openrouter_key) + + for key in [ + "OPENAI_API_KEY", + "XAI_API_KEY", + "DIAL_API_KEY", + "CUSTOM_API_KEY", + "CUSTOM_API_URL", + ]: + m.delenv(key, raising=False) + + import importlib + + import config + + m.setattr(config, "DEFAULT_MODEL", "auto") + + import server as server_module + + server = importlib.reload(server_module) + m.setattr(server, "DEFAULT_MODEL", "auto", raising=False) + + ModelProviderRegistry.reset_for_testing() + from providers.gemini import GeminiModelProvider + from providers.openrouter import OpenRouterProvider + + ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider) + ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider) + + from utils.storage_backend import get_storage_backend + + # Clear conversation storage to avoid cross-test leakage + storage = get_storage_backend() + storage._store.clear() + + models_to_consult = [ + {"model": "claude-3-5-flash-20241022", "stance": "neutral"}, + {"model": "gpt-5-mini", "stance": "neutral"}, + ] + + step1_args = { + "step": "Evaluate framework options.", + "step_number": 1, + "total_steps": len(models_to_consult), + "next_step_required": True, + "findings": "Initial analysis requested.", + "models": models_to_consult, + } + + step1_output = await server.handle_call_tool("consensus", step1_args) + assert step1_output and step1_output[0].type == "text" + step1_payload = json.loads(step1_output[0].text) + + assert step1_payload["status"] == "analysis_and_first_model_consulted" + assert step1_payload["model_consulted"] == "claude-3-5-flash-20241022" + assert step1_payload["model_response"]["status"] == "error" + assert "claude-3-5-flash-20241022" in step1_payload["model_response"]["error"] + + continuation_offer = step1_payload.get("continuation_offer") + assert continuation_offer is not None + continuation_id = continuation_offer["continuation_id"] + + step2_args = { + "step": "Continue consultation sequence.", + "step_number": 2, + "total_steps": len(models_to_consult), + "next_step_required": False, + "findings": "Ready for next model.", + "continuation_id": continuation_id, + "models": models_to_consult, + } + + try: + step2_output = await server.handle_call_tool("consensus", step2_args) + finally: + # Reset provider registry regardless of outcome to avoid cross-test bleed + ModelProviderRegistry.reset_for_testing() + + assert step2_output and step2_output[0].type == "text" + step2_payload = json.loads(step2_output[0].text) + + serialized = json.dumps(step2_payload) + assert "auto" not in serialized.lower(), "Auto model leakage should be resolved" + assert "gpt-5-mini" in serialized or "claude-3-5-flash-20241022" in serialized + + # Restore server module to reflect original configuration for other tests + import importlib + + import server as server_module + + importlib.reload(server_module) diff --git a/tests/test_consensus_schema.py b/tests/test_consensus_schema.py new file mode 100644 index 0000000..c0a5fd4 --- /dev/null +++ b/tests/test_consensus_schema.py @@ -0,0 +1,24 @@ +"""Schema-related tests for ConsensusTool.""" + +from types import MethodType + +from tools.consensus import ConsensusTool + + +def test_consensus_models_field_includes_available_models(monkeypatch): + """Consensus schema should surface available model guidance like single-model tools.""" + + tool = ConsensusTool() + + monkeypatch.setattr( + tool, + "_get_ranked_model_summaries", + MethodType(lambda self, limit=5: (["gemini-2.5-pro (score 100, 1.0M ctx, thinking)"], 1, False), tool), + ) + monkeypatch.setattr(tool, "_get_restriction_note", MethodType(lambda self: None, tool)) + + schema = tool.get_input_schema() + models_field_description = schema["properties"]["models"]["description"] + + assert "listmodels" in models_field_description + assert "Top models" in models_field_description diff --git a/tools/consensus.py b/tools/consensus.py index 3965eb3..bf7b9d1 100644 --- a/tools/consensus.py +++ b/tools/consensus.py @@ -258,6 +258,35 @@ of the evidence, even when it strongly points in one direction.""", }, } + # Provide guidance on available models similar to single-model tools + model_description = ( + "When the user names a model, you MUST use that exact value or report the " + "provider error—never swap in another option. Use the `listmodels` tool for the full roster." + ) + + summaries, total, restricted = self._get_ranked_model_summaries() + remainder = max(0, total - len(summaries)) + if summaries: + label = "Allowed models" if restricted else "Top models" + top_line = "; ".join(summaries) + if remainder > 0: + top_line = f"{label}: {top_line}; +{remainder} more via `listmodels`." + else: + top_line = f"{label}: {top_line}." + model_description = f"{model_description} {top_line}" + else: + model_description = ( + f"{model_description} No models detected—configure provider credentials or use the `listmodels` tool " + "to inspect availability." + ) + + restriction_note = self._get_restriction_note() + if restriction_note and (remainder > 0 or not summaries): + model_description = f"{model_description} {restriction_note}." + + existing_models_desc = consensus_field_overrides["models"]["description"] + consensus_field_overrides["models"]["description"] = f"{existing_models_desc} {model_description}" + # Define excluded fields for consensus workflow excluded_workflow_fields = [ "files_checked", # Not used in consensus workflow