Rebranding, refactoring, renaming, cleanup, updated docs

This commit is contained in:
Fahad
2025-06-12 10:40:43 +04:00
parent 9a55ca8898
commit fb66825bf6
55 changed files with 1048 additions and 1474 deletions

View File

@@ -1,9 +1,9 @@
"""Model provider abstractions for supporting multiple AI providers."""
from .base import ModelProvider, ModelResponse, ModelCapabilities
from .registry import ModelProviderRegistry
from .base import ModelCapabilities, ModelProvider, ModelResponse
from .gemini import GeminiModelProvider
from .openai import OpenAIModelProvider
from .registry import ModelProviderRegistry
__all__ = [
"ModelProvider",
@@ -12,4 +12,4 @@ __all__ = [
"ModelProviderRegistry",
"GeminiModelProvider",
"OpenAIModelProvider",
]
]

View File

@@ -2,34 +2,35 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any, Tuple
from enum import Enum
from typing import Any, Optional
class ProviderType(Enum):
"""Supported model provider types."""
GOOGLE = "google"
OPENAI = "openai"
class TemperatureConstraint(ABC):
"""Abstract base class for temperature constraints."""
@abstractmethod
def validate(self, temperature: float) -> bool:
"""Check if temperature is valid."""
pass
@abstractmethod
def get_corrected_value(self, temperature: float) -> float:
"""Get nearest valid temperature."""
pass
@abstractmethod
def get_description(self) -> str:
"""Get human-readable description of constraint."""
pass
@abstractmethod
def get_default(self) -> float:
"""Get model's default temperature."""
@@ -38,60 +39,60 @@ class TemperatureConstraint(ABC):
class FixedTemperatureConstraint(TemperatureConstraint):
"""For models that only support one temperature value (e.g., O3)."""
def __init__(self, value: float):
self.value = value
def validate(self, temperature: float) -> bool:
return abs(temperature - self.value) < 1e-6 # Handle floating point precision
def get_corrected_value(self, temperature: float) -> float:
return self.value
def get_description(self) -> str:
return f"Only supports temperature={self.value}"
def get_default(self) -> float:
return self.value
class RangeTemperatureConstraint(TemperatureConstraint):
"""For models supporting continuous temperature ranges."""
def __init__(self, min_temp: float, max_temp: float, default: float = None):
self.min_temp = min_temp
self.max_temp = max_temp
self.default_temp = default or (min_temp + max_temp) / 2
def validate(self, temperature: float) -> bool:
return self.min_temp <= temperature <= self.max_temp
def get_corrected_value(self, temperature: float) -> float:
return max(self.min_temp, min(self.max_temp, temperature))
def get_description(self) -> str:
return f"Supports temperature range [{self.min_temp}, {self.max_temp}]"
def get_default(self) -> float:
return self.default_temp
class DiscreteTemperatureConstraint(TemperatureConstraint):
"""For models supporting only specific temperature values."""
def __init__(self, allowed_values: List[float], default: float = None):
def __init__(self, allowed_values: list[float], default: float = None):
self.allowed_values = sorted(allowed_values)
self.default_temp = default or allowed_values[len(allowed_values)//2]
self.default_temp = default or allowed_values[len(allowed_values) // 2]
def validate(self, temperature: float) -> bool:
return any(abs(temperature - val) < 1e-6 for val in self.allowed_values)
def get_corrected_value(self, temperature: float) -> float:
return min(self.allowed_values, key=lambda x: abs(x - temperature))
def get_description(self) -> str:
return f"Supports temperatures: {self.allowed_values}"
def get_default(self) -> float:
return self.default_temp
@@ -99,6 +100,7 @@ class DiscreteTemperatureConstraint(TemperatureConstraint):
@dataclass
class ModelCapabilities:
"""Capabilities and constraints for a specific model."""
provider: ProviderType
model_name: str
friendly_name: str # Human-friendly name like "Gemini" or "OpenAI"
@@ -107,15 +109,15 @@ class ModelCapabilities:
supports_system_prompts: bool = True
supports_streaming: bool = True
supports_function_calling: bool = False
# Temperature constraint object - preferred way to define temperature limits
temperature_constraint: TemperatureConstraint = field(
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.7)
)
# Backward compatibility property for existing code
@property
def temperature_range(self) -> Tuple[float, float]:
def temperature_range(self) -> tuple[float, float]:
"""Backward compatibility for existing code that uses temperature_range."""
if isinstance(self.temperature_constraint, RangeTemperatureConstraint):
return (self.temperature_constraint.min_temp, self.temperature_constraint.max_temp)
@@ -130,13 +132,14 @@ class ModelCapabilities:
@dataclass
class ModelResponse:
"""Response from a model provider."""
content: str
usage: Dict[str, int] = field(default_factory=dict) # input_tokens, output_tokens, total_tokens
usage: dict[str, int] = field(default_factory=dict) # input_tokens, output_tokens, total_tokens
model_name: str = ""
friendly_name: str = "" # Human-friendly name like "Gemini" or "OpenAI"
provider: ProviderType = ProviderType.GOOGLE
metadata: Dict[str, Any] = field(default_factory=dict) # Provider-specific metadata
metadata: dict[str, Any] = field(default_factory=dict) # Provider-specific metadata
@property
def total_tokens(self) -> int:
"""Get total tokens used."""
@@ -145,17 +148,17 @@ class ModelResponse:
class ModelProvider(ABC):
"""Abstract base class for model providers."""
def __init__(self, api_key: str, **kwargs):
"""Initialize the provider with API key and optional configuration."""
self.api_key = api_key
self.config = kwargs
@abstractmethod
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific model."""
pass
@abstractmethod
def generate_content(
self,
@@ -164,10 +167,10 @@ class ModelProvider(ABC):
system_prompt: Optional[str] = None,
temperature: float = 0.7,
max_output_tokens: Optional[int] = None,
**kwargs
**kwargs,
) -> ModelResponse:
"""Generate content using the model.
Args:
prompt: User prompt to send to the model
model_name: Name of the model to use
@@ -175,49 +178,43 @@ class ModelProvider(ABC):
temperature: Sampling temperature (0-2)
max_output_tokens: Maximum tokens to generate
**kwargs: Provider-specific parameters
Returns:
ModelResponse with generated content and metadata
"""
pass
@abstractmethod
def count_tokens(self, text: str, model_name: str) -> int:
"""Count tokens for the given text using the specified model's tokenizer."""
pass
@abstractmethod
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
pass
@abstractmethod
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported by this provider."""
pass
def validate_parameters(
self,
model_name: str,
temperature: float,
**kwargs
) -> None:
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
"""Validate model parameters against capabilities.
Raises:
ValueError: If parameters are invalid
"""
capabilities = self.get_capabilities(model_name)
# Validate temperature
min_temp, max_temp = capabilities.temperature_range
if not min_temp <= temperature <= max_temp:
raise ValueError(
f"Temperature {temperature} out of range [{min_temp}, {max_temp}] "
f"for model {model_name}"
f"Temperature {temperature} out of range [{min_temp}, {max_temp}] " f"for model {model_name}"
)
@abstractmethod
def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode."""
pass
pass

View File

@@ -1,22 +1,16 @@
"""Gemini model provider implementation."""
import os
from typing import Dict, Optional, List
from typing import Optional
from google import genai
from google.genai import types
from .base import (
ModelProvider,
ModelResponse,
ModelCapabilities,
ProviderType,
RangeTemperatureConstraint
)
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, RangeTemperatureConstraint
class GeminiModelProvider(ModelProvider):
"""Google Gemini model provider implementation."""
# Model configurations
SUPPORTED_MODELS = {
"gemini-2.0-flash-exp": {
@@ -31,42 +25,42 @@ class GeminiModelProvider(ModelProvider):
"flash": "gemini-2.0-flash-exp",
"pro": "gemini-2.5-pro-preview-06-05",
}
# Thinking mode configurations for models that support it
THINKING_BUDGETS = {
"minimal": 128, # Minimum for 2.5 Pro - fast responses
"low": 2048, # Light reasoning tasks
"medium": 8192, # Balanced reasoning (default)
"high": 16384, # Complex analysis
"max": 32768, # Maximum reasoning depth
"minimal": 128, # Minimum for 2.5 Pro - fast responses
"low": 2048, # Light reasoning tasks
"medium": 8192, # Balanced reasoning (default)
"high": 16384, # Complex analysis
"max": 32768, # Maximum reasoning depth
}
def __init__(self, api_key: str, **kwargs):
"""Initialize Gemini provider with API key."""
super().__init__(api_key, **kwargs)
self._client = None
self._token_counters = {} # Cache for token counting
@property
def client(self):
"""Lazy initialization of Gemini client."""
if self._client is None:
self._client = genai.Client(api_key=self.api_key)
return self._client
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific Gemini model."""
# Resolve shorthand
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.SUPPORTED_MODELS:
raise ValueError(f"Unsupported Gemini model: {model_name}")
config = self.SUPPORTED_MODELS[resolved_name]
# Gemini models support 0.0-2.0 temperature range
temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
return ModelCapabilities(
provider=ProviderType.GOOGLE,
model_name=resolved_name,
@@ -78,7 +72,7 @@ class GeminiModelProvider(ModelProvider):
supports_function_calling=True,
temperature_constraint=temp_constraint,
)
def generate_content(
self,
prompt: str,
@@ -87,36 +81,36 @@ class GeminiModelProvider(ModelProvider):
temperature: float = 0.7,
max_output_tokens: Optional[int] = None,
thinking_mode: str = "medium",
**kwargs
**kwargs,
) -> ModelResponse:
"""Generate content using Gemini model."""
# Validate parameters
resolved_name = self._resolve_model_name(model_name)
self.validate_parameters(resolved_name, temperature)
# Combine system prompt with user prompt if provided
if system_prompt:
full_prompt = f"{system_prompt}\n\n{prompt}"
else:
full_prompt = prompt
# Prepare generation config
generation_config = types.GenerateContentConfig(
temperature=temperature,
candidate_count=1,
)
# Add max output tokens if specified
if max_output_tokens:
generation_config.max_output_tokens = max_output_tokens
# Add thinking configuration for models that support it
capabilities = self.get_capabilities(resolved_name)
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
generation_config.thinking_config = types.ThinkingConfig(
thinking_budget=self.THINKING_BUDGETS[thinking_mode]
)
try:
# Generate content
response = self.client.models.generate_content(
@@ -124,10 +118,10 @@ class GeminiModelProvider(ModelProvider):
contents=full_prompt,
config=generation_config,
)
# Extract usage information if available
usage = self._extract_usage(response)
return ModelResponse(
content=response.text,
usage=usage,
@@ -136,38 +130,40 @@ class GeminiModelProvider(ModelProvider):
provider=ProviderType.GOOGLE,
metadata={
"thinking_mode": thinking_mode if capabilities.supports_extended_thinking else None,
"finish_reason": getattr(response.candidates[0], "finish_reason", "STOP") if response.candidates else "STOP",
}
"finish_reason": (
getattr(response.candidates[0], "finish_reason", "STOP") if response.candidates else "STOP"
),
},
)
except Exception as e:
# Log error and re-raise with more context
error_msg = f"Gemini API error for model {resolved_name}: {str(e)}"
raise RuntimeError(error_msg) from e
def count_tokens(self, text: str, model_name: str) -> int:
"""Count tokens for the given text using Gemini's tokenizer."""
resolved_name = self._resolve_model_name(model_name)
self._resolve_model_name(model_name)
# For now, use a simple estimation
# TODO: Use actual Gemini tokenizer when available in SDK
# Rough estimation: ~4 characters per token for English text
return len(text) // 4
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.GOOGLE
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported."""
resolved_name = self._resolve_model_name(model_name)
return resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict)
def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode."""
capabilities = self.get_capabilities(model_name)
return capabilities.supports_extended_thinking
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model shorthand to full name."""
# Check if it's a shorthand
@@ -175,11 +171,11 @@ class GeminiModelProvider(ModelProvider):
if isinstance(shorthand_value, str):
return shorthand_value
return model_name
def _extract_usage(self, response) -> Dict[str, int]:
def _extract_usage(self, response) -> dict[str, int]:
"""Extract token usage from Gemini response."""
usage = {}
# Try to extract usage metadata from response
# Note: The actual structure depends on the SDK version and response format
if hasattr(response, "usage_metadata"):
@@ -190,5 +186,5 @@ class GeminiModelProvider(ModelProvider):
usage["output_tokens"] = metadata.candidates_token_count
if "input_tokens" in usage and "output_tokens" in usage:
usage["total_tokens"] = usage["input_tokens"] + usage["output_tokens"]
return usage
return usage

View File

@@ -1,24 +1,23 @@
"""OpenAI model provider implementation."""
import os
from typing import Dict, Optional, List, Any
import logging
from typing import Optional
from openai import OpenAI
from .base import (
ModelProvider,
ModelResponse,
ModelCapabilities,
ProviderType,
FixedTemperatureConstraint,
RangeTemperatureConstraint
ModelCapabilities,
ModelProvider,
ModelResponse,
ProviderType,
RangeTemperatureConstraint,
)
class OpenAIModelProvider(ModelProvider):
"""OpenAI model provider implementation."""
# Model configurations
SUPPORTED_MODELS = {
"o3": {
@@ -30,14 +29,14 @@ class OpenAIModelProvider(ModelProvider):
"supports_extended_thinking": False,
},
}
def __init__(self, api_key: str, **kwargs):
"""Initialize OpenAI provider with API key."""
super().__init__(api_key, **kwargs)
self._client = None
self.base_url = kwargs.get("base_url") # Support custom endpoints
self.organization = kwargs.get("organization")
@property
def client(self):
"""Lazy initialization of OpenAI client."""
@@ -47,17 +46,17 @@ class OpenAIModelProvider(ModelProvider):
client_kwargs["base_url"] = self.base_url
if self.organization:
client_kwargs["organization"] = self.organization
self._client = OpenAI(**client_kwargs)
return self._client
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific OpenAI model."""
if model_name not in self.SUPPORTED_MODELS:
raise ValueError(f"Unsupported OpenAI model: {model_name}")
config = self.SUPPORTED_MODELS[model_name]
# Define temperature constraints per model
if model_name in ["o3", "o3-mini"]:
# O3 models only support temperature=1.0
@@ -65,7 +64,7 @@ class OpenAIModelProvider(ModelProvider):
else:
# Other OpenAI models support 0.0-2.0 range
temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
return ModelCapabilities(
provider=ProviderType.OPENAI,
model_name=model_name,
@@ -77,7 +76,7 @@ class OpenAIModelProvider(ModelProvider):
supports_function_calling=True,
temperature_constraint=temp_constraint,
)
def generate_content(
self,
prompt: str,
@@ -85,42 +84,42 @@ class OpenAIModelProvider(ModelProvider):
system_prompt: Optional[str] = None,
temperature: float = 0.7,
max_output_tokens: Optional[int] = None,
**kwargs
**kwargs,
) -> ModelResponse:
"""Generate content using OpenAI model."""
# Validate parameters
self.validate_parameters(model_name, temperature)
# Prepare messages
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
# Prepare completion parameters
completion_params = {
"model": model_name,
"messages": messages,
"temperature": temperature,
}
# Add max tokens if specified
if max_output_tokens:
completion_params["max_tokens"] = max_output_tokens
# Add any additional OpenAI-specific parameters
for key, value in kwargs.items():
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop"]:
completion_params[key] = value
try:
# Generate completion
response = self.client.chat.completions.create(**completion_params)
# Extract content and usage
content = response.choices[0].message.content
usage = self._extract_usage(response)
return ModelResponse(
content=content,
usage=usage,
@@ -132,18 +131,18 @@ class OpenAIModelProvider(ModelProvider):
"model": response.model, # Actual model used (in case of fallbacks)
"id": response.id,
"created": response.created,
}
},
)
except Exception as e:
# Log error and re-raise with more context
error_msg = f"OpenAI API error for model {model_name}: {str(e)}"
logging.error(error_msg)
raise RuntimeError(error_msg) from e
def count_tokens(self, text: str, model_name: str) -> int:
"""Count tokens for the given text.
Note: For accurate token counting, we should use tiktoken library.
This is a simplified estimation.
"""
@@ -151,28 +150,28 @@ class OpenAIModelProvider(ModelProvider):
# For now, use rough estimation
# O3 models ~4 chars per token
return len(text) // 4
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.OPENAI
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported."""
return model_name in self.SUPPORTED_MODELS
def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode."""
# Currently no OpenAI models support extended thinking
# This may change with future O3 models
return False
def _extract_usage(self, response) -> Dict[str, int]:
def _extract_usage(self, response) -> dict[str, int]:
"""Extract token usage from OpenAI response."""
usage = {}
if hasattr(response, "usage") and response.usage:
usage["input_tokens"] = response.usage.prompt_tokens
usage["output_tokens"] = response.usage.completion_tokens
usage["total_tokens"] = response.usage.total_tokens
return usage
return usage

View File

@@ -1,115 +1,116 @@
"""Model provider registry for managing available providers."""
import os
from typing import Dict, Optional, Type, List
from typing import Optional
from .base import ModelProvider, ProviderType
class ModelProviderRegistry:
"""Registry for managing model providers."""
_instance = None
_providers: Dict[ProviderType, Type[ModelProvider]] = {}
_initialized_providers: Dict[ProviderType, ModelProvider] = {}
_providers: dict[ProviderType, type[ModelProvider]] = {}
_initialized_providers: dict[ProviderType, ModelProvider] = {}
def __new__(cls):
"""Singleton pattern for registry."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
@classmethod
def register_provider(cls, provider_type: ProviderType, provider_class: Type[ModelProvider]) -> None:
def register_provider(cls, provider_type: ProviderType, provider_class: type[ModelProvider]) -> None:
"""Register a new provider class.
Args:
provider_type: Type of the provider (e.g., ProviderType.GOOGLE)
provider_class: Class that implements ModelProvider interface
"""
cls._providers[provider_type] = provider_class
@classmethod
def get_provider(cls, provider_type: ProviderType, force_new: bool = False) -> Optional[ModelProvider]:
"""Get an initialized provider instance.
Args:
provider_type: Type of provider to get
force_new: Force creation of new instance instead of using cached
Returns:
Initialized ModelProvider instance or None if not available
"""
# Return cached instance if available and not forcing new
if not force_new and provider_type in cls._initialized_providers:
return cls._initialized_providers[provider_type]
# Check if provider class is registered
if provider_type not in cls._providers:
return None
# Get API key from environment
api_key = cls._get_api_key_for_provider(provider_type)
if not api_key:
return None
# Initialize provider
provider_class = cls._providers[provider_type]
provider = provider_class(api_key=api_key)
# Cache the instance
cls._initialized_providers[provider_type] = provider
return provider
@classmethod
def get_provider_for_model(cls, model_name: str) -> Optional[ModelProvider]:
"""Get provider instance for a specific model name.
Args:
model_name: Name of the model (e.g., "gemini-2.0-flash-exp", "o3-mini")
Returns:
ModelProvider instance that supports this model
"""
# Check each registered provider
for provider_type, provider_class in cls._providers.items():
for provider_type, _provider_class in cls._providers.items():
# Get or create provider instance
provider = cls.get_provider(provider_type)
if provider and provider.validate_model_name(model_name):
return provider
return None
@classmethod
def get_available_providers(cls) -> List[ProviderType]:
def get_available_providers(cls) -> list[ProviderType]:
"""Get list of registered provider types."""
return list(cls._providers.keys())
@classmethod
def get_available_models(cls) -> Dict[str, ProviderType]:
def get_available_models(cls) -> dict[str, ProviderType]:
"""Get mapping of all available models to their providers.
Returns:
Dict mapping model names to provider types
"""
models = {}
for provider_type in cls._providers:
provider = cls.get_provider(provider_type)
if provider:
# This assumes providers have a method to list supported models
# We'll need to add this to the interface
pass
return models
@classmethod
def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]:
"""Get API key for a provider from environment variables.
Args:
provider_type: Provider type to get API key for
Returns:
API key string or None if not found
"""
@@ -117,20 +118,20 @@ class ModelProviderRegistry:
ProviderType.GOOGLE: "GEMINI_API_KEY",
ProviderType.OPENAI: "OPENAI_API_KEY",
}
env_var = key_mapping.get(provider_type)
if not env_var:
return None
return os.getenv(env_var)
@classmethod
def clear_cache(cls) -> None:
"""Clear cached provider instances."""
cls._initialized_providers.clear()
@classmethod
def unregister_provider(cls, provider_type: ProviderType) -> None:
"""Unregister a provider (mainly for testing)."""
cls._providers.pop(provider_type, None)
cls._initialized_providers.pop(provider_type, None)
cls._initialized_providers.pop(provider_type, None)