refactor: cleanup provider base class; cleanup shared responsibilities; cleanup public contract
docs: document provider base class refactor: cleanup custom provider, it should only deal with `is_custom` model configurations fix: make sure openrouter provider does not load `is_custom` models fix: listmodels tool cleanup
This commit is contained in:
@@ -5,7 +5,6 @@ import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -61,6 +60,33 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
"This may be insecure. Consider setting an API key for authentication."
|
||||
)
|
||||
|
||||
def _ensure_model_allowed(
|
||||
self,
|
||||
capabilities: ModelCapabilities,
|
||||
canonical_name: str,
|
||||
requested_name: str,
|
||||
) -> None:
|
||||
"""Respect provider-specific allowlists before default restriction checks."""
|
||||
|
||||
super()._ensure_model_allowed(capabilities, canonical_name, requested_name)
|
||||
|
||||
if self.allowed_models is not None:
|
||||
requested = requested_name.lower()
|
||||
canonical = canonical_name.lower()
|
||||
|
||||
if requested not in self.allowed_models and canonical not in self.allowed_models:
|
||||
raise ValueError(
|
||||
f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}"
|
||||
)
|
||||
|
||||
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||
"""Return statically declared capabilities for OpenAI-compatible providers."""
|
||||
|
||||
model_map = getattr(self, "MODEL_CAPABILITIES", None)
|
||||
if isinstance(model_map, dict):
|
||||
return {k: v for k, v in model_map.items() if isinstance(v, ModelCapabilities)}
|
||||
return {}
|
||||
|
||||
def _parse_allowed_models(self) -> Optional[set[str]]:
|
||||
"""Parse allowed models from environment variable.
|
||||
|
||||
@@ -686,30 +712,6 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
|
||||
return super().count_tokens(text, model_name)
|
||||
|
||||
@abstractmethod
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a specific model.
|
||||
|
||||
Must be implemented by subclasses.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type.
|
||||
|
||||
Must be implemented by subclasses.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is supported.
|
||||
|
||||
Must be implemented by subclasses.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _is_error_retryable(self, error: Exception) -> bool:
|
||||
"""Determine if an error should be retried based on structured error codes.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user