refactor: cleanup provider base class; cleanup shared responsibilities; cleanup public contract
docs: document provider base class refactor: cleanup custom provider, it should only deal with `is_custom` model configurations fix: make sure openrouter provider does not load `is_custom` models fix: listmodels tool cleanup
This commit is contained in:
@@ -174,106 +174,61 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
kwargs.setdefault("base_url", "https://api.openai.com/v1")
|
||||
super().__init__(api_key, **kwargs)
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a specific OpenAI model."""
|
||||
# First check if it's a key in MODEL_CAPABILITIES
|
||||
if model_name in self.MODEL_CAPABILITIES:
|
||||
self._check_model_restrictions(model_name, model_name)
|
||||
return self.MODEL_CAPABILITIES[model_name]
|
||||
# ------------------------------------------------------------------
|
||||
# Capability surface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# Try resolving as alias
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
def _lookup_capabilities(
|
||||
self,
|
||||
canonical_name: str,
|
||||
requested_name: Optional[str] = None,
|
||||
) -> Optional[ModelCapabilities]:
|
||||
"""Look up OpenAI capabilities from built-ins or the custom registry."""
|
||||
|
||||
# Check if resolved name is a key
|
||||
if resolved_name in self.MODEL_CAPABILITIES:
|
||||
self._check_model_restrictions(resolved_name, model_name)
|
||||
return self.MODEL_CAPABILITIES[resolved_name]
|
||||
builtin = super()._lookup_capabilities(canonical_name, requested_name)
|
||||
if builtin is not None:
|
||||
return builtin
|
||||
|
||||
# Finally check if resolved name matches any API model name
|
||||
for key, capabilities in self.MODEL_CAPABILITIES.items():
|
||||
if resolved_name == capabilities.model_name:
|
||||
self._check_model_restrictions(key, model_name)
|
||||
return capabilities
|
||||
|
||||
# Check custom models registry for user-configured OpenAI models
|
||||
try:
|
||||
from .openrouter_registry import OpenRouterModelRegistry
|
||||
|
||||
registry = OpenRouterModelRegistry()
|
||||
config = registry.get_model_config(resolved_name)
|
||||
config = registry.get_model_config(canonical_name)
|
||||
|
||||
if config and config.provider == ProviderType.OPENAI:
|
||||
self._check_model_restrictions(config.model_name, model_name)
|
||||
|
||||
# Update provider type to ensure consistency
|
||||
config.provider = ProviderType.OPENAI
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
# Log but don't fail - registry might not be available
|
||||
logger.debug(f"Could not check custom models registry for '{resolved_name}': {e}")
|
||||
except Exception as exc: # pragma: no cover - registry failures are non-critical
|
||||
logger.debug(f"Could not resolve custom OpenAI model '{canonical_name}': {exc}")
|
||||
|
||||
return None
|
||||
|
||||
def _finalise_capabilities(
|
||||
self,
|
||||
capabilities: ModelCapabilities,
|
||||
canonical_name: str,
|
||||
requested_name: str,
|
||||
) -> ModelCapabilities:
|
||||
"""Ensure registry-sourced models report the correct provider type."""
|
||||
|
||||
if capabilities.provider != ProviderType.OPENAI:
|
||||
capabilities.provider = ProviderType.OPENAI
|
||||
return capabilities
|
||||
|
||||
def _raise_unsupported_model(self, model_name: str) -> None:
|
||||
raise ValueError(f"Unsupported OpenAI model: {model_name}")
|
||||
|
||||
def _check_model_restrictions(self, provider_model_name: str, user_model_name: str) -> None:
|
||||
"""Check if a model is allowed by restriction policy.
|
||||
|
||||
Args:
|
||||
provider_model_name: The model name used by the provider
|
||||
user_model_name: The model name requested by the user
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not allowed by restriction policy
|
||||
"""
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.OPENAI, provider_model_name, user_model_name):
|
||||
raise ValueError(f"OpenAI model '{user_model_name}' is not allowed by restriction policy.")
|
||||
# ------------------------------------------------------------------
|
||||
# Provider identity
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
return ProviderType.OPENAI
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is supported and allowed."""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
# First, determine which model name to check against restrictions.
|
||||
model_to_check = None
|
||||
is_custom_model = False
|
||||
|
||||
if resolved_name in self.MODEL_CAPABILITIES:
|
||||
model_to_check = resolved_name
|
||||
else:
|
||||
# If not a built-in model, check the custom models registry.
|
||||
try:
|
||||
from .openrouter_registry import OpenRouterModelRegistry
|
||||
|
||||
registry = OpenRouterModelRegistry()
|
||||
config = registry.get_model_config(resolved_name)
|
||||
|
||||
if config and config.provider == ProviderType.OPENAI:
|
||||
model_to_check = config.model_name
|
||||
is_custom_model = True
|
||||
except Exception as e:
|
||||
# Log but don't fail - registry might not be available.
|
||||
logger.debug(f"Could not check custom models registry for '{resolved_name}': {e}")
|
||||
|
||||
# If no model was found (neither built-in nor custom), it's invalid.
|
||||
if not model_to_check:
|
||||
return False
|
||||
|
||||
# Now, perform the restriction check once.
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.OPENAI, model_to_check, model_name):
|
||||
model_type = "custom " if is_custom_model else ""
|
||||
logger.debug(f"OpenAI {model_type}model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
||||
return False
|
||||
|
||||
return True
|
||||
# ------------------------------------------------------------------
|
||||
# Request execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
@@ -298,6 +253,10 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Provider preferences
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
|
||||
"""Get OpenAI's preferred model for a given category from allowed models.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user