refactor: renaming to reflect underlying type
docs: updated to reflect new modules
This commit is contained in:
@@ -28,7 +28,7 @@ Each provider:
|
||||
|
||||
### 1. Add Provider Type
|
||||
|
||||
Add your provider to `ProviderType` enum in `providers/base.py`:
|
||||
Add your provider to the `ProviderType` enum in `providers/shared/provider_type.py`:
|
||||
|
||||
```python
|
||||
class ProviderType(Enum):
|
||||
@@ -48,15 +48,23 @@ Create `providers/example.py`:
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, RangeTemperatureConstraint
|
||||
|
||||
from .base import ModelProvider
|
||||
from .shared import (
|
||||
ModelCapabilities,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExampleModelProvider(ModelProvider):
|
||||
"""Example model provider implementation."""
|
||||
|
||||
# Define models using ModelCapabilities objects (like Gemini provider)
|
||||
SUPPORTED_MODELS = {
|
||||
MODEL_CAPABILITIES = {
|
||||
"example-large": ModelCapabilities(
|
||||
provider=ProviderType.EXAMPLE,
|
||||
model_name="example-large",
|
||||
@@ -87,7 +95,7 @@ class ExampleModelProvider(ModelProvider):
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||
raise ValueError(f"Unsupported model: {model_name}")
|
||||
|
||||
# Apply restrictions if needed
|
||||
@@ -96,7 +104,7 @@ class ExampleModelProvider(ModelProvider):
|
||||
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
|
||||
raise ValueError(f"Model '{model_name}' is not allowed.")
|
||||
|
||||
return self.SUPPORTED_MODELS[resolved_name]
|
||||
return self.MODEL_CAPABILITIES[resolved_name]
|
||||
|
||||
def generate_content(self, prompt: str, model_name: str, system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7, max_output_tokens: Optional[int] = None, **kwargs) -> ModelResponse:
|
||||
@@ -121,7 +129,7 @@ class ExampleModelProvider(ModelProvider):
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
return resolved_name in self.SUPPORTED_MODELS
|
||||
return resolved_name in self.MODEL_CAPABILITIES
|
||||
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
@@ -136,8 +144,15 @@ For OpenAI-compatible APIs:
|
||||
"""Example OpenAI-compatible provider."""
|
||||
|
||||
from typing import Optional
|
||||
from .base import ModelCapabilities, ModelResponse, ProviderType, RangeTemperatureConstraint
|
||||
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .shared import (
|
||||
ModelCapabilities,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint,
|
||||
)
|
||||
|
||||
|
||||
class ExampleProvider(OpenAICompatibleProvider):
|
||||
"""Example OpenAI-compatible provider."""
|
||||
@@ -145,7 +160,7 @@ class ExampleProvider(OpenAICompatibleProvider):
|
||||
FRIENDLY_NAME = "Example"
|
||||
|
||||
# Define models using ModelCapabilities (consistent with other providers)
|
||||
SUPPORTED_MODELS = {
|
||||
MODEL_CAPABILITIES = {
|
||||
"example-model-large": ModelCapabilities(
|
||||
provider=ProviderType.EXAMPLE,
|
||||
model_name="example-model-large",
|
||||
@@ -163,16 +178,16 @@ class ExampleProvider(OpenAICompatibleProvider):
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||
raise ValueError(f"Unsupported model: {model_name}")
|
||||
return self.SUPPORTED_MODELS[resolved_name]
|
||||
return self.MODEL_CAPABILITIES[resolved_name]
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
return ProviderType.EXAMPLE
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
return resolved_name in self.SUPPORTED_MODELS
|
||||
return resolved_name in self.MODEL_CAPABILITIES
|
||||
|
||||
def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
|
||||
# IMPORTANT: Resolve aliases before API call
|
||||
@@ -185,12 +200,8 @@ class ExampleProvider(OpenAICompatibleProvider):
|
||||
Add environment variable mapping in `providers/registry.py`:
|
||||
|
||||
```python
|
||||
# In _get_api_key_for_provider method:
|
||||
key_mapping = {
|
||||
ProviderType.GOOGLE: "GEMINI_API_KEY",
|
||||
ProviderType.OPENAI: "OPENAI_API_KEY",
|
||||
ProviderType.EXAMPLE: "EXAMPLE_API_KEY", # Add this
|
||||
}
|
||||
# In _get_api_key_for_provider (providers/registry.py), add:
|
||||
ProviderType.EXAMPLE: "EXAMPLE_API_KEY",
|
||||
```
|
||||
|
||||
Add to `server.py`:
|
||||
@@ -209,16 +220,7 @@ if example_key:
|
||||
logger.info("Example API key found - Example models available")
|
||||
```
|
||||
|
||||
3. **Add to provider priority** (in `providers/registry.py`):
|
||||
```python
|
||||
PROVIDER_PRIORITY_ORDER = [
|
||||
ProviderType.GOOGLE,
|
||||
ProviderType.OPENAI,
|
||||
ProviderType.EXAMPLE, # Add your provider here
|
||||
ProviderType.CUSTOM, # Local models
|
||||
ProviderType.OPENROUTER, # Catch-all (keep last)
|
||||
]
|
||||
```
|
||||
3. **Add to provider priority** (edit `ModelProviderRegistry.PROVIDER_PRIORITY_ORDER` in `providers/registry.py`): insert your provider in the list at the appropriate point in the cascade of native → custom → catch-all providers.
|
||||
|
||||
### 4. Environment Configuration
|
||||
|
||||
@@ -265,7 +267,7 @@ Your `validate_model_name()` should **only** return `True` for models you explic
|
||||
```python
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
return resolved_name in self.SUPPORTED_MODELS # Be specific!
|
||||
return resolved_name in self.MODEL_CAPABILITIES # Be specific!
|
||||
```
|
||||
|
||||
### Model Aliases
|
||||
@@ -296,7 +298,7 @@ Without this, API calls with aliases like `"large"` will fail because your API d
|
||||
|
||||
## Quick Checklist
|
||||
|
||||
- [ ] Added to `ProviderType` enum in `providers/base.py`
|
||||
- [ ] Added to `ProviderType` enum in `providers/shared/provider_type.py`
|
||||
- [ ] Created provider class with all required methods
|
||||
- [ ] Added API key mapping in `providers/registry.py`
|
||||
- [ ] Added to provider priority order in `registry.py`
|
||||
@@ -307,8 +309,6 @@ Without this, API calls with aliases like `"large"` will fail because your API d
|
||||
## Examples
|
||||
|
||||
See existing implementations:
|
||||
- **Full provider**: `providers/gemini.py`
|
||||
- **Full provider**: `providers/gemini.py`
|
||||
- **OpenAI-compatible**: `providers/custom.py`
|
||||
- **Base classes**: `providers/base.py`
|
||||
|
||||
The modern approach uses `ModelCapabilities` objects directly in `SUPPORTED_MODELS`, making the implementation much cleaner and more consistent.
|
||||
Reference in New Issue
Block a user