Include custom models in model discovery for auto mode too
This commit is contained in:
@@ -200,6 +200,26 @@ class ModelProviderRegistry:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
models[model_name] = provider_type
|
models[model_name] = provider_type
|
||||||
|
elif provider_type == ProviderType.CUSTOM:
|
||||||
|
# Custom provider also uses a registry system (shared with OpenRouter)
|
||||||
|
if hasattr(provider, "_registry") and provider._registry:
|
||||||
|
# Get all models from the registry
|
||||||
|
all_models = provider._registry.list_models()
|
||||||
|
aliases = provider._registry.list_aliases()
|
||||||
|
|
||||||
|
# Add models that are validated by the custom provider
|
||||||
|
for model_name in all_models + aliases:
|
||||||
|
# Use the provider's validation logic to determine if this model
|
||||||
|
# is appropriate for the custom endpoint
|
||||||
|
if provider.validate_model_name(model_name):
|
||||||
|
# Check restrictions if enabled
|
||||||
|
if restriction_service and not restriction_service.is_allowed(
|
||||||
|
provider_type, model_name
|
||||||
|
):
|
||||||
|
logging.debug(f"Model {model_name} filtered by restrictions")
|
||||||
|
continue
|
||||||
|
|
||||||
|
models[model_name] = provider_type
|
||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
@@ -274,11 +294,13 @@ class ModelProviderRegistry:
|
|||||||
gemini_models = [m for m, p in available_models.items() if p == ProviderType.GOOGLE]
|
gemini_models = [m for m, p in available_models.items() if p == ProviderType.GOOGLE]
|
||||||
xai_models = [m for m, p in available_models.items() if p == ProviderType.XAI]
|
xai_models = [m for m, p in available_models.items() if p == ProviderType.XAI]
|
||||||
openrouter_models = [m for m, p in available_models.items() if p == ProviderType.OPENROUTER]
|
openrouter_models = [m for m, p in available_models.items() if p == ProviderType.OPENROUTER]
|
||||||
|
custom_models = [m for m, p in available_models.items() if p == ProviderType.CUSTOM]
|
||||||
|
|
||||||
openai_available = bool(openai_models)
|
openai_available = bool(openai_models)
|
||||||
gemini_available = bool(gemini_models)
|
gemini_available = bool(gemini_models)
|
||||||
xai_available = bool(xai_models)
|
xai_available = bool(xai_models)
|
||||||
openrouter_available = bool(openrouter_models)
|
openrouter_available = bool(openrouter_models)
|
||||||
|
custom_available = bool(custom_models)
|
||||||
|
|
||||||
if tool_category == ToolModelCategory.EXTENDED_REASONING:
|
if tool_category == ToolModelCategory.EXTENDED_REASONING:
|
||||||
# Prefer thinking-capable models for deep reasoning tools
|
# Prefer thinking-capable models for deep reasoning tools
|
||||||
@@ -305,6 +327,9 @@ class ModelProviderRegistry:
|
|||||||
return thinking_model
|
return thinking_model
|
||||||
# Fallback to first available OpenRouter model
|
# Fallback to first available OpenRouter model
|
||||||
return openrouter_models[0]
|
return openrouter_models[0]
|
||||||
|
elif custom_available:
|
||||||
|
# Fallback to custom models when available
|
||||||
|
return custom_models[0]
|
||||||
else:
|
else:
|
||||||
# Fallback to pro if nothing found
|
# Fallback to pro if nothing found
|
||||||
return "gemini-2.5-pro-preview-06-05"
|
return "gemini-2.5-pro-preview-06-05"
|
||||||
@@ -332,6 +357,9 @@ class ModelProviderRegistry:
|
|||||||
elif openrouter_available:
|
elif openrouter_available:
|
||||||
# Fallback to first available OpenRouter model
|
# Fallback to first available OpenRouter model
|
||||||
return openrouter_models[0]
|
return openrouter_models[0]
|
||||||
|
elif custom_available:
|
||||||
|
# Fallback to custom models when available
|
||||||
|
return custom_models[0]
|
||||||
else:
|
else:
|
||||||
# Default to flash
|
# Default to flash
|
||||||
return "gemini-2.5-flash-preview-05-20"
|
return "gemini-2.5-flash-preview-05-20"
|
||||||
@@ -353,6 +381,9 @@ class ModelProviderRegistry:
|
|||||||
return gemini_models[0]
|
return gemini_models[0]
|
||||||
elif openrouter_available:
|
elif openrouter_available:
|
||||||
return openrouter_models[0]
|
return openrouter_models[0]
|
||||||
|
elif custom_available:
|
||||||
|
# Fallback to custom models when available
|
||||||
|
return custom_models[0]
|
||||||
else:
|
else:
|
||||||
# No models available due to restrictions - check if any providers exist
|
# No models available due to restrictions - check if any providers exist
|
||||||
if not available_models:
|
if not available_models:
|
||||||
|
|||||||
208
tests/test_auto_mode_custom_provider_only.py
Normal file
208
tests/test_auto_mode_custom_provider_only.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
"""Test auto mode with only custom provider configured to reproduce the reported issue."""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from providers.base import ProviderType
|
||||||
|
from providers.registry import ModelProviderRegistry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.no_mock_provider
|
||||||
|
class TestAutoModeCustomProviderOnly:
|
||||||
|
"""Test auto mode when only custom provider is configured."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Set up clean state before each test."""
|
||||||
|
# Save original environment state for restoration
|
||||||
|
self._original_env = {}
|
||||||
|
for key in [
|
||||||
|
"GEMINI_API_KEY",
|
||||||
|
"OPENAI_API_KEY",
|
||||||
|
"XAI_API_KEY",
|
||||||
|
"OPENROUTER_API_KEY",
|
||||||
|
"CUSTOM_API_URL",
|
||||||
|
"CUSTOM_API_KEY",
|
||||||
|
"DEFAULT_MODEL",
|
||||||
|
]:
|
||||||
|
self._original_env[key] = os.environ.get(key)
|
||||||
|
|
||||||
|
# Clear restriction service cache
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
# Clear provider registry by resetting singleton instance
|
||||||
|
ModelProviderRegistry._instance = None
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
"""Clean up after each test."""
|
||||||
|
# Restore original environment
|
||||||
|
for key, value in self._original_env.items():
|
||||||
|
if value is not None:
|
||||||
|
os.environ[key] = value
|
||||||
|
elif key in os.environ:
|
||||||
|
del os.environ[key]
|
||||||
|
|
||||||
|
# Reload config to pick up the restored environment
|
||||||
|
import config
|
||||||
|
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
|
# Clear restriction service cache
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
# Clear provider registry by resetting singleton instance
|
||||||
|
ModelProviderRegistry._instance = None
|
||||||
|
|
||||||
|
def test_reproduce_auto_mode_custom_provider_only_issue(self):
|
||||||
|
"""Test the fix for auto mode failing when only custom provider is configured."""
|
||||||
|
|
||||||
|
# Set up environment with ONLY custom provider configured
|
||||||
|
test_env = {
|
||||||
|
"CUSTOM_API_URL": "http://localhost:11434/v1",
|
||||||
|
"CUSTOM_API_KEY": "", # Empty for Ollama-style
|
||||||
|
"DEFAULT_MODEL": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Clear all other provider keys
|
||||||
|
clear_keys = ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]
|
||||||
|
|
||||||
|
with patch.dict(os.environ, test_env, clear=False):
|
||||||
|
# Ensure other provider keys are not set
|
||||||
|
for key in clear_keys:
|
||||||
|
if key in os.environ:
|
||||||
|
del os.environ[key]
|
||||||
|
|
||||||
|
# Reload config to pick up auto mode
|
||||||
|
import config
|
||||||
|
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
|
# Register only the custom provider (simulating server startup)
|
||||||
|
from providers.custom import CustomProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider)
|
||||||
|
|
||||||
|
# This should now work after the fix
|
||||||
|
# The fix added support for custom provider registry system in get_available_models()
|
||||||
|
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||||
|
|
||||||
|
# This assertion should now pass after the fix
|
||||||
|
assert available_models, (
|
||||||
|
"Expected custom provider models to be available. "
|
||||||
|
"This test verifies the fix for auto mode failing with custom providers."
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_custom_provider_models_available_via_registry(self):
|
||||||
|
"""Test that custom provider has models available via its registry system."""
|
||||||
|
|
||||||
|
# Set up environment with only custom provider
|
||||||
|
test_env = {
|
||||||
|
"CUSTOM_API_URL": "http://localhost:11434/v1",
|
||||||
|
"CUSTOM_API_KEY": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.dict(os.environ, test_env, clear=False):
|
||||||
|
# Clear other provider keys
|
||||||
|
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||||
|
if key in os.environ:
|
||||||
|
del os.environ[key]
|
||||||
|
|
||||||
|
# Register custom provider
|
||||||
|
from providers.custom import CustomProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider)
|
||||||
|
|
||||||
|
# Get the provider instance
|
||||||
|
custom_provider = ModelProviderRegistry.get_provider(ProviderType.CUSTOM)
|
||||||
|
assert custom_provider is not None, "Custom provider should be available"
|
||||||
|
|
||||||
|
# Verify it has a registry with models
|
||||||
|
assert hasattr(custom_provider, "_registry"), "Custom provider should have _registry"
|
||||||
|
assert custom_provider._registry is not None, "Registry should be initialized"
|
||||||
|
|
||||||
|
# Get models from registry
|
||||||
|
models = custom_provider._registry.list_models()
|
||||||
|
aliases = custom_provider._registry.list_aliases()
|
||||||
|
|
||||||
|
# Should have some models and aliases available
|
||||||
|
assert models, "Custom provider registry should have models"
|
||||||
|
assert aliases, "Custom provider registry should have aliases"
|
||||||
|
|
||||||
|
print(f"Available models: {len(models)}")
|
||||||
|
print(f"Available aliases: {len(aliases)}")
|
||||||
|
|
||||||
|
def test_custom_provider_validate_model_name(self):
|
||||||
|
"""Test that custom provider can validate model names."""
|
||||||
|
|
||||||
|
# Set up environment with only custom provider
|
||||||
|
test_env = {
|
||||||
|
"CUSTOM_API_URL": "http://localhost:11434/v1",
|
||||||
|
"CUSTOM_API_KEY": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.dict(os.environ, test_env, clear=False):
|
||||||
|
# Register custom provider
|
||||||
|
from providers.custom import CustomProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider)
|
||||||
|
|
||||||
|
# Get the provider instance
|
||||||
|
custom_provider = ModelProviderRegistry.get_provider(ProviderType.CUSTOM)
|
||||||
|
assert custom_provider is not None
|
||||||
|
|
||||||
|
# Test that it can validate some typical custom model names
|
||||||
|
test_models = ["llama3.2", "llama3.2:latest", "local-model", "ollama-model"]
|
||||||
|
|
||||||
|
for model in test_models:
|
||||||
|
is_valid = custom_provider.validate_model_name(model)
|
||||||
|
print(f"Model '{model}' validation: {is_valid}")
|
||||||
|
# Should validate at least some local-style models
|
||||||
|
# (The exact validation logic may vary based on registry content)
|
||||||
|
|
||||||
|
def test_auto_mode_fallback_with_custom_only_should_work(self):
|
||||||
|
"""Test that auto mode fallback should work when only custom provider is available."""
|
||||||
|
|
||||||
|
# Set up environment with only custom provider
|
||||||
|
test_env = {
|
||||||
|
"CUSTOM_API_URL": "http://localhost:11434/v1",
|
||||||
|
"CUSTOM_API_KEY": "",
|
||||||
|
"DEFAULT_MODEL": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.dict(os.environ, test_env, clear=False):
|
||||||
|
# Clear other provider keys
|
||||||
|
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||||
|
if key in os.environ:
|
||||||
|
del os.environ[key]
|
||||||
|
|
||||||
|
# Reload config
|
||||||
|
import config
|
||||||
|
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
|
# Register custom provider
|
||||||
|
from providers.custom import CustomProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider)
|
||||||
|
|
||||||
|
# This should work and return a fallback model from custom provider
|
||||||
|
# Currently fails because get_preferred_fallback_model doesn't consider custom models
|
||||||
|
from tools.models import ToolModelCategory
|
||||||
|
|
||||||
|
try:
|
||||||
|
fallback_model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||||
|
print(f"Fallback model for FAST_RESPONSE: {fallback_model}")
|
||||||
|
|
||||||
|
# Should get a valid model name, not the hardcoded fallback
|
||||||
|
assert (
|
||||||
|
fallback_model != "gemini-2.5-flash-preview-05-20"
|
||||||
|
), "Should not fallback to hardcoded Gemini model when custom provider is available"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Getting fallback model failed: {e}")
|
||||||
Reference in New Issue
Block a user