Breaking change: openrouter_models.json -> custom_models.json

* Support for Custom URLs and custom models, including locally hosted models such as ollama
* Support for native + openrouter + local models (i.e. dozens of models) means you can start delegating sub-tasks to particular models or work to local models such as localizations or other boring work etc.
* Several tests added
* precommit to also include untracked (new) files
* Logfile auto rollover
* Improved logging
This commit is contained in:
Fahad
2025-06-13 15:22:09 +04:00
parent f5fdf7b2ed
commit f44ca326ef
27 changed files with 1692 additions and 351 deletions

View File

@@ -1,5 +1,6 @@
"""Model provider registry for managing available providers."""
import logging
import os
from typing import Optional
@@ -10,13 +11,18 @@ class ModelProviderRegistry:
"""Registry for managing model providers."""
_instance = None
_providers: dict[ProviderType, type[ModelProvider]] = {}
_initialized_providers: dict[ProviderType, ModelProvider] = {}
def __new__(cls):
"""Singleton pattern for registry."""
if cls._instance is None:
logging.debug("REGISTRY: Creating new registry instance")
cls._instance = super().__new__(cls)
# Initialize instance dictionaries on first creation
cls._instance._providers = {}
cls._instance._initialized_providers = {}
logging.debug(f"REGISTRY: Created instance {cls._instance}")
else:
logging.debug(f"REGISTRY: Returning existing instance {cls._instance}")
return cls._instance
@classmethod
@@ -27,7 +33,8 @@ class ModelProviderRegistry:
provider_type: Type of the provider (e.g., ProviderType.GOOGLE)
provider_class: Class that implements ModelProvider interface
"""
cls._providers[provider_type] = provider_class
instance = cls()
instance._providers[provider_type] = provider_class
@classmethod
def get_provider(cls, provider_type: ProviderType, force_new: bool = False) -> Optional[ModelProvider]:
@@ -40,25 +47,48 @@ class ModelProviderRegistry:
Returns:
Initialized ModelProvider instance or None if not available
"""
instance = cls()
# Return cached instance if available and not forcing new
if not force_new and provider_type in cls._initialized_providers:
return cls._initialized_providers[provider_type]
if not force_new and provider_type in instance._initialized_providers:
return instance._initialized_providers[provider_type]
# Check if provider class is registered
if provider_type not in cls._providers:
if provider_type not in instance._providers:
return None
# Get API key from environment
api_key = cls._get_api_key_for_provider(provider_type)
if not api_key:
return None
# Initialize provider
provider_class = cls._providers[provider_type]
provider = provider_class(api_key=api_key)
# Get provider class or factory function
provider_class = instance._providers[provider_type]
# For custom providers, handle special initialization requirements
if provider_type == ProviderType.CUSTOM:
# Check if it's a factory function (callable but not a class)
if callable(provider_class) and not isinstance(provider_class, type):
# Factory function - call it with api_key parameter
provider = provider_class(api_key=api_key)
else:
# Regular class - need to handle URL requirement
custom_url = os.getenv("CUSTOM_API_URL", "")
if not custom_url:
if api_key: # Key is set but URL is missing
logging.warning("CUSTOM_API_KEY set but CUSTOM_API_URL missing skipping Custom provider")
return None
# Use empty string as API key for custom providers that don't need auth (e.g., Ollama)
# This allows the provider to be created even without CUSTOM_API_KEY being set
api_key = api_key or ""
# Initialize custom provider with both API key and base URL
provider = provider_class(api_key=api_key, base_url=custom_url)
else:
if not api_key:
return None
# Initialize non-custom provider with just API key
provider = provider_class(api_key=api_key)
# Cache the instance
cls._initialized_providers[provider_type] = provider
instance._initialized_providers[provider_type] = provider
return provider
@@ -66,25 +96,55 @@ class ModelProviderRegistry:
def get_provider_for_model(cls, model_name: str) -> Optional[ModelProvider]:
"""Get provider instance for a specific model name.
Provider priority order:
1. Native APIs (GOOGLE, OPENAI) - Most direct and efficient
2. CUSTOM - For local/private models with specific endpoints
3. OPENROUTER - Catch-all for cloud models via unified API
Args:
model_name: Name of the model (e.g., "gemini-2.5-flash-preview-05-20", "o3-mini")
Returns:
ModelProvider instance that supports this model
"""
# Check each registered provider
for provider_type, _provider_class in cls._providers.items():
# Get or create provider instance
provider = cls.get_provider(provider_type)
if provider and provider.validate_model_name(model_name):
return provider
logging.debug(f"get_provider_for_model called with model_name='{model_name}'")
# Define explicit provider priority order
# Native APIs first, then custom endpoints, then catch-all providers
PROVIDER_PRIORITY_ORDER = [
ProviderType.GOOGLE, # Direct Gemini access
ProviderType.OPENAI, # Direct OpenAI access
ProviderType.CUSTOM, # Local/self-hosted models
ProviderType.OPENROUTER, # Catch-all for cloud models
]
# Check providers in priority order
instance = cls()
logging.debug(f"Registry instance: {instance}")
logging.debug(f"Available providers in registry: {list(instance._providers.keys())}")
for provider_type in PROVIDER_PRIORITY_ORDER:
logging.debug(f"Checking provider_type: {provider_type}")
if provider_type in instance._providers:
logging.debug(f"Found {provider_type} in registry")
# Get or create provider instance
provider = cls.get_provider(provider_type)
if provider and provider.validate_model_name(model_name):
logging.debug(f"{provider_type} validates model {model_name}")
return provider
else:
logging.debug(f"{provider_type} does not validate model {model_name}")
else:
logging.debug(f"{provider_type} not found in registry")
logging.debug(f"No provider found for model {model_name}")
return None
@classmethod
def get_available_providers(cls) -> list[ProviderType]:
"""Get list of registered provider types."""
return list(cls._providers.keys())
instance = cls()
return list(instance._providers.keys())
@classmethod
def get_available_models(cls) -> dict[str, ProviderType]:
@@ -94,8 +154,9 @@ class ModelProviderRegistry:
Dict mapping model names to provider types
"""
models = {}
instance = cls()
for provider_type in cls._providers:
for provider_type in instance._providers:
provider = cls.get_provider(provider_type)
if provider:
# This assumes providers have a method to list supported models
@@ -118,6 +179,7 @@ class ModelProviderRegistry:
ProviderType.GOOGLE: "GEMINI_API_KEY",
ProviderType.OPENAI: "OPENAI_API_KEY",
ProviderType.OPENROUTER: "OPENROUTER_API_KEY",
ProviderType.CUSTOM: "CUSTOM_API_KEY", # Can be empty for providers that don't need auth
}
env_var = key_mapping.get(provider_type)
@@ -165,7 +227,8 @@ class ModelProviderRegistry:
List of ProviderType values for providers with valid API keys
"""
available = []
for provider_type in cls._providers:
instance = cls()
for provider_type in instance._providers:
if cls.get_provider(provider_type) is not None:
available.append(provider_type)
return available
@@ -173,10 +236,12 @@ class ModelProviderRegistry:
@classmethod
def clear_cache(cls) -> None:
"""Clear cached provider instances."""
cls._initialized_providers.clear()
instance = cls()
instance._initialized_providers.clear()
@classmethod
def unregister_provider(cls, provider_type: ProviderType) -> None:
"""Unregister a provider (mainly for testing)."""
cls._providers.pop(provider_type, None)
cls._initialized_providers.pop(provider_type, None)
instance = cls()
instance._providers.pop(provider_type, None)
instance._initialized_providers.pop(provider_type, None)