Breaking change: openrouter_models.json -> custom_models.json
* Support for Custom URLs and custom models, including locally hosted models such as ollama * Support for native + openrouter + local models (i.e. dozens of models) means you can start delegating sub-tasks to particular models or work to local models such as localizations or other boring work etc. * Several tests added * precommit to also include untracked (new) files * Logfile auto rollover * Improved logging
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""Model provider registry for managing available providers."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
@@ -10,13 +11,18 @@ class ModelProviderRegistry:
|
||||
"""Registry for managing model providers."""
|
||||
|
||||
_instance = None
|
||||
_providers: dict[ProviderType, type[ModelProvider]] = {}
|
||||
_initialized_providers: dict[ProviderType, ModelProvider] = {}
|
||||
|
||||
def __new__(cls):
|
||||
"""Singleton pattern for registry."""
|
||||
if cls._instance is None:
|
||||
logging.debug("REGISTRY: Creating new registry instance")
|
||||
cls._instance = super().__new__(cls)
|
||||
# Initialize instance dictionaries on first creation
|
||||
cls._instance._providers = {}
|
||||
cls._instance._initialized_providers = {}
|
||||
logging.debug(f"REGISTRY: Created instance {cls._instance}")
|
||||
else:
|
||||
logging.debug(f"REGISTRY: Returning existing instance {cls._instance}")
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
@@ -27,7 +33,8 @@ class ModelProviderRegistry:
|
||||
provider_type: Type of the provider (e.g., ProviderType.GOOGLE)
|
||||
provider_class: Class that implements ModelProvider interface
|
||||
"""
|
||||
cls._providers[provider_type] = provider_class
|
||||
instance = cls()
|
||||
instance._providers[provider_type] = provider_class
|
||||
|
||||
@classmethod
|
||||
def get_provider(cls, provider_type: ProviderType, force_new: bool = False) -> Optional[ModelProvider]:
|
||||
@@ -40,25 +47,48 @@ class ModelProviderRegistry:
|
||||
Returns:
|
||||
Initialized ModelProvider instance or None if not available
|
||||
"""
|
||||
instance = cls()
|
||||
|
||||
# Return cached instance if available and not forcing new
|
||||
if not force_new and provider_type in cls._initialized_providers:
|
||||
return cls._initialized_providers[provider_type]
|
||||
if not force_new and provider_type in instance._initialized_providers:
|
||||
return instance._initialized_providers[provider_type]
|
||||
|
||||
# Check if provider class is registered
|
||||
if provider_type not in cls._providers:
|
||||
if provider_type not in instance._providers:
|
||||
return None
|
||||
|
||||
# Get API key from environment
|
||||
api_key = cls._get_api_key_for_provider(provider_type)
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
# Initialize provider
|
||||
provider_class = cls._providers[provider_type]
|
||||
provider = provider_class(api_key=api_key)
|
||||
# Get provider class or factory function
|
||||
provider_class = instance._providers[provider_type]
|
||||
|
||||
# For custom providers, handle special initialization requirements
|
||||
if provider_type == ProviderType.CUSTOM:
|
||||
# Check if it's a factory function (callable but not a class)
|
||||
if callable(provider_class) and not isinstance(provider_class, type):
|
||||
# Factory function - call it with api_key parameter
|
||||
provider = provider_class(api_key=api_key)
|
||||
else:
|
||||
# Regular class - need to handle URL requirement
|
||||
custom_url = os.getenv("CUSTOM_API_URL", "")
|
||||
if not custom_url:
|
||||
if api_key: # Key is set but URL is missing
|
||||
logging.warning("CUSTOM_API_KEY set but CUSTOM_API_URL missing – skipping Custom provider")
|
||||
return None
|
||||
# Use empty string as API key for custom providers that don't need auth (e.g., Ollama)
|
||||
# This allows the provider to be created even without CUSTOM_API_KEY being set
|
||||
api_key = api_key or ""
|
||||
# Initialize custom provider with both API key and base URL
|
||||
provider = provider_class(api_key=api_key, base_url=custom_url)
|
||||
else:
|
||||
if not api_key:
|
||||
return None
|
||||
# Initialize non-custom provider with just API key
|
||||
provider = provider_class(api_key=api_key)
|
||||
|
||||
# Cache the instance
|
||||
cls._initialized_providers[provider_type] = provider
|
||||
instance._initialized_providers[provider_type] = provider
|
||||
|
||||
return provider
|
||||
|
||||
@@ -66,25 +96,55 @@ class ModelProviderRegistry:
|
||||
def get_provider_for_model(cls, model_name: str) -> Optional[ModelProvider]:
|
||||
"""Get provider instance for a specific model name.
|
||||
|
||||
Provider priority order:
|
||||
1. Native APIs (GOOGLE, OPENAI) - Most direct and efficient
|
||||
2. CUSTOM - For local/private models with specific endpoints
|
||||
3. OPENROUTER - Catch-all for cloud models via unified API
|
||||
|
||||
Args:
|
||||
model_name: Name of the model (e.g., "gemini-2.5-flash-preview-05-20", "o3-mini")
|
||||
|
||||
Returns:
|
||||
ModelProvider instance that supports this model
|
||||
"""
|
||||
# Check each registered provider
|
||||
for provider_type, _provider_class in cls._providers.items():
|
||||
# Get or create provider instance
|
||||
provider = cls.get_provider(provider_type)
|
||||
if provider and provider.validate_model_name(model_name):
|
||||
return provider
|
||||
logging.debug(f"get_provider_for_model called with model_name='{model_name}'")
|
||||
|
||||
# Define explicit provider priority order
|
||||
# Native APIs first, then custom endpoints, then catch-all providers
|
||||
PROVIDER_PRIORITY_ORDER = [
|
||||
ProviderType.GOOGLE, # Direct Gemini access
|
||||
ProviderType.OPENAI, # Direct OpenAI access
|
||||
ProviderType.CUSTOM, # Local/self-hosted models
|
||||
ProviderType.OPENROUTER, # Catch-all for cloud models
|
||||
]
|
||||
|
||||
# Check providers in priority order
|
||||
instance = cls()
|
||||
logging.debug(f"Registry instance: {instance}")
|
||||
logging.debug(f"Available providers in registry: {list(instance._providers.keys())}")
|
||||
|
||||
for provider_type in PROVIDER_PRIORITY_ORDER:
|
||||
logging.debug(f"Checking provider_type: {provider_type}")
|
||||
if provider_type in instance._providers:
|
||||
logging.debug(f"Found {provider_type} in registry")
|
||||
# Get or create provider instance
|
||||
provider = cls.get_provider(provider_type)
|
||||
if provider and provider.validate_model_name(model_name):
|
||||
logging.debug(f"{provider_type} validates model {model_name}")
|
||||
return provider
|
||||
else:
|
||||
logging.debug(f"{provider_type} does not validate model {model_name}")
|
||||
else:
|
||||
logging.debug(f"{provider_type} not found in registry")
|
||||
|
||||
logging.debug(f"No provider found for model {model_name}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_available_providers(cls) -> list[ProviderType]:
|
||||
"""Get list of registered provider types."""
|
||||
return list(cls._providers.keys())
|
||||
instance = cls()
|
||||
return list(instance._providers.keys())
|
||||
|
||||
@classmethod
|
||||
def get_available_models(cls) -> dict[str, ProviderType]:
|
||||
@@ -94,8 +154,9 @@ class ModelProviderRegistry:
|
||||
Dict mapping model names to provider types
|
||||
"""
|
||||
models = {}
|
||||
instance = cls()
|
||||
|
||||
for provider_type in cls._providers:
|
||||
for provider_type in instance._providers:
|
||||
provider = cls.get_provider(provider_type)
|
||||
if provider:
|
||||
# This assumes providers have a method to list supported models
|
||||
@@ -118,6 +179,7 @@ class ModelProviderRegistry:
|
||||
ProviderType.GOOGLE: "GEMINI_API_KEY",
|
||||
ProviderType.OPENAI: "OPENAI_API_KEY",
|
||||
ProviderType.OPENROUTER: "OPENROUTER_API_KEY",
|
||||
ProviderType.CUSTOM: "CUSTOM_API_KEY", # Can be empty for providers that don't need auth
|
||||
}
|
||||
|
||||
env_var = key_mapping.get(provider_type)
|
||||
@@ -165,7 +227,8 @@ class ModelProviderRegistry:
|
||||
List of ProviderType values for providers with valid API keys
|
||||
"""
|
||||
available = []
|
||||
for provider_type in cls._providers:
|
||||
instance = cls()
|
||||
for provider_type in instance._providers:
|
||||
if cls.get_provider(provider_type) is not None:
|
||||
available.append(provider_type)
|
||||
return available
|
||||
@@ -173,10 +236,12 @@ class ModelProviderRegistry:
|
||||
@classmethod
|
||||
def clear_cache(cls) -> None:
|
||||
"""Clear cached provider instances."""
|
||||
cls._initialized_providers.clear()
|
||||
instance = cls()
|
||||
instance._initialized_providers.clear()
|
||||
|
||||
@classmethod
|
||||
def unregister_provider(cls, provider_type: ProviderType) -> None:
|
||||
"""Unregister a provider (mainly for testing)."""
|
||||
cls._providers.pop(provider_type, None)
|
||||
cls._initialized_providers.pop(provider_type, None)
|
||||
instance = cls()
|
||||
instance._providers.pop(provider_type, None)
|
||||
instance._initialized_providers.pop(provider_type, None)
|
||||
|
||||
Reference in New Issue
Block a user