refactor: cleanup provider base class; cleanup shared responsibilities; cleanup public contract

docs: document provider base class
refactor: cleanup custom provider, it should only deal with `is_custom` model configurations
fix: make sure openrouter provider does not load `is_custom` models
fix: listmodels tool cleanup
This commit is contained in:
Fahad
2025-10-02 12:59:45 +04:00
parent 6ec2033f34
commit 693b84db2b
15 changed files with 509 additions and 751 deletions

View File

@@ -5,7 +5,6 @@ import ipaddress
import logging
import os
import time
from abc import abstractmethod
from typing import Optional
from urllib.parse import urlparse
@@ -61,6 +60,33 @@ class OpenAICompatibleProvider(ModelProvider):
"This may be insecure. Consider setting an API key for authentication."
)
def _ensure_model_allowed(
self,
capabilities: ModelCapabilities,
canonical_name: str,
requested_name: str,
) -> None:
"""Respect provider-specific allowlists before default restriction checks."""
super()._ensure_model_allowed(capabilities, canonical_name, requested_name)
if self.allowed_models is not None:
requested = requested_name.lower()
canonical = canonical_name.lower()
if requested not in self.allowed_models and canonical not in self.allowed_models:
raise ValueError(
f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}"
)
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
"""Return statically declared capabilities for OpenAI-compatible providers."""
model_map = getattr(self, "MODEL_CAPABILITIES", None)
if isinstance(model_map, dict):
return {k: v for k, v in model_map.items() if isinstance(v, ModelCapabilities)}
return {}
def _parse_allowed_models(self) -> Optional[set[str]]:
"""Parse allowed models from environment variable.
@@ -686,30 +712,6 @@ class OpenAICompatibleProvider(ModelProvider):
return super().count_tokens(text, model_name)
@abstractmethod
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific model.
Must be implemented by subclasses.
"""
pass
@abstractmethod
def get_provider_type(self) -> ProviderType:
"""Get the provider type.
Must be implemented by subclasses.
"""
pass
@abstractmethod
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported.
Must be implemented by subclasses.
"""
pass
def _is_error_retryable(self, error: Exception) -> bool:
"""Determine if an error should be retried based on structured error codes.