From 5199dd6ead12f6fe1f24d7662e0d6e07889ba635 Mon Sep 17 00:00:00 2001 From: Fahad Date: Wed, 18 Jun 2025 06:40:35 +0400 Subject: [PATCH] Include custom models in model discovery for auto mode too --- providers/registry.py | 31 +++ tests/test_auto_mode_custom_provider_only.py | 208 +++++++++++++++++++ 2 files changed, 239 insertions(+) create mode 100644 tests/test_auto_mode_custom_provider_only.py diff --git a/providers/registry.py b/providers/registry.py index 6332466..b2e52da 100644 --- a/providers/registry.py +++ b/providers/registry.py @@ -200,6 +200,26 @@ class ModelProviderRegistry: continue models[model_name] = provider_type + elif provider_type == ProviderType.CUSTOM: + # Custom provider also uses a registry system (shared with OpenRouter) + if hasattr(provider, "_registry") and provider._registry: + # Get all models from the registry + all_models = provider._registry.list_models() + aliases = provider._registry.list_aliases() + + # Add models that are validated by the custom provider + for model_name in all_models + aliases: + # Use the provider's validation logic to determine if this model + # is appropriate for the custom endpoint + if provider.validate_model_name(model_name): + # Check restrictions if enabled + if restriction_service and not restriction_service.is_allowed( + provider_type, model_name + ): + logging.debug(f"Model {model_name} filtered by restrictions") + continue + + models[model_name] = provider_type return models @@ -274,11 +294,13 @@ class ModelProviderRegistry: gemini_models = [m for m, p in available_models.items() if p == ProviderType.GOOGLE] xai_models = [m for m, p in available_models.items() if p == ProviderType.XAI] openrouter_models = [m for m, p in available_models.items() if p == ProviderType.OPENROUTER] + custom_models = [m for m, p in available_models.items() if p == ProviderType.CUSTOM] openai_available = bool(openai_models) gemini_available = bool(gemini_models) xai_available = bool(xai_models) openrouter_available = bool(openrouter_models) + custom_available = bool(custom_models) if tool_category == ToolModelCategory.EXTENDED_REASONING: # Prefer thinking-capable models for deep reasoning tools @@ -305,6 +327,9 @@ class ModelProviderRegistry: return thinking_model # Fallback to first available OpenRouter model return openrouter_models[0] + elif custom_available: + # Fallback to custom models when available + return custom_models[0] else: # Fallback to pro if nothing found return "gemini-2.5-pro-preview-06-05" @@ -332,6 +357,9 @@ class ModelProviderRegistry: elif openrouter_available: # Fallback to first available OpenRouter model return openrouter_models[0] + elif custom_available: + # Fallback to custom models when available + return custom_models[0] else: # Default to flash return "gemini-2.5-flash-preview-05-20" @@ -353,6 +381,9 @@ class ModelProviderRegistry: return gemini_models[0] elif openrouter_available: return openrouter_models[0] + elif custom_available: + # Fallback to custom models when available + return custom_models[0] else: # No models available due to restrictions - check if any providers exist if not available_models: diff --git a/tests/test_auto_mode_custom_provider_only.py b/tests/test_auto_mode_custom_provider_only.py new file mode 100644 index 0000000..9b3abc7 --- /dev/null +++ b/tests/test_auto_mode_custom_provider_only.py @@ -0,0 +1,208 @@ +"""Test auto mode with only custom provider configured to reproduce the reported issue.""" + +import importlib +import os +from unittest.mock import patch + +import pytest + +from providers.base import ProviderType +from providers.registry import ModelProviderRegistry + + +@pytest.mark.no_mock_provider +class TestAutoModeCustomProviderOnly: + """Test auto mode when only custom provider is configured.""" + + def setup_method(self): + """Set up clean state before each test.""" + # Save original environment state for restoration + self._original_env = {} + for key in [ + "GEMINI_API_KEY", + "OPENAI_API_KEY", + "XAI_API_KEY", + "OPENROUTER_API_KEY", + "CUSTOM_API_URL", + "CUSTOM_API_KEY", + "DEFAULT_MODEL", + ]: + self._original_env[key] = os.environ.get(key) + + # Clear restriction service cache + import utils.model_restrictions + + utils.model_restrictions._restriction_service = None + + # Clear provider registry by resetting singleton instance + ModelProviderRegistry._instance = None + + def teardown_method(self): + """Clean up after each test.""" + # Restore original environment + for key, value in self._original_env.items(): + if value is not None: + os.environ[key] = value + elif key in os.environ: + del os.environ[key] + + # Reload config to pick up the restored environment + import config + + importlib.reload(config) + + # Clear restriction service cache + import utils.model_restrictions + + utils.model_restrictions._restriction_service = None + + # Clear provider registry by resetting singleton instance + ModelProviderRegistry._instance = None + + def test_reproduce_auto_mode_custom_provider_only_issue(self): + """Test the fix for auto mode failing when only custom provider is configured.""" + + # Set up environment with ONLY custom provider configured + test_env = { + "CUSTOM_API_URL": "http://localhost:11434/v1", + "CUSTOM_API_KEY": "", # Empty for Ollama-style + "DEFAULT_MODEL": "auto", + } + + # Clear all other provider keys + clear_keys = ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"] + + with patch.dict(os.environ, test_env, clear=False): + # Ensure other provider keys are not set + for key in clear_keys: + if key in os.environ: + del os.environ[key] + + # Reload config to pick up auto mode + import config + + importlib.reload(config) + + # Register only the custom provider (simulating server startup) + from providers.custom import CustomProvider + + ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider) + + # This should now work after the fix + # The fix added support for custom provider registry system in get_available_models() + available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True) + + # This assertion should now pass after the fix + assert available_models, ( + "Expected custom provider models to be available. " + "This test verifies the fix for auto mode failing with custom providers." + ) + + def test_custom_provider_models_available_via_registry(self): + """Test that custom provider has models available via its registry system.""" + + # Set up environment with only custom provider + test_env = { + "CUSTOM_API_URL": "http://localhost:11434/v1", + "CUSTOM_API_KEY": "", + } + + with patch.dict(os.environ, test_env, clear=False): + # Clear other provider keys + for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]: + if key in os.environ: + del os.environ[key] + + # Register custom provider + from providers.custom import CustomProvider + + ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider) + + # Get the provider instance + custom_provider = ModelProviderRegistry.get_provider(ProviderType.CUSTOM) + assert custom_provider is not None, "Custom provider should be available" + + # Verify it has a registry with models + assert hasattr(custom_provider, "_registry"), "Custom provider should have _registry" + assert custom_provider._registry is not None, "Registry should be initialized" + + # Get models from registry + models = custom_provider._registry.list_models() + aliases = custom_provider._registry.list_aliases() + + # Should have some models and aliases available + assert models, "Custom provider registry should have models" + assert aliases, "Custom provider registry should have aliases" + + print(f"Available models: {len(models)}") + print(f"Available aliases: {len(aliases)}") + + def test_custom_provider_validate_model_name(self): + """Test that custom provider can validate model names.""" + + # Set up environment with only custom provider + test_env = { + "CUSTOM_API_URL": "http://localhost:11434/v1", + "CUSTOM_API_KEY": "", + } + + with patch.dict(os.environ, test_env, clear=False): + # Register custom provider + from providers.custom import CustomProvider + + ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider) + + # Get the provider instance + custom_provider = ModelProviderRegistry.get_provider(ProviderType.CUSTOM) + assert custom_provider is not None + + # Test that it can validate some typical custom model names + test_models = ["llama3.2", "llama3.2:latest", "local-model", "ollama-model"] + + for model in test_models: + is_valid = custom_provider.validate_model_name(model) + print(f"Model '{model}' validation: {is_valid}") + # Should validate at least some local-style models + # (The exact validation logic may vary based on registry content) + + def test_auto_mode_fallback_with_custom_only_should_work(self): + """Test that auto mode fallback should work when only custom provider is available.""" + + # Set up environment with only custom provider + test_env = { + "CUSTOM_API_URL": "http://localhost:11434/v1", + "CUSTOM_API_KEY": "", + "DEFAULT_MODEL": "auto", + } + + with patch.dict(os.environ, test_env, clear=False): + # Clear other provider keys + for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]: + if key in os.environ: + del os.environ[key] + + # Reload config + import config + + importlib.reload(config) + + # Register custom provider + from providers.custom import CustomProvider + + ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider) + + # This should work and return a fallback model from custom provider + # Currently fails because get_preferred_fallback_model doesn't consider custom models + from tools.models import ToolModelCategory + + try: + fallback_model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE) + print(f"Fallback model for FAST_RESPONSE: {fallback_model}") + + # Should get a valid model name, not the hardcoded fallback + assert ( + fallback_model != "gemini-2.5-flash-preview-05-20" + ), "Should not fallback to hardcoded Gemini model when custom provider is available" + + except Exception as e: + pytest.fail(f"Getting fallback model failed: {e}")