refactor: renaming to reflect underlying type

docs: updated to reflect new modules
This commit is contained in:
Fahad
2025-10-02 09:07:40 +04:00
parent 2b10adcaf2
commit 1dc25f6c3d
18 changed files with 129 additions and 131 deletions

View File

@@ -28,7 +28,7 @@ Each provider:
### 1. Add Provider Type
Add your provider to `ProviderType` enum in `providers/base.py`:
Add your provider to the `ProviderType` enum in `providers/shared/provider_type.py`:
```python
class ProviderType(Enum):
@@ -48,15 +48,23 @@ Create `providers/example.py`:
import logging
from typing import Optional
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, RangeTemperatureConstraint
from .base import ModelProvider
from .shared import (
ModelCapabilities,
ModelResponse,
ProviderType,
RangeTemperatureConstraint,
)
logger = logging.getLogger(__name__)
class ExampleModelProvider(ModelProvider):
"""Example model provider implementation."""
# Define models using ModelCapabilities objects (like Gemini provider)
SUPPORTED_MODELS = {
MODEL_CAPABILITIES = {
"example-large": ModelCapabilities(
provider=ProviderType.EXAMPLE,
model_name="example-large",
@@ -87,7 +95,7 @@ class ExampleModelProvider(ModelProvider):
def get_capabilities(self, model_name: str) -> ModelCapabilities:
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.SUPPORTED_MODELS:
if resolved_name not in self.MODEL_CAPABILITIES:
raise ValueError(f"Unsupported model: {model_name}")
# Apply restrictions if needed
@@ -96,7 +104,7 @@ class ExampleModelProvider(ModelProvider):
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
raise ValueError(f"Model '{model_name}' is not allowed.")
return self.SUPPORTED_MODELS[resolved_name]
return self.MODEL_CAPABILITIES[resolved_name]
def generate_content(self, prompt: str, model_name: str, system_prompt: Optional[str] = None,
temperature: float = 0.7, max_output_tokens: Optional[int] = None, **kwargs) -> ModelResponse:
@@ -121,7 +129,7 @@ class ExampleModelProvider(ModelProvider):
def validate_model_name(self, model_name: str) -> bool:
resolved_name = self._resolve_model_name(model_name)
return resolved_name in self.SUPPORTED_MODELS
return resolved_name in self.MODEL_CAPABILITIES
def supports_thinking_mode(self, model_name: str) -> bool:
capabilities = self.get_capabilities(model_name)
@@ -136,8 +144,15 @@ For OpenAI-compatible APIs:
"""Example OpenAI-compatible provider."""
from typing import Optional
from .base import ModelCapabilities, ModelResponse, ProviderType, RangeTemperatureConstraint
from .openai_compatible import OpenAICompatibleProvider
from .shared import (
ModelCapabilities,
ModelResponse,
ProviderType,
RangeTemperatureConstraint,
)
class ExampleProvider(OpenAICompatibleProvider):
"""Example OpenAI-compatible provider."""
@@ -145,7 +160,7 @@ class ExampleProvider(OpenAICompatibleProvider):
FRIENDLY_NAME = "Example"
# Define models using ModelCapabilities (consistent with other providers)
SUPPORTED_MODELS = {
MODEL_CAPABILITIES = {
"example-model-large": ModelCapabilities(
provider=ProviderType.EXAMPLE,
model_name="example-model-large",
@@ -163,16 +178,16 @@ class ExampleProvider(OpenAICompatibleProvider):
def get_capabilities(self, model_name: str) -> ModelCapabilities:
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.SUPPORTED_MODELS:
if resolved_name not in self.MODEL_CAPABILITIES:
raise ValueError(f"Unsupported model: {model_name}")
return self.SUPPORTED_MODELS[resolved_name]
return self.MODEL_CAPABILITIES[resolved_name]
def get_provider_type(self) -> ProviderType:
return ProviderType.EXAMPLE
def validate_model_name(self, model_name: str) -> bool:
resolved_name = self._resolve_model_name(model_name)
return resolved_name in self.SUPPORTED_MODELS
return resolved_name in self.MODEL_CAPABILITIES
def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
# IMPORTANT: Resolve aliases before API call
@@ -185,12 +200,8 @@ class ExampleProvider(OpenAICompatibleProvider):
Add environment variable mapping in `providers/registry.py`:
```python
# In _get_api_key_for_provider method:
key_mapping = {
ProviderType.GOOGLE: "GEMINI_API_KEY",
ProviderType.OPENAI: "OPENAI_API_KEY",
ProviderType.EXAMPLE: "EXAMPLE_API_KEY", # Add this
}
# In _get_api_key_for_provider (providers/registry.py), add:
ProviderType.EXAMPLE: "EXAMPLE_API_KEY",
```
Add to `server.py`:
@@ -209,16 +220,7 @@ if example_key:
logger.info("Example API key found - Example models available")
```
3. **Add to provider priority** (in `providers/registry.py`):
```python
PROVIDER_PRIORITY_ORDER = [
ProviderType.GOOGLE,
ProviderType.OPENAI,
ProviderType.EXAMPLE, # Add your provider here
ProviderType.CUSTOM, # Local models
ProviderType.OPENROUTER, # Catch-all (keep last)
]
```
3. **Add to provider priority** (edit `ModelProviderRegistry.PROVIDER_PRIORITY_ORDER` in `providers/registry.py`): insert your provider in the list at the appropriate point in the cascade of native → custom → catch-all providers.
### 4. Environment Configuration
@@ -265,7 +267,7 @@ Your `validate_model_name()` should **only** return `True` for models you explic
```python
def validate_model_name(self, model_name: str) -> bool:
resolved_name = self._resolve_model_name(model_name)
return resolved_name in self.SUPPORTED_MODELS # Be specific!
return resolved_name in self.MODEL_CAPABILITIES # Be specific!
```
### Model Aliases
@@ -296,7 +298,7 @@ Without this, API calls with aliases like `"large"` will fail because your API d
## Quick Checklist
- [ ] Added to `ProviderType` enum in `providers/base.py`
- [ ] Added to `ProviderType` enum in `providers/shared/provider_type.py`
- [ ] Created provider class with all required methods
- [ ] Added API key mapping in `providers/registry.py`
- [ ] Added to provider priority order in `registry.py`
@@ -307,8 +309,6 @@ Without this, API calls with aliases like `"large"` will fail because your API d
## Examples
See existing implementations:
- **Full provider**: `providers/gemini.py`
- **Full provider**: `providers/gemini.py`
- **OpenAI-compatible**: `providers/custom.py`
- **Base classes**: `providers/base.py`
The modern approach uses `ModelCapabilities` objects directly in `SUPPORTED_MODELS`, making the implementation much cleaner and more consistent.