refactor: cleanup provider base class; cleanup shared responsibilities; cleanup public contract
docs: document provider base class refactor: cleanup custom provider, it should only deal with `is_custom` model configurations fix: make sure openrouter provider does not load `is_custom` models fix: listmodels tool cleanup
This commit is contained in:
@@ -261,68 +261,10 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
|
||||
logger.info(f"Initialized DIAL provider with host: {dial_host} and api-version: {self.api_version}")
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a specific model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model (can be shorthand)
|
||||
|
||||
Returns:
|
||||
ModelCapabilities object
|
||||
|
||||
Raises:
|
||||
ValueError: If model is not supported or not allowed
|
||||
"""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||
raise ValueError(f"Unsupported DIAL model: {model_name}")
|
||||
|
||||
# Check restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
|
||||
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
|
||||
|
||||
# Return the ModelCapabilities object directly from MODEL_CAPABILITIES
|
||||
return self.MODEL_CAPABILITIES[resolved_name]
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
return ProviderType.DIAL
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is supported.
|
||||
|
||||
Args:
|
||||
model_name: Model name to validate
|
||||
|
||||
Returns:
|
||||
True if model is supported and allowed, False otherwise
|
||||
"""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||
return False
|
||||
|
||||
# Check against base class allowed_models if configured
|
||||
if self.allowed_models is not None:
|
||||
# Check both original and resolved names (case-insensitive)
|
||||
if model_name.lower() not in self.allowed_models and resolved_name.lower() not in self.allowed_models:
|
||||
logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' not in allowed_models list")
|
||||
return False
|
||||
|
||||
# Also check restrictions via ModelRestrictionService
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
|
||||
logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_deployment_client(self, deployment: str):
|
||||
"""Get or create a cached client for a specific deployment.
|
||||
|
||||
@@ -504,7 +446,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
|
||||
)
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Clean up HTTP clients when provider is closed."""
|
||||
logger.info("Closing DIAL provider HTTP clients...")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user