refactor: moved temperature method from base provider to model capabilities
refactor: model listing cleanup, moved logic to model_capabilities.py docs: added AGENTS.md for onboarding Codex
This commit is contained in:
@@ -18,13 +18,26 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelProvider(ABC):
|
||||
"""Defines the contract implemented by every model provider backend.
|
||||
"""Abstract base class for all model backends in the MCP server.
|
||||
|
||||
Subclasses adapt third-party SDKs into the MCP server by exposing
|
||||
capability metadata, request execution, and token counting through a
|
||||
consistent interface. Shared helper methods (temperature validation,
|
||||
alias resolution, image handling, etc.) live here so individual providers
|
||||
only need to focus on provider-specific details.
|
||||
Role
|
||||
Defines the interface every provider must implement so the registry,
|
||||
restriction service, and tools have a uniform surface for listing
|
||||
models, resolving aliases, and executing requests.
|
||||
|
||||
Responsibilities
|
||||
* expose static capability metadata for each supported model via
|
||||
:class:`ModelCapabilities`
|
||||
* accept user prompts, forward them to the underlying SDK, and wrap
|
||||
responses in :class:`ModelResponse`
|
||||
* report tokenizer counts for budgeting and validation logic
|
||||
* advertise provider identity (``ProviderType``) so restriction
|
||||
policies can map environment configuration onto providers
|
||||
* validate whether a model name or alias is recognised by the provider
|
||||
|
||||
Shared helpers like temperature validation, alias resolution, and
|
||||
restriction-aware ``list_models`` live here so concrete subclasses only
|
||||
need to supply their catalogue and wire up SDK-specific behaviour.
|
||||
"""
|
||||
|
||||
# All concrete providers must define their supported models
|
||||
@@ -151,67 +164,52 @@ class ModelProvider(ABC):
|
||||
# If not found, return as-is
|
||||
return model_name
|
||||
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
|
||||
This implementation uses the get_model_configurations() hook
|
||||
to support different model configuration sources.
|
||||
def list_models(
|
||||
self,
|
||||
*,
|
||||
respect_restrictions: bool = True,
|
||||
include_aliases: bool = True,
|
||||
lowercase: bool = False,
|
||||
unique: bool = False,
|
||||
) -> list[str]:
|
||||
"""Return formatted model names supported by this provider.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
respect_restrictions: Apply provider restriction policy.
|
||||
include_aliases: Include aliases alongside canonical model names.
|
||||
lowercase: Normalize returned names to lowercase.
|
||||
unique: Deduplicate names after formatting.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
List of model names formatted according to the provided options.
|
||||
"""
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||
models = []
|
||||
|
||||
# Get model configurations from the hook method
|
||||
model_configs = self.get_model_configurations()
|
||||
if not model_configs:
|
||||
return []
|
||||
|
||||
for model_name in model_configs:
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||
continue
|
||||
restriction_service = None
|
||||
if respect_restrictions:
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
# Add the base model
|
||||
models.append(model_name)
|
||||
restriction_service = get_restriction_service()
|
||||
|
||||
# Add aliases derived from the model configurations
|
||||
alias_map = ModelCapabilities.collect_aliases(model_configs)
|
||||
for model_name, aliases in alias_map.items():
|
||||
# Only add aliases for models that passed restriction check
|
||||
if model_name in models:
|
||||
models.extend(aliases)
|
||||
if restriction_service:
|
||||
allowed_configs = {}
|
||||
for model_name, config in model_configs.items():
|
||||
if restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||
allowed_configs[model_name] = config
|
||||
model_configs = allowed_configs
|
||||
|
||||
return models
|
||||
if not model_configs:
|
||||
return []
|
||||
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
|
||||
This is used for validation purposes to ensure restriction policies
|
||||
can validate against both aliases and their target model names.
|
||||
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
"""
|
||||
all_models = set()
|
||||
|
||||
# Get model configurations from the hook method
|
||||
model_configs = self.get_model_configurations()
|
||||
|
||||
# Add all base model names
|
||||
for model_name in model_configs:
|
||||
all_models.add(model_name.lower())
|
||||
|
||||
# Add aliases derived from the model configurations
|
||||
for aliases in ModelCapabilities.collect_aliases(model_configs).values():
|
||||
for alias in aliases:
|
||||
all_models.add(alias.lower())
|
||||
|
||||
return list(all_models)
|
||||
return ModelCapabilities.collect_model_names(
|
||||
model_configs,
|
||||
include_aliases=include_aliases,
|
||||
lowercase=lowercase,
|
||||
unique=unique,
|
||||
)
|
||||
|
||||
def validate_image(self, image_path: str, max_size_mb: float = None) -> tuple[bytes, str]:
|
||||
"""Provider-independent image validation.
|
||||
|
||||
Reference in New Issue
Block a user