Include custom models in model discovery for auto mode too
This commit is contained in:
@@ -200,6 +200,26 @@ class ModelProviderRegistry:
|
||||
continue
|
||||
|
||||
models[model_name] = provider_type
|
||||
elif provider_type == ProviderType.CUSTOM:
|
||||
# Custom provider also uses a registry system (shared with OpenRouter)
|
||||
if hasattr(provider, "_registry") and provider._registry:
|
||||
# Get all models from the registry
|
||||
all_models = provider._registry.list_models()
|
||||
aliases = provider._registry.list_aliases()
|
||||
|
||||
# Add models that are validated by the custom provider
|
||||
for model_name in all_models + aliases:
|
||||
# Use the provider's validation logic to determine if this model
|
||||
# is appropriate for the custom endpoint
|
||||
if provider.validate_model_name(model_name):
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(
|
||||
provider_type, model_name
|
||||
):
|
||||
logging.debug(f"Model {model_name} filtered by restrictions")
|
||||
continue
|
||||
|
||||
models[model_name] = provider_type
|
||||
|
||||
return models
|
||||
|
||||
@@ -274,11 +294,13 @@ class ModelProviderRegistry:
|
||||
gemini_models = [m for m, p in available_models.items() if p == ProviderType.GOOGLE]
|
||||
xai_models = [m for m, p in available_models.items() if p == ProviderType.XAI]
|
||||
openrouter_models = [m for m, p in available_models.items() if p == ProviderType.OPENROUTER]
|
||||
custom_models = [m for m, p in available_models.items() if p == ProviderType.CUSTOM]
|
||||
|
||||
openai_available = bool(openai_models)
|
||||
gemini_available = bool(gemini_models)
|
||||
xai_available = bool(xai_models)
|
||||
openrouter_available = bool(openrouter_models)
|
||||
custom_available = bool(custom_models)
|
||||
|
||||
if tool_category == ToolModelCategory.EXTENDED_REASONING:
|
||||
# Prefer thinking-capable models for deep reasoning tools
|
||||
@@ -305,6 +327,9 @@ class ModelProviderRegistry:
|
||||
return thinking_model
|
||||
# Fallback to first available OpenRouter model
|
||||
return openrouter_models[0]
|
||||
elif custom_available:
|
||||
# Fallback to custom models when available
|
||||
return custom_models[0]
|
||||
else:
|
||||
# Fallback to pro if nothing found
|
||||
return "gemini-2.5-pro-preview-06-05"
|
||||
@@ -332,6 +357,9 @@ class ModelProviderRegistry:
|
||||
elif openrouter_available:
|
||||
# Fallback to first available OpenRouter model
|
||||
return openrouter_models[0]
|
||||
elif custom_available:
|
||||
# Fallback to custom models when available
|
||||
return custom_models[0]
|
||||
else:
|
||||
# Default to flash
|
||||
return "gemini-2.5-flash-preview-05-20"
|
||||
@@ -353,6 +381,9 @@ class ModelProviderRegistry:
|
||||
return gemini_models[0]
|
||||
elif openrouter_available:
|
||||
return openrouter_models[0]
|
||||
elif custom_available:
|
||||
# Fallback to custom models when available
|
||||
return custom_models[0]
|
||||
else:
|
||||
# No models available due to restrictions - check if any providers exist
|
||||
if not available_models:
|
||||
|
||||
208
tests/test_auto_mode_custom_provider_only.py
Normal file
208
tests/test_auto_mode_custom_provider_only.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Test auto mode with only custom provider configured to reproduce the reported issue."""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
|
||||
@pytest.mark.no_mock_provider
|
||||
class TestAutoModeCustomProviderOnly:
|
||||
"""Test auto mode when only custom provider is configured."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up clean state before each test."""
|
||||
# Save original environment state for restoration
|
||||
self._original_env = {}
|
||||
for key in [
|
||||
"GEMINI_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"XAI_API_KEY",
|
||||
"OPENROUTER_API_KEY",
|
||||
"CUSTOM_API_URL",
|
||||
"CUSTOM_API_KEY",
|
||||
"DEFAULT_MODEL",
|
||||
]:
|
||||
self._original_env[key] = os.environ.get(key)
|
||||
|
||||
# Clear restriction service cache
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
# Clear provider registry by resetting singleton instance
|
||||
ModelProviderRegistry._instance = None
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clean up after each test."""
|
||||
# Restore original environment
|
||||
for key, value in self._original_env.items():
|
||||
if value is not None:
|
||||
os.environ[key] = value
|
||||
elif key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
# Reload config to pick up the restored environment
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Clear restriction service cache
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
# Clear provider registry by resetting singleton instance
|
||||
ModelProviderRegistry._instance = None
|
||||
|
||||
def test_reproduce_auto_mode_custom_provider_only_issue(self):
|
||||
"""Test the fix for auto mode failing when only custom provider is configured."""
|
||||
|
||||
# Set up environment with ONLY custom provider configured
|
||||
test_env = {
|
||||
"CUSTOM_API_URL": "http://localhost:11434/v1",
|
||||
"CUSTOM_API_KEY": "", # Empty for Ollama-style
|
||||
"DEFAULT_MODEL": "auto",
|
||||
}
|
||||
|
||||
# Clear all other provider keys
|
||||
clear_keys = ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]
|
||||
|
||||
with patch.dict(os.environ, test_env, clear=False):
|
||||
# Ensure other provider keys are not set
|
||||
for key in clear_keys:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
# Reload config to pick up auto mode
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Register only the custom provider (simulating server startup)
|
||||
from providers.custom import CustomProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider)
|
||||
|
||||
# This should now work after the fix
|
||||
# The fix added support for custom provider registry system in get_available_models()
|
||||
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||
|
||||
# This assertion should now pass after the fix
|
||||
assert available_models, (
|
||||
"Expected custom provider models to be available. "
|
||||
"This test verifies the fix for auto mode failing with custom providers."
|
||||
)
|
||||
|
||||
def test_custom_provider_models_available_via_registry(self):
|
||||
"""Test that custom provider has models available via its registry system."""
|
||||
|
||||
# Set up environment with only custom provider
|
||||
test_env = {
|
||||
"CUSTOM_API_URL": "http://localhost:11434/v1",
|
||||
"CUSTOM_API_KEY": "",
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, test_env, clear=False):
|
||||
# Clear other provider keys
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
# Register custom provider
|
||||
from providers.custom import CustomProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider)
|
||||
|
||||
# Get the provider instance
|
||||
custom_provider = ModelProviderRegistry.get_provider(ProviderType.CUSTOM)
|
||||
assert custom_provider is not None, "Custom provider should be available"
|
||||
|
||||
# Verify it has a registry with models
|
||||
assert hasattr(custom_provider, "_registry"), "Custom provider should have _registry"
|
||||
assert custom_provider._registry is not None, "Registry should be initialized"
|
||||
|
||||
# Get models from registry
|
||||
models = custom_provider._registry.list_models()
|
||||
aliases = custom_provider._registry.list_aliases()
|
||||
|
||||
# Should have some models and aliases available
|
||||
assert models, "Custom provider registry should have models"
|
||||
assert aliases, "Custom provider registry should have aliases"
|
||||
|
||||
print(f"Available models: {len(models)}")
|
||||
print(f"Available aliases: {len(aliases)}")
|
||||
|
||||
def test_custom_provider_validate_model_name(self):
|
||||
"""Test that custom provider can validate model names."""
|
||||
|
||||
# Set up environment with only custom provider
|
||||
test_env = {
|
||||
"CUSTOM_API_URL": "http://localhost:11434/v1",
|
||||
"CUSTOM_API_KEY": "",
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, test_env, clear=False):
|
||||
# Register custom provider
|
||||
from providers.custom import CustomProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider)
|
||||
|
||||
# Get the provider instance
|
||||
custom_provider = ModelProviderRegistry.get_provider(ProviderType.CUSTOM)
|
||||
assert custom_provider is not None
|
||||
|
||||
# Test that it can validate some typical custom model names
|
||||
test_models = ["llama3.2", "llama3.2:latest", "local-model", "ollama-model"]
|
||||
|
||||
for model in test_models:
|
||||
is_valid = custom_provider.validate_model_name(model)
|
||||
print(f"Model '{model}' validation: {is_valid}")
|
||||
# Should validate at least some local-style models
|
||||
# (The exact validation logic may vary based on registry content)
|
||||
|
||||
def test_auto_mode_fallback_with_custom_only_should_work(self):
|
||||
"""Test that auto mode fallback should work when only custom provider is available."""
|
||||
|
||||
# Set up environment with only custom provider
|
||||
test_env = {
|
||||
"CUSTOM_API_URL": "http://localhost:11434/v1",
|
||||
"CUSTOM_API_KEY": "",
|
||||
"DEFAULT_MODEL": "auto",
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, test_env, clear=False):
|
||||
# Clear other provider keys
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
# Reload config
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Register custom provider
|
||||
from providers.custom import CustomProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider)
|
||||
|
||||
# This should work and return a fallback model from custom provider
|
||||
# Currently fails because get_preferred_fallback_model doesn't consider custom models
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
try:
|
||||
fallback_model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||
print(f"Fallback model for FAST_RESPONSE: {fallback_model}")
|
||||
|
||||
# Should get a valid model name, not the hardcoded fallback
|
||||
assert (
|
||||
fallback_model != "gemini-2.5-flash-preview-05-20"
|
||||
), "Should not fallback to hardcoded Gemini model when custom provider is available"
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"Getting fallback model failed: {e}")
|
||||
Reference in New Issue
Block a user