Files
my-pal-mcp-server/providers/registry.py

474 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Model provider registry for managing available providers."""
import logging
from typing import TYPE_CHECKING, Optional
from utils.env import get_env
from .base import ModelProvider
from .shared import ProviderType
if TYPE_CHECKING:
from tools.models import ToolModelCategory
class ModelProviderRegistry:
"""Central catalogue of provider implementations used by the MCP server.
Role
Holds the mapping between :class:`ProviderType` values and concrete
:class:`ModelProvider` subclasses/factories. At runtime the registry
is responsible for instantiating providers, caching them for reuse, and
mediating lookup of providers and model names in provider priority
order.
Core responsibilities
* Resolve API keys and other runtime configuration for each provider
* Lazily create provider instances so unused backends incur no cost
* Expose convenience methods for enumerating available models and
locating which provider can service a requested model name or alias
* Honour the project-wide provider priority policy so namespaces (or
alias collisions) are resolved deterministically.
"""
_instance = None
# Provider priority order for model selection
# Native APIs first, then custom endpoints, then catch-all providers
PROVIDER_PRIORITY_ORDER = [
ProviderType.GOOGLE, # Direct Gemini access
ProviderType.OPENAI, # Direct OpenAI access
ProviderType.AZURE, # Azure-hosted OpenAI deployments
ProviderType.XAI, # Direct X.AI GROK access
ProviderType.ZEN, # OpenCode Zen curated models
ProviderType.DIAL, # DIAL unified API access
ProviderType.CUSTOM, # Local/self-hosted models
ProviderType.OPENROUTER, # Catch-all for cloud models
]
def __new__(cls):
"""Singleton pattern for registry."""
if cls._instance is None:
logging.debug("REGISTRY: Creating new registry instance")
cls._instance = super().__new__(cls)
# Initialize instance dictionaries on first creation
cls._instance._providers = {}
cls._instance._initialized_providers = {}
logging.debug(f"REGISTRY: Created instance {cls._instance}")
return cls._instance
@classmethod
def register_provider(cls, provider_type: ProviderType, provider_class: type[ModelProvider]) -> None:
"""Register a new provider class.
Args:
provider_type: Type of the provider (e.g., ProviderType.GOOGLE)
provider_class: Class that implements ModelProvider interface
"""
instance = cls()
instance._providers[provider_type] = provider_class
# Invalidate any cached instance so subsequent lookups use the new registration
instance._initialized_providers.pop(provider_type, None)
@classmethod
def get_provider(cls, provider_type: ProviderType, force_new: bool = False) -> Optional[ModelProvider]:
"""Get an initialized provider instance.
Args:
provider_type: Type of provider to get
force_new: Force creation of new instance instead of using cached
Returns:
Initialized ModelProvider instance or None if not available
"""
instance = cls()
# Return cached instance if available and not forcing new
if not force_new and provider_type in instance._initialized_providers:
return instance._initialized_providers[provider_type]
# Check if provider class is registered
if provider_type not in instance._providers:
return None
# Get API key from environment
api_key = cls._get_api_key_for_provider(provider_type)
# Get provider class or factory function
provider_class = instance._providers[provider_type]
# For custom providers, handle special initialization requirements
if provider_type == ProviderType.CUSTOM:
# Check if it's a factory function (callable but not a class)
if callable(provider_class) and not isinstance(provider_class, type):
# Factory function - call it with api_key parameter
provider = provider_class(api_key=api_key)
else:
# Regular class - need to handle URL requirement
custom_url = get_env("CUSTOM_API_URL", "") or ""
if not custom_url:
if api_key: # Key is set but URL is missing
logging.warning("CUSTOM_API_KEY set but CUSTOM_API_URL missing skipping Custom provider")
return None
# Use empty string as API key for custom providers that don't need auth (e.g., Ollama)
# This allows the provider to be created even without CUSTOM_API_KEY being set
api_key = api_key or ""
# Initialize custom provider with both API key and base URL
provider = provider_class(api_key=api_key, base_url=custom_url)
elif provider_type == ProviderType.GOOGLE:
# For Gemini, check if custom base URL is configured
if not api_key:
return None
gemini_base_url = get_env("GEMINI_BASE_URL")
provider_kwargs = {"api_key": api_key}
if gemini_base_url:
provider_kwargs["base_url"] = gemini_base_url
logging.info(f"Initialized Gemini provider with custom endpoint: {gemini_base_url}")
provider = provider_class(**provider_kwargs)
elif provider_type == ProviderType.AZURE:
if not api_key:
return None
azure_endpoint = get_env("AZURE_OPENAI_ENDPOINT")
if not azure_endpoint:
logging.warning("AZURE_OPENAI_ENDPOINT missing skipping Azure OpenAI provider")
return None
azure_version = get_env("AZURE_OPENAI_API_VERSION")
provider = provider_class(
api_key=api_key,
azure_endpoint=azure_endpoint,
api_version=azure_version,
)
else:
if not api_key:
return None
# Initialize non-custom provider with just API key
provider = provider_class(api_key=api_key)
# Cache the instance
instance._initialized_providers[provider_type] = provider
return provider
@classmethod
def get_provider_for_model(cls, model_name: str) -> Optional[ModelProvider]:
"""Get provider instance for a specific model name.
Provider priority order:
1. Native APIs (GOOGLE, OPENAI) - Most direct and efficient
2. CUSTOM - For local/private models with specific endpoints
3. OPENROUTER - Catch-all for cloud models via unified API
Args:
model_name: Name of the model (e.g., "gemini-2.5-flash", "gpt5")
Returns:
ModelProvider instance that supports this model
"""
logging.debug(f"get_provider_for_model called with model_name='{model_name}'")
# Check providers in priority order
instance = cls()
logging.debug(f"Registry instance: {instance}")
logging.debug(f"Available providers in registry: {list(instance._providers.keys())}")
for provider_type in cls.PROVIDER_PRIORITY_ORDER:
if provider_type in instance._providers:
logging.debug(f"Found {provider_type} in registry")
# Get or create provider instance
provider = cls.get_provider(provider_type)
if provider and provider.validate_model_name(model_name):
logging.debug(f"{provider_type} validates model {model_name}")
return provider
else:
logging.debug(f"{provider_type} does not validate model {model_name}")
else:
logging.debug(f"{provider_type} not found in registry")
logging.debug(f"No provider found for model {model_name}")
return None
@classmethod
def get_available_providers(cls) -> list[ProviderType]:
"""Get list of registered provider types."""
instance = cls()
return list(instance._providers.keys())
@classmethod
def get_available_models(cls, respect_restrictions: bool = True) -> dict[str, ProviderType]:
"""Get mapping of all available models to their providers.
Args:
respect_restrictions: If True, filter out models not allowed by restrictions
Returns:
Dict mapping model names to provider types
"""
# Import here to avoid circular imports
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
models: dict[str, ProviderType] = {}
instance = cls()
for provider_type in instance._providers:
provider = cls.get_provider(provider_type)
if not provider:
continue
try:
available = provider.list_models(respect_restrictions=respect_restrictions)
except NotImplementedError:
logging.warning("Provider %s does not implement list_models", provider_type)
continue
if restriction_service and restriction_service.has_restrictions(provider_type):
restricted_display = cls._collect_restricted_display_names(
provider,
provider_type,
available,
restriction_service,
)
if restricted_display:
for model_name in restricted_display:
models[model_name] = provider_type
continue
for model_name in available:
# =====================================================================================
# CRITICAL: Prevent double restriction filtering (Fixed Issue #98)
# =====================================================================================
# Previously, both the provider AND registry applied restrictions, causing
# double-filtering that resulted in "no models available" errors.
#
# Logic: If respect_restrictions=True, provider already filtered models,
# so registry should NOT filter them again.
# TEST COVERAGE: tests/test_provider_routing_bugs.py::TestOpenRouterAliasRestrictions
# =====================================================================================
if (
restriction_service
and not respect_restrictions # Only filter if provider didn't already filter
and not restriction_service.is_allowed(provider_type, model_name)
):
logging.debug("Model %s filtered by restrictions", model_name)
continue
models[model_name] = provider_type
return models
@classmethod
def _collect_restricted_display_names(
cls,
provider: ModelProvider,
provider_type: ProviderType,
available: list[str],
restriction_service,
) -> list[str] | None:
"""Derive the human-facing model list when restrictions are active."""
allowed_models = restriction_service.get_allowed_models(provider_type)
if not allowed_models:
return None
allowed_details: list[tuple[str, int]] = []
for model_name in sorted(allowed_models):
try:
capabilities = provider.get_capabilities(model_name)
except (AttributeError, ValueError):
continue
try:
rank = capabilities.get_effective_capability_rank()
rank_value = float(rank)
except (AttributeError, TypeError, ValueError):
rank_value = 0.0
allowed_details.append((model_name, rank_value))
if allowed_details:
allowed_details.sort(key=lambda item: (-item[1], item[0]))
return [name for name, _ in allowed_details]
# Fallback: intersect the allowlist with the provider-advertised names.
available_lookup = {name.lower(): name for name in available}
display_names: list[str] = []
for model_name in sorted(allowed_models):
lowered = model_name.lower()
if lowered in available_lookup:
display_names.append(available_lookup[lowered])
return display_names
@classmethod
def get_available_model_names(cls, provider_type: Optional[ProviderType] = None) -> list[str]:
"""Get list of available model names, optionally filtered by provider.
This respects model restrictions automatically.
Args:
provider_type: Optional provider to filter by
Returns:
List of available model names
"""
available_models = cls.get_available_models(respect_restrictions=True)
if provider_type:
# Filter by specific provider
return [name for name, ptype in available_models.items() if ptype == provider_type]
else:
# Return all available models
return list(available_models.keys())
@classmethod
def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]:
"""Get API key for a provider from environment variables.
Args:
provider_type: Provider type to get API key for
Returns:
API key string or None if not found
"""
key_mapping = {
ProviderType.GOOGLE: "GEMINI_API_KEY",
ProviderType.OPENAI: "OPENAI_API_KEY",
ProviderType.AZURE: "AZURE_OPENAI_API_KEY",
ProviderType.XAI: "XAI_API_KEY",
ProviderType.ZEN: "ZEN_API_KEY",
ProviderType.OPENROUTER: "OPENROUTER_API_KEY",
ProviderType.CUSTOM: "CUSTOM_API_KEY", # Can be empty for providers that don't need auth
ProviderType.DIAL: "DIAL_API_KEY",
}
env_var = key_mapping.get(provider_type)
if not env_var:
return None
return get_env(env_var)
@classmethod
def _get_allowed_models_for_provider(cls, provider: ModelProvider, provider_type: ProviderType) -> list[str]:
"""Get a list of allowed canonical model names for a given provider.
Args:
provider: The provider instance to get models for
provider_type: The provider type for restriction checking
Returns:
List of model names that are both supported and allowed
"""
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
allowed_models = []
# Get the provider's supported models
try:
# Use list_models to get all supported models (handles both regular and custom providers)
supported_models = provider.list_models(respect_restrictions=False)
except (NotImplementedError, AttributeError):
# Fallback to provider-declared capability maps if list_models not implemented
model_map = getattr(provider, "MODEL_CAPABILITIES", None)
supported_models = list(model_map.keys()) if isinstance(model_map, dict) else []
# Filter by restrictions
for model_name in supported_models:
if restriction_service.is_allowed(provider_type, model_name):
allowed_models.append(model_name)
return allowed_models
@classmethod
def get_preferred_fallback_model(cls, tool_category: Optional["ToolModelCategory"] = None) -> str:
"""Get the preferred fallback model based on provider priority and tool category.
This method orchestrates model selection by:
1. Getting allowed models for each provider (respecting restrictions)
2. Asking providers for their preference from the allowed list
3. Falling back to first available model if no preference given
Args:
tool_category: Optional category to influence model selection
Returns:
Model name string for fallback use
"""
from tools.models import ToolModelCategory
effective_category = tool_category or ToolModelCategory.BALANCED
first_available_model = None
# Ask each provider for their preference in priority order
for provider_type in cls.PROVIDER_PRIORITY_ORDER:
provider = cls.get_provider(provider_type)
if provider:
# 1. Registry filters the models first
allowed_models = cls._get_allowed_models_for_provider(provider, provider_type)
if not allowed_models:
continue
# 2. Keep track of the first available model as fallback
if not first_available_model:
first_available_model = sorted(allowed_models)[0]
# 3. Ask provider to pick from allowed list
preferred_model = provider.get_preferred_model(effective_category, allowed_models)
if preferred_model:
logging.debug(
f"Provider {provider_type.value} selected '{preferred_model}' for category '{effective_category.value}'"
)
return preferred_model
# If no provider returned a preference, use first available model
if first_available_model:
logging.debug(f"No provider preference, using first available: {first_available_model}")
return first_available_model
# Ultimate fallback if no providers have models
logging.warning("No models available from any provider, using default fallback")
return "gemini-2.5-flash"
@classmethod
def get_available_providers_with_keys(cls) -> list[ProviderType]:
"""Get list of provider types that have valid API keys.
Returns:
List of ProviderType values for providers with valid API keys
"""
available = []
instance = cls()
for provider_type in instance._providers:
if cls.get_provider(provider_type) is not None:
available.append(provider_type)
return available
@classmethod
def clear_cache(cls) -> None:
"""Clear cached provider instances."""
instance = cls()
instance._initialized_providers.clear()
@classmethod
def reset_for_testing(cls) -> None:
"""Reset the registry to a clean state for testing.
This provides a safe, public API for tests to clean up registry state
without directly manipulating private attributes.
"""
cls._instance = None
if hasattr(cls, "_providers"):
cls._providers = {}
@classmethod
def unregister_provider(cls, provider_type: ProviderType) -> None:
"""Unregister a provider (mainly for testing)."""
instance = cls()
instance._providers.pop(provider_type, None)
instance._initialized_providers.pop(provider_type, None)