Rebranding, refactoring, renaming, cleanup, updated docs
This commit is contained in:
@@ -2,34 +2,35 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class ProviderType(Enum):
|
||||
"""Supported model provider types."""
|
||||
|
||||
GOOGLE = "google"
|
||||
OPENAI = "openai"
|
||||
|
||||
|
||||
class TemperatureConstraint(ABC):
|
||||
"""Abstract base class for temperature constraints."""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def validate(self, temperature: float) -> bool:
|
||||
"""Check if temperature is valid."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
"""Get nearest valid temperature."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_description(self) -> str:
|
||||
"""Get human-readable description of constraint."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_default(self) -> float:
|
||||
"""Get model's default temperature."""
|
||||
@@ -38,60 +39,60 @@ class TemperatureConstraint(ABC):
|
||||
|
||||
class FixedTemperatureConstraint(TemperatureConstraint):
|
||||
"""For models that only support one temperature value (e.g., O3)."""
|
||||
|
||||
|
||||
def __init__(self, value: float):
|
||||
self.value = value
|
||||
|
||||
|
||||
def validate(self, temperature: float) -> bool:
|
||||
return abs(temperature - self.value) < 1e-6 # Handle floating point precision
|
||||
|
||||
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
return self.value
|
||||
|
||||
|
||||
def get_description(self) -> str:
|
||||
return f"Only supports temperature={self.value}"
|
||||
|
||||
|
||||
def get_default(self) -> float:
|
||||
return self.value
|
||||
|
||||
|
||||
class RangeTemperatureConstraint(TemperatureConstraint):
|
||||
"""For models supporting continuous temperature ranges."""
|
||||
|
||||
|
||||
def __init__(self, min_temp: float, max_temp: float, default: float = None):
|
||||
self.min_temp = min_temp
|
||||
self.max_temp = max_temp
|
||||
self.default_temp = default or (min_temp + max_temp) / 2
|
||||
|
||||
|
||||
def validate(self, temperature: float) -> bool:
|
||||
return self.min_temp <= temperature <= self.max_temp
|
||||
|
||||
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
return max(self.min_temp, min(self.max_temp, temperature))
|
||||
|
||||
|
||||
def get_description(self) -> str:
|
||||
return f"Supports temperature range [{self.min_temp}, {self.max_temp}]"
|
||||
|
||||
|
||||
def get_default(self) -> float:
|
||||
return self.default_temp
|
||||
|
||||
|
||||
class DiscreteTemperatureConstraint(TemperatureConstraint):
|
||||
"""For models supporting only specific temperature values."""
|
||||
|
||||
def __init__(self, allowed_values: List[float], default: float = None):
|
||||
|
||||
def __init__(self, allowed_values: list[float], default: float = None):
|
||||
self.allowed_values = sorted(allowed_values)
|
||||
self.default_temp = default or allowed_values[len(allowed_values)//2]
|
||||
|
||||
self.default_temp = default or allowed_values[len(allowed_values) // 2]
|
||||
|
||||
def validate(self, temperature: float) -> bool:
|
||||
return any(abs(temperature - val) < 1e-6 for val in self.allowed_values)
|
||||
|
||||
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
return min(self.allowed_values, key=lambda x: abs(x - temperature))
|
||||
|
||||
|
||||
def get_description(self) -> str:
|
||||
return f"Supports temperatures: {self.allowed_values}"
|
||||
|
||||
|
||||
def get_default(self) -> float:
|
||||
return self.default_temp
|
||||
|
||||
@@ -99,6 +100,7 @@ class DiscreteTemperatureConstraint(TemperatureConstraint):
|
||||
@dataclass
|
||||
class ModelCapabilities:
|
||||
"""Capabilities and constraints for a specific model."""
|
||||
|
||||
provider: ProviderType
|
||||
model_name: str
|
||||
friendly_name: str # Human-friendly name like "Gemini" or "OpenAI"
|
||||
@@ -107,15 +109,15 @@ class ModelCapabilities:
|
||||
supports_system_prompts: bool = True
|
||||
supports_streaming: bool = True
|
||||
supports_function_calling: bool = False
|
||||
|
||||
|
||||
# Temperature constraint object - preferred way to define temperature limits
|
||||
temperature_constraint: TemperatureConstraint = field(
|
||||
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||
)
|
||||
|
||||
|
||||
# Backward compatibility property for existing code
|
||||
@property
|
||||
def temperature_range(self) -> Tuple[float, float]:
|
||||
def temperature_range(self) -> tuple[float, float]:
|
||||
"""Backward compatibility for existing code that uses temperature_range."""
|
||||
if isinstance(self.temperature_constraint, RangeTemperatureConstraint):
|
||||
return (self.temperature_constraint.min_temp, self.temperature_constraint.max_temp)
|
||||
@@ -130,13 +132,14 @@ class ModelCapabilities:
|
||||
@dataclass
|
||||
class ModelResponse:
|
||||
"""Response from a model provider."""
|
||||
|
||||
content: str
|
||||
usage: Dict[str, int] = field(default_factory=dict) # input_tokens, output_tokens, total_tokens
|
||||
usage: dict[str, int] = field(default_factory=dict) # input_tokens, output_tokens, total_tokens
|
||||
model_name: str = ""
|
||||
friendly_name: str = "" # Human-friendly name like "Gemini" or "OpenAI"
|
||||
provider: ProviderType = ProviderType.GOOGLE
|
||||
metadata: Dict[str, Any] = field(default_factory=dict) # Provider-specific metadata
|
||||
|
||||
metadata: dict[str, Any] = field(default_factory=dict) # Provider-specific metadata
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get total tokens used."""
|
||||
@@ -145,17 +148,17 @@ class ModelResponse:
|
||||
|
||||
class ModelProvider(ABC):
|
||||
"""Abstract base class for model providers."""
|
||||
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
"""Initialize the provider with API key and optional configuration."""
|
||||
self.api_key = api_key
|
||||
self.config = kwargs
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a specific model."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def generate_content(
|
||||
self,
|
||||
@@ -164,10 +167,10 @@ class ModelProvider(ABC):
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""Generate content using the model.
|
||||
|
||||
|
||||
Args:
|
||||
prompt: User prompt to send to the model
|
||||
model_name: Name of the model to use
|
||||
@@ -175,49 +178,43 @@ class ModelProvider(ABC):
|
||||
temperature: Sampling temperature (0-2)
|
||||
max_output_tokens: Maximum tokens to generate
|
||||
**kwargs: Provider-specific parameters
|
||||
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated content and metadata
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def count_tokens(self, text: str, model_name: str) -> int:
|
||||
"""Count tokens for the given text using the specified model's tokenizer."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is supported by this provider."""
|
||||
pass
|
||||
|
||||
def validate_parameters(
|
||||
self,
|
||||
model_name: str,
|
||||
temperature: float,
|
||||
**kwargs
|
||||
) -> None:
|
||||
|
||||
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
|
||||
"""Validate model parameters against capabilities.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If parameters are invalid
|
||||
"""
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
|
||||
|
||||
# Validate temperature
|
||||
min_temp, max_temp = capabilities.temperature_range
|
||||
if not min_temp <= temperature <= max_temp:
|
||||
raise ValueError(
|
||||
f"Temperature {temperature} out of range [{min_temp}, {max_temp}] "
|
||||
f"for model {model_name}"
|
||||
f"Temperature {temperature} out of range [{min_temp}, {max_temp}] " f"for model {model_name}"
|
||||
)
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
"""Check if the model supports extended thinking mode."""
|
||||
pass
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user