WIP major refactor and features
This commit is contained in:
136
providers/registry.py
Normal file
136
providers/registry.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""Model provider registry for managing available providers."""
|
||||
|
||||
import os
|
||||
from typing import Dict, Optional, Type, List
|
||||
from .base import ModelProvider, ProviderType
|
||||
|
||||
|
||||
class ModelProviderRegistry:
|
||||
"""Registry for managing model providers."""
|
||||
|
||||
_instance = None
|
||||
_providers: Dict[ProviderType, Type[ModelProvider]] = {}
|
||||
_initialized_providers: Dict[ProviderType, ModelProvider] = {}
|
||||
|
||||
def __new__(cls):
|
||||
"""Singleton pattern for registry."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def register_provider(cls, provider_type: ProviderType, provider_class: Type[ModelProvider]) -> None:
|
||||
"""Register a new provider class.
|
||||
|
||||
Args:
|
||||
provider_type: Type of the provider (e.g., ProviderType.GOOGLE)
|
||||
provider_class: Class that implements ModelProvider interface
|
||||
"""
|
||||
cls._providers[provider_type] = provider_class
|
||||
|
||||
@classmethod
|
||||
def get_provider(cls, provider_type: ProviderType, force_new: bool = False) -> Optional[ModelProvider]:
|
||||
"""Get an initialized provider instance.
|
||||
|
||||
Args:
|
||||
provider_type: Type of provider to get
|
||||
force_new: Force creation of new instance instead of using cached
|
||||
|
||||
Returns:
|
||||
Initialized ModelProvider instance or None if not available
|
||||
"""
|
||||
# Return cached instance if available and not forcing new
|
||||
if not force_new and provider_type in cls._initialized_providers:
|
||||
return cls._initialized_providers[provider_type]
|
||||
|
||||
# Check if provider class is registered
|
||||
if provider_type not in cls._providers:
|
||||
return None
|
||||
|
||||
# Get API key from environment
|
||||
api_key = cls._get_api_key_for_provider(provider_type)
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
# Initialize provider
|
||||
provider_class = cls._providers[provider_type]
|
||||
provider = provider_class(api_key=api_key)
|
||||
|
||||
# Cache the instance
|
||||
cls._initialized_providers[provider_type] = provider
|
||||
|
||||
return provider
|
||||
|
||||
@classmethod
|
||||
def get_provider_for_model(cls, model_name: str) -> Optional[ModelProvider]:
|
||||
"""Get provider instance for a specific model name.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model (e.g., "gemini-2.0-flash-exp", "o3-mini")
|
||||
|
||||
Returns:
|
||||
ModelProvider instance that supports this model
|
||||
"""
|
||||
# Check each registered provider
|
||||
for provider_type, provider_class in cls._providers.items():
|
||||
# Get or create provider instance
|
||||
provider = cls.get_provider(provider_type)
|
||||
if provider and provider.validate_model_name(model_name):
|
||||
return provider
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_available_providers(cls) -> List[ProviderType]:
|
||||
"""Get list of registered provider types."""
|
||||
return list(cls._providers.keys())
|
||||
|
||||
@classmethod
|
||||
def get_available_models(cls) -> Dict[str, ProviderType]:
|
||||
"""Get mapping of all available models to their providers.
|
||||
|
||||
Returns:
|
||||
Dict mapping model names to provider types
|
||||
"""
|
||||
models = {}
|
||||
|
||||
for provider_type in cls._providers:
|
||||
provider = cls.get_provider(provider_type)
|
||||
if provider:
|
||||
# This assumes providers have a method to list supported models
|
||||
# We'll need to add this to the interface
|
||||
pass
|
||||
|
||||
return models
|
||||
|
||||
@classmethod
|
||||
def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]:
|
||||
"""Get API key for a provider from environment variables.
|
||||
|
||||
Args:
|
||||
provider_type: Provider type to get API key for
|
||||
|
||||
Returns:
|
||||
API key string or None if not found
|
||||
"""
|
||||
key_mapping = {
|
||||
ProviderType.GOOGLE: "GEMINI_API_KEY",
|
||||
ProviderType.OPENAI: "OPENAI_API_KEY",
|
||||
}
|
||||
|
||||
env_var = key_mapping.get(provider_type)
|
||||
if not env_var:
|
||||
return None
|
||||
|
||||
return os.getenv(env_var)
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls) -> None:
|
||||
"""Clear cached provider instances."""
|
||||
cls._initialized_providers.clear()
|
||||
|
||||
@classmethod
|
||||
def unregister_provider(cls, provider_type: ProviderType) -> None:
|
||||
"""Unregister a provider (mainly for testing)."""
|
||||
cls._providers.pop(provider_type, None)
|
||||
cls._initialized_providers.pop(provider_type, None)
|
||||
Reference in New Issue
Block a user