WIP major refactor and features
This commit is contained in:
15
providers/__init__.py
Normal file
15
providers/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Model provider abstractions for supporting multiple AI providers."""
|
||||
|
||||
from .base import ModelProvider, ModelResponse, ModelCapabilities
|
||||
from .registry import ModelProviderRegistry
|
||||
from .gemini import GeminiModelProvider
|
||||
from .openai import OpenAIModelProvider
|
||||
|
||||
__all__ = [
|
||||
"ModelProvider",
|
||||
"ModelResponse",
|
||||
"ModelCapabilities",
|
||||
"ModelProviderRegistry",
|
||||
"GeminiModelProvider",
|
||||
"OpenAIModelProvider",
|
||||
]
|
||||
122
providers/base.py
Normal file
122
providers/base.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Base model provider interface and data classes."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ProviderType(Enum):
|
||||
"""Supported model provider types."""
|
||||
GOOGLE = "google"
|
||||
OPENAI = "openai"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCapabilities:
|
||||
"""Capabilities and constraints for a specific model."""
|
||||
provider: ProviderType
|
||||
model_name: str
|
||||
friendly_name: str # Human-friendly name like "Gemini" or "OpenAI"
|
||||
max_tokens: int
|
||||
supports_extended_thinking: bool = False
|
||||
supports_system_prompts: bool = True
|
||||
supports_streaming: bool = True
|
||||
supports_function_calling: bool = False
|
||||
temperature_range: Tuple[float, float] = (0.0, 2.0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelResponse:
|
||||
"""Response from a model provider."""
|
||||
content: str
|
||||
usage: Dict[str, int] = field(default_factory=dict) # input_tokens, output_tokens, total_tokens
|
||||
model_name: str = ""
|
||||
friendly_name: str = "" # Human-friendly name like "Gemini" or "OpenAI"
|
||||
provider: ProviderType = ProviderType.GOOGLE
|
||||
metadata: Dict[str, Any] = field(default_factory=dict) # Provider-specific metadata
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get total tokens used."""
|
||||
return self.usage.get("total_tokens", 0)
|
||||
|
||||
|
||||
class ModelProvider(ABC):
|
||||
"""Abstract base class for model providers."""
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
"""Initialize the provider with API key and optional configuration."""
|
||||
self.api_key = api_key
|
||||
self.config = kwargs
|
||||
|
||||
@abstractmethod
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a specific model."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs
|
||||
) -> ModelResponse:
|
||||
"""Generate content using the model.
|
||||
|
||||
Args:
|
||||
prompt: User prompt to send to the model
|
||||
model_name: Name of the model to use
|
||||
system_prompt: Optional system prompt for model behavior
|
||||
temperature: Sampling temperature (0-2)
|
||||
max_output_tokens: Maximum tokens to generate
|
||||
**kwargs: Provider-specific parameters
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated content and metadata
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def count_tokens(self, text: str, model_name: str) -> int:
|
||||
"""Count tokens for the given text using the specified model's tokenizer."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is supported by this provider."""
|
||||
pass
|
||||
|
||||
def validate_parameters(
|
||||
self,
|
||||
model_name: str,
|
||||
temperature: float,
|
||||
**kwargs
|
||||
) -> None:
|
||||
"""Validate model parameters against capabilities.
|
||||
|
||||
Raises:
|
||||
ValueError: If parameters are invalid
|
||||
"""
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
|
||||
# Validate temperature
|
||||
min_temp, max_temp = capabilities.temperature_range
|
||||
if not min_temp <= temperature <= max_temp:
|
||||
raise ValueError(
|
||||
f"Temperature {temperature} out of range [{min_temp}, {max_temp}] "
|
||||
f"for model {model_name}"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
"""Check if the model supports extended thinking mode."""
|
||||
pass
|
||||
185
providers/gemini.py
Normal file
185
providers/gemini.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""Gemini model provider implementation."""
|
||||
|
||||
import os
|
||||
from typing import Dict, Optional, List
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
from .base import ModelProvider, ModelResponse, ModelCapabilities, ProviderType
|
||||
|
||||
|
||||
class GeminiModelProvider(ModelProvider):
|
||||
"""Google Gemini model provider implementation."""
|
||||
|
||||
# Model configurations
|
||||
SUPPORTED_MODELS = {
|
||||
"gemini-2.0-flash-exp": {
|
||||
"max_tokens": 1_048_576, # 1M tokens
|
||||
"supports_extended_thinking": False,
|
||||
},
|
||||
"gemini-2.5-pro-preview-06-05": {
|
||||
"max_tokens": 1_048_576, # 1M tokens
|
||||
"supports_extended_thinking": True,
|
||||
},
|
||||
# Shorthands
|
||||
"flash": "gemini-2.0-flash-exp",
|
||||
"pro": "gemini-2.5-pro-preview-06-05",
|
||||
}
|
||||
|
||||
# Thinking mode configurations for models that support it
|
||||
THINKING_BUDGETS = {
|
||||
"minimal": 128, # Minimum for 2.5 Pro - fast responses
|
||||
"low": 2048, # Light reasoning tasks
|
||||
"medium": 8192, # Balanced reasoning (default)
|
||||
"high": 16384, # Complex analysis
|
||||
"max": 32768, # Maximum reasoning depth
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
"""Initialize Gemini provider with API key."""
|
||||
super().__init__(api_key, **kwargs)
|
||||
self._client = None
|
||||
self._token_counters = {} # Cache for token counting
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Lazy initialization of Gemini client."""
|
||||
if self._client is None:
|
||||
self._client = genai.Client(api_key=self.api_key)
|
||||
return self._client
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a specific Gemini model."""
|
||||
# Resolve shorthand
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(f"Unsupported Gemini model: {model_name}")
|
||||
|
||||
config = self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
return ModelCapabilities(
|
||||
provider=ProviderType.GOOGLE,
|
||||
model_name=resolved_name,
|
||||
friendly_name="Gemini",
|
||||
max_tokens=config["max_tokens"],
|
||||
supports_extended_thinking=config["supports_extended_thinking"],
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
temperature_range=(0.0, 2.0),
|
||||
)
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
thinking_mode: str = "medium",
|
||||
**kwargs
|
||||
) -> ModelResponse:
|
||||
"""Generate content using Gemini model."""
|
||||
# Validate parameters
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
self.validate_parameters(resolved_name, temperature)
|
||||
|
||||
# Combine system prompt with user prompt if provided
|
||||
if system_prompt:
|
||||
full_prompt = f"{system_prompt}\n\n{prompt}"
|
||||
else:
|
||||
full_prompt = prompt
|
||||
|
||||
# Prepare generation config
|
||||
generation_config = types.GenerateContentConfig(
|
||||
temperature=temperature,
|
||||
candidate_count=1,
|
||||
)
|
||||
|
||||
# Add max output tokens if specified
|
||||
if max_output_tokens:
|
||||
generation_config.max_output_tokens = max_output_tokens
|
||||
|
||||
# Add thinking configuration for models that support it
|
||||
capabilities = self.get_capabilities(resolved_name)
|
||||
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
|
||||
generation_config.thinking_config = types.ThinkingConfig(
|
||||
thinking_budget=self.THINKING_BUDGETS[thinking_mode]
|
||||
)
|
||||
|
||||
try:
|
||||
# Generate content
|
||||
response = self.client.models.generate_content(
|
||||
model=resolved_name,
|
||||
contents=full_prompt,
|
||||
config=generation_config,
|
||||
)
|
||||
|
||||
# Extract usage information if available
|
||||
usage = self._extract_usage(response)
|
||||
|
||||
return ModelResponse(
|
||||
content=response.text,
|
||||
usage=usage,
|
||||
model_name=resolved_name,
|
||||
friendly_name="Gemini",
|
||||
provider=ProviderType.GOOGLE,
|
||||
metadata={
|
||||
"thinking_mode": thinking_mode if capabilities.supports_extended_thinking else None,
|
||||
"finish_reason": getattr(response.candidates[0], "finish_reason", "STOP") if response.candidates else "STOP",
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Log error and re-raise with more context
|
||||
error_msg = f"Gemini API error for model {resolved_name}: {str(e)}"
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
def count_tokens(self, text: str, model_name: str) -> int:
|
||||
"""Count tokens for the given text using Gemini's tokenizer."""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
# For now, use a simple estimation
|
||||
# TODO: Use actual Gemini tokenizer when available in SDK
|
||||
# Rough estimation: ~4 characters per token for English text
|
||||
return len(text) // 4
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
return ProviderType.GOOGLE
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is supported."""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
return resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict)
|
||||
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
"""Check if the model supports extended thinking mode."""
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
return capabilities.supports_extended_thinking
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name."""
|
||||
# Check if it's a shorthand
|
||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name.lower())
|
||||
if isinstance(shorthand_value, str):
|
||||
return shorthand_value
|
||||
return model_name
|
||||
|
||||
def _extract_usage(self, response) -> Dict[str, int]:
|
||||
"""Extract token usage from Gemini response."""
|
||||
usage = {}
|
||||
|
||||
# Try to extract usage metadata from response
|
||||
# Note: The actual structure depends on the SDK version and response format
|
||||
if hasattr(response, "usage_metadata"):
|
||||
metadata = response.usage_metadata
|
||||
if hasattr(metadata, "prompt_token_count"):
|
||||
usage["input_tokens"] = metadata.prompt_token_count
|
||||
if hasattr(metadata, "candidates_token_count"):
|
||||
usage["output_tokens"] = metadata.candidates_token_count
|
||||
if "input_tokens" in usage and "output_tokens" in usage:
|
||||
usage["total_tokens"] = usage["input_tokens"] + usage["output_tokens"]
|
||||
|
||||
return usage
|
||||
163
providers/openai.py
Normal file
163
providers/openai.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""OpenAI model provider implementation."""
|
||||
|
||||
import os
|
||||
from typing import Dict, Optional, List, Any
|
||||
import logging
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from .base import ModelProvider, ModelResponse, ModelCapabilities, ProviderType
|
||||
|
||||
|
||||
class OpenAIModelProvider(ModelProvider):
|
||||
"""OpenAI model provider implementation."""
|
||||
|
||||
# Model configurations
|
||||
SUPPORTED_MODELS = {
|
||||
"o3": {
|
||||
"max_tokens": 200_000, # 200K tokens
|
||||
"supports_extended_thinking": False,
|
||||
},
|
||||
"o3-mini": {
|
||||
"max_tokens": 200_000, # 200K tokens
|
||||
"supports_extended_thinking": False,
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
"""Initialize OpenAI provider with API key."""
|
||||
super().__init__(api_key, **kwargs)
|
||||
self._client = None
|
||||
self.base_url = kwargs.get("base_url") # Support custom endpoints
|
||||
self.organization = kwargs.get("organization")
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Lazy initialization of OpenAI client."""
|
||||
if self._client is None:
|
||||
client_kwargs = {"api_key": self.api_key}
|
||||
if self.base_url:
|
||||
client_kwargs["base_url"] = self.base_url
|
||||
if self.organization:
|
||||
client_kwargs["organization"] = self.organization
|
||||
|
||||
self._client = OpenAI(**client_kwargs)
|
||||
return self._client
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a specific OpenAI model."""
|
||||
if model_name not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(f"Unsupported OpenAI model: {model_name}")
|
||||
|
||||
config = self.SUPPORTED_MODELS[model_name]
|
||||
|
||||
return ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name=model_name,
|
||||
friendly_name="OpenAI",
|
||||
max_tokens=config["max_tokens"],
|
||||
supports_extended_thinking=config["supports_extended_thinking"],
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
temperature_range=(0.0, 2.0),
|
||||
)
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs
|
||||
) -> ModelResponse:
|
||||
"""Generate content using OpenAI model."""
|
||||
# Validate parameters
|
||||
self.validate_parameters(model_name, temperature)
|
||||
|
||||
# Prepare messages
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Prepare completion parameters
|
||||
completion_params = {
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
# Add max tokens if specified
|
||||
if max_output_tokens:
|
||||
completion_params["max_tokens"] = max_output_tokens
|
||||
|
||||
# Add any additional OpenAI-specific parameters
|
||||
for key, value in kwargs.items():
|
||||
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop"]:
|
||||
completion_params[key] = value
|
||||
|
||||
try:
|
||||
# Generate completion
|
||||
response = self.client.chat.completions.create(**completion_params)
|
||||
|
||||
# Extract content and usage
|
||||
content = response.choices[0].message.content
|
||||
usage = self._extract_usage(response)
|
||||
|
||||
return ModelResponse(
|
||||
content=content,
|
||||
usage=usage,
|
||||
model_name=model_name,
|
||||
friendly_name="OpenAI",
|
||||
provider=ProviderType.OPENAI,
|
||||
metadata={
|
||||
"finish_reason": response.choices[0].finish_reason,
|
||||
"model": response.model, # Actual model used (in case of fallbacks)
|
||||
"id": response.id,
|
||||
"created": response.created,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Log error and re-raise with more context
|
||||
error_msg = f"OpenAI API error for model {model_name}: {str(e)}"
|
||||
logging.error(error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
def count_tokens(self, text: str, model_name: str) -> int:
|
||||
"""Count tokens for the given text.
|
||||
|
||||
Note: For accurate token counting, we should use tiktoken library.
|
||||
This is a simplified estimation.
|
||||
"""
|
||||
# TODO: Implement proper token counting with tiktoken
|
||||
# For now, use rough estimation
|
||||
# O3 models ~4 chars per token
|
||||
return len(text) // 4
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
return ProviderType.OPENAI
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is supported."""
|
||||
return model_name in self.SUPPORTED_MODELS
|
||||
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
"""Check if the model supports extended thinking mode."""
|
||||
# Currently no OpenAI models support extended thinking
|
||||
# This may change with future O3 models
|
||||
return False
|
||||
|
||||
def _extract_usage(self, response) -> Dict[str, int]:
|
||||
"""Extract token usage from OpenAI response."""
|
||||
usage = {}
|
||||
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage["input_tokens"] = response.usage.prompt_tokens
|
||||
usage["output_tokens"] = response.usage.completion_tokens
|
||||
usage["total_tokens"] = response.usage.total_tokens
|
||||
|
||||
return usage
|
||||
136
providers/registry.py
Normal file
136
providers/registry.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""Model provider registry for managing available providers."""
|
||||
|
||||
import os
|
||||
from typing import Dict, Optional, Type, List
|
||||
from .base import ModelProvider, ProviderType
|
||||
|
||||
|
||||
class ModelProviderRegistry:
|
||||
"""Registry for managing model providers."""
|
||||
|
||||
_instance = None
|
||||
_providers: Dict[ProviderType, Type[ModelProvider]] = {}
|
||||
_initialized_providers: Dict[ProviderType, ModelProvider] = {}
|
||||
|
||||
def __new__(cls):
|
||||
"""Singleton pattern for registry."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def register_provider(cls, provider_type: ProviderType, provider_class: Type[ModelProvider]) -> None:
|
||||
"""Register a new provider class.
|
||||
|
||||
Args:
|
||||
provider_type: Type of the provider (e.g., ProviderType.GOOGLE)
|
||||
provider_class: Class that implements ModelProvider interface
|
||||
"""
|
||||
cls._providers[provider_type] = provider_class
|
||||
|
||||
@classmethod
|
||||
def get_provider(cls, provider_type: ProviderType, force_new: bool = False) -> Optional[ModelProvider]:
|
||||
"""Get an initialized provider instance.
|
||||
|
||||
Args:
|
||||
provider_type: Type of provider to get
|
||||
force_new: Force creation of new instance instead of using cached
|
||||
|
||||
Returns:
|
||||
Initialized ModelProvider instance or None if not available
|
||||
"""
|
||||
# Return cached instance if available and not forcing new
|
||||
if not force_new and provider_type in cls._initialized_providers:
|
||||
return cls._initialized_providers[provider_type]
|
||||
|
||||
# Check if provider class is registered
|
||||
if provider_type not in cls._providers:
|
||||
return None
|
||||
|
||||
# Get API key from environment
|
||||
api_key = cls._get_api_key_for_provider(provider_type)
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
# Initialize provider
|
||||
provider_class = cls._providers[provider_type]
|
||||
provider = provider_class(api_key=api_key)
|
||||
|
||||
# Cache the instance
|
||||
cls._initialized_providers[provider_type] = provider
|
||||
|
||||
return provider
|
||||
|
||||
@classmethod
|
||||
def get_provider_for_model(cls, model_name: str) -> Optional[ModelProvider]:
|
||||
"""Get provider instance for a specific model name.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model (e.g., "gemini-2.0-flash-exp", "o3-mini")
|
||||
|
||||
Returns:
|
||||
ModelProvider instance that supports this model
|
||||
"""
|
||||
# Check each registered provider
|
||||
for provider_type, provider_class in cls._providers.items():
|
||||
# Get or create provider instance
|
||||
provider = cls.get_provider(provider_type)
|
||||
if provider and provider.validate_model_name(model_name):
|
||||
return provider
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_available_providers(cls) -> List[ProviderType]:
|
||||
"""Get list of registered provider types."""
|
||||
return list(cls._providers.keys())
|
||||
|
||||
@classmethod
|
||||
def get_available_models(cls) -> Dict[str, ProviderType]:
|
||||
"""Get mapping of all available models to their providers.
|
||||
|
||||
Returns:
|
||||
Dict mapping model names to provider types
|
||||
"""
|
||||
models = {}
|
||||
|
||||
for provider_type in cls._providers:
|
||||
provider = cls.get_provider(provider_type)
|
||||
if provider:
|
||||
# This assumes providers have a method to list supported models
|
||||
# We'll need to add this to the interface
|
||||
pass
|
||||
|
||||
return models
|
||||
|
||||
@classmethod
|
||||
def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]:
|
||||
"""Get API key for a provider from environment variables.
|
||||
|
||||
Args:
|
||||
provider_type: Provider type to get API key for
|
||||
|
||||
Returns:
|
||||
API key string or None if not found
|
||||
"""
|
||||
key_mapping = {
|
||||
ProviderType.GOOGLE: "GEMINI_API_KEY",
|
||||
ProviderType.OPENAI: "OPENAI_API_KEY",
|
||||
}
|
||||
|
||||
env_var = key_mapping.get(provider_type)
|
||||
if not env_var:
|
||||
return None
|
||||
|
||||
return os.getenv(env_var)
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls) -> None:
|
||||
"""Clear cached provider instances."""
|
||||
cls._initialized_providers.clear()
|
||||
|
||||
@classmethod
|
||||
def unregister_provider(cls, provider_type: ProviderType) -> None:
|
||||
"""Unregister a provider (mainly for testing)."""
|
||||
cls._providers.pop(provider_type, None)
|
||||
cls._initialized_providers.pop(provider_type, None)
|
||||
Reference in New Issue
Block a user