refactor: moved temperature method from base provider to model capabilities

refactor: model listing cleanup, moved logic to model_capabilities.py
docs: added AGENTS.md for onboarding Codex
This commit is contained in:
Fahad
2025-10-02 10:25:41 +04:00
parent f461cb4519
commit 6d237d0970
14 changed files with 460 additions and 512 deletions

View File

@@ -18,13 +18,26 @@ logger = logging.getLogger(__name__)
class ModelProvider(ABC):
"""Defines the contract implemented by every model provider backend.
"""Abstract base class for all model backends in the MCP server.
Subclasses adapt third-party SDKs into the MCP server by exposing
capability metadata, request execution, and token counting through a
consistent interface. Shared helper methods (temperature validation,
alias resolution, image handling, etc.) live here so individual providers
only need to focus on provider-specific details.
Role
Defines the interface every provider must implement so the registry,
restriction service, and tools have a uniform surface for listing
models, resolving aliases, and executing requests.
Responsibilities
* expose static capability metadata for each supported model via
:class:`ModelCapabilities`
* accept user prompts, forward them to the underlying SDK, and wrap
responses in :class:`ModelResponse`
* report tokenizer counts for budgeting and validation logic
* advertise provider identity (``ProviderType``) so restriction
policies can map environment configuration onto providers
* validate whether a model name or alias is recognised by the provider
Shared helpers like temperature validation, alias resolution, and
restriction-aware ``list_models`` live here so concrete subclasses only
need to supply their catalogue and wire up SDK-specific behaviour.
"""
# All concrete providers must define their supported models
@@ -151,67 +164,52 @@ class ModelProvider(ABC):
# If not found, return as-is
return model_name
def list_models(self, respect_restrictions: bool = True) -> list[str]:
"""Return a list of model names supported by this provider.
This implementation uses the get_model_configurations() hook
to support different model configuration sources.
def list_models(
self,
*,
respect_restrictions: bool = True,
include_aliases: bool = True,
lowercase: bool = False,
unique: bool = False,
) -> list[str]:
"""Return formatted model names supported by this provider.
Args:
respect_restrictions: Whether to apply provider-specific restriction logic.
respect_restrictions: Apply provider restriction policy.
include_aliases: Include aliases alongside canonical model names.
lowercase: Normalize returned names to lowercase.
unique: Deduplicate names after formatting.
Returns:
List of model names available from this provider
List of model names formatted according to the provided options.
"""
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
models = []
# Get model configurations from the hook method
model_configs = self.get_model_configurations()
if not model_configs:
return []
for model_name in model_configs:
# Check restrictions if enabled
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
continue
restriction_service = None
if respect_restrictions:
from utils.model_restrictions import get_restriction_service
# Add the base model
models.append(model_name)
restriction_service = get_restriction_service()
# Add aliases derived from the model configurations
alias_map = ModelCapabilities.collect_aliases(model_configs)
for model_name, aliases in alias_map.items():
# Only add aliases for models that passed restriction check
if model_name in models:
models.extend(aliases)
if restriction_service:
allowed_configs = {}
for model_name, config in model_configs.items():
if restriction_service.is_allowed(self.get_provider_type(), model_name):
allowed_configs[model_name] = config
model_configs = allowed_configs
return models
if not model_configs:
return []
def list_all_known_models(self) -> list[str]:
"""Return all model names known by this provider, including alias targets.
This is used for validation purposes to ensure restriction policies
can validate against both aliases and their target model names.
Returns:
List of all model names and alias targets known by this provider
"""
all_models = set()
# Get model configurations from the hook method
model_configs = self.get_model_configurations()
# Add all base model names
for model_name in model_configs:
all_models.add(model_name.lower())
# Add aliases derived from the model configurations
for aliases in ModelCapabilities.collect_aliases(model_configs).values():
for alias in aliases:
all_models.add(alias.lower())
return list(all_models)
return ModelCapabilities.collect_model_names(
model_configs,
include_aliases=include_aliases,
lowercase=lowercase,
unique=unique,
)
def validate_image(self, image_path: str, max_size_mb: float = None) -> tuple[bytes, str]:
"""Provider-independent image validation.

View File

@@ -32,11 +32,20 @@ _TEMP_UNSUPPORTED_KEYWORDS = [
class CustomProvider(OpenAICompatibleProvider):
"""Adapter for self-hosted or local OpenAI-compatible endpoints.
The provider reuses the :mod:`providers.shared` registry to surface
user-defined aliases and capability metadata. It also normalises
Ollama-style version tags (``model:latest``) and enforces the same
restriction policies used by cloud providers, ensuring consistent
behaviour regardless of where the model is hosted.
Role
Provide a uniform bridge between the MCP server and user-managed
OpenAI-compatible services (Ollama, vLLM, LM Studio, bespoke gateways).
By subclassing :class:`OpenAICompatibleProvider` it inherits request and
token handling, while the custom registry exposes locally defined model
metadata.
Notable behaviour
* Uses :class:`OpenRouterModelRegistry` to load model definitions and
aliases so custom deployments share the same metadata pipeline as
OpenRouter itself.
* Normalises version-tagged model names (``model:latest``) and applies
restriction policies just like cloud providers, ensuring consistent
behaviour across environments.
"""
FRIENDLY_NAME = "Custom API"

View File

@@ -17,9 +17,19 @@ from .shared import (
class OpenRouterProvider(OpenAICompatibleProvider):
"""Client for OpenRouter's multi-model aggregation service.
OpenRouter surfaces dozens of upstream vendors. This provider layers alias
resolution, restriction-aware filtering, and sensible capability defaults
on top of the generic OpenAI-compatible plumbing.
Role
Surface OpenRouters dynamic catalogue through the same interface as
native providers so tools can reference OpenRouter models and aliases
without special cases.
Characteristics
* Pulls live model definitions from :class:`OpenRouterModelRegistry`
(aliases, provider-specific metadata, capability hints)
* Applies alias-aware restriction checks before exposing models to the
registry or tooling
* Reuses :class:`OpenAICompatibleProvider` infrastructure for request
execution so OpenRouter endpoints behave like standard OpenAI-style
APIs.
"""
FRIENDLY_NAME = "OpenRouter"
@@ -208,75 +218,56 @@ class OpenRouterProvider(OpenAICompatibleProvider):
"""
return False
def list_models(self, respect_restrictions: bool = True) -> list[str]:
"""Return a list of model names supported by this provider.
def list_models(
self,
*,
respect_restrictions: bool = True,
include_aliases: bool = True,
lowercase: bool = False,
unique: bool = False,
) -> list[str]:
"""Return formatted OpenRouter model names, respecting alias-aware restrictions."""
Args:
respect_restrictions: Whether to apply provider-specific restriction logic.
if not self._registry:
return []
Returns:
List of model names available from this provider
"""
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
models = []
allowed_configs: dict[str, ModelCapabilities] = {}
if self._registry:
for model_name in self._registry.list_models():
# =====================================================================================
# CRITICAL ALIAS-AWARE RESTRICTION CHECKING (Fixed Issue #98)
# =====================================================================================
# Previously, restrictions only checked full model names (e.g., "google/gemini-2.5-pro")
# but users specify aliases in OPENROUTER_ALLOWED_MODELS (e.g., "pro").
# This caused "no models available" error even with valid restrictions.
#
# Fix: Check both model name AND all aliases against restrictions
# TEST COVERAGE: tests/test_provider_routing_bugs.py::TestOpenRouterAliasRestrictions
# =====================================================================================
if restriction_service:
# Get model config to check aliases as well
model_config = self._registry.resolve(model_name)
allowed = False
for model_name in self._registry.list_models():
config = self._registry.resolve(model_name)
if not config:
continue
# Check if model name itself is allowed
if restriction_service.is_allowed(self.get_provider_type(), model_name):
allowed = True
if restriction_service:
allowed = restriction_service.is_allowed(self.get_provider_type(), model_name)
# CRITICAL: Also check aliases - this fixes the alias restriction bug
if not allowed and model_config and model_config.aliases:
for alias in model_config.aliases:
if restriction_service.is_allowed(self.get_provider_type(), alias):
allowed = True
break
if not allowed and config.aliases:
for alias in config.aliases:
if restriction_service.is_allowed(self.get_provider_type(), alias):
allowed = True
break
if not allowed:
continue
if not allowed:
continue
models.append(model_name)
allowed_configs[model_name] = config
return models
if not allowed_configs:
return []
def list_all_known_models(self) -> list[str]:
"""Return all model names known by this provider, including alias targets.
# When restrictions are in place, don't include aliases to avoid confusion
# Only return the canonical model names that are actually allowed
actual_include_aliases = include_aliases and not respect_restrictions
Returns:
List of all model names and alias targets known by this provider
"""
all_models = set()
if self._registry:
# Get all models and aliases from the registry
all_models.update(model.lower() for model in self._registry.list_models())
all_models.update(alias.lower() for alias in self._registry.list_aliases())
# For each alias, also add its target
for alias in self._registry.list_aliases():
config = self._registry.resolve(alias)
if config:
all_models.add(config.model_name.lower())
return list(all_models)
return ModelCapabilities.collect_model_names(
allowed_configs,
include_aliases=actual_include_aliases,
lowercase=lowercase,
unique=unique,
)
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
"""Get model configurations from the registry.

View File

@@ -17,12 +17,21 @@ from .shared import (
class OpenRouterModelRegistry:
"""Loads and validates the OpenRouter/custom model catalogue.
"""In-memory view of OpenRouter and custom model metadata.
The registry parses ``conf/custom_models.json`` (or an override supplied via
environment variable), builds case-insensitive alias maps, and exposes
:class:`~providers.shared.ModelCapabilities` objects used by several
providers.
Role
Parse the packaged ``conf/custom_models.json`` (or user-specified
overrides), construct alias and capability maps, and serve those
structures to providers that rely on OpenRouter semantics (both the
OpenRouter provider itself and the Custom provider).
Key duties
* Load :class:`ModelCapabilities` definitions from configuration files
* Maintain a case-insensitive alias → canonical name map for fast
resolution
* Provide helpers to list models, list aliases, and resolve an arbitrary
name to its capability object without repeatedly touching the file
system.
"""
def __init__(self, config_path: Optional[str] = None):

View File

@@ -12,11 +12,22 @@ if TYPE_CHECKING:
class ModelProviderRegistry:
"""Singleton that caches provider instances and coordinates priority order.
"""Central catalogue of provider implementations used by the MCP server.
Responsibilities include resolving API keys from the environment, lazily
instantiating providers, and choosing the best provider for a model based
on restriction policies and provider priority.
Role
Holds the mapping between :class:`ProviderType` values and concrete
:class:`ModelProvider` subclasses/factories. At runtime the registry
is responsible for instantiating providers, caching them for reuse, and
mediating lookup of providers and model names in provider priority
order.
Core responsibilities
* Resolve API keys and other runtime configuration for each provider
* Lazily create provider instances so unused backends incur no cost
* Expose convenience methods for enumerating available models and
locating which provider can service a requested model name or alias
* Honour the project-wide provider priority policy so namespaces (or
alias collisions) are resolved deterministically.
"""
_instance = None

View File

@@ -11,24 +11,46 @@ __all__ = ["ModelCapabilities"]
@dataclass
class ModelCapabilities:
"""Static capabilities and constraints for a provider-managed model."""
"""Static description of what a model can do within a provider.
Role
Acts as the canonical record for everything the server needs to know
about a model—its provider, token limits, feature switches, aliases,
and temperature rules. Providers populate these objects so tools and
higher-level services can rely on a consistent schema.
Typical usage
* Provider subclasses declare `MODEL_CAPABILITIES` maps containing these
objects (for example ``OpenAIModelProvider``)
* Helper utilities (e.g. restriction validation, alias expansion) read
these objects to build model lists for tooling and policy enforcement
* Tool selection logic inspects attributes such as
``supports_extended_thinking`` or ``context_window`` to choose an
appropriate model for a task.
"""
provider: ProviderType
model_name: str
friendly_name: str
context_window: int
max_output_tokens: int
description: str = ""
aliases: list[str] = field(default_factory=list)
# Capacity limits / resource budgets
context_window: int = 0
max_output_tokens: int = 0
max_thinking_tokens: int = 0
# Capability flags
supports_extended_thinking: bool = False
supports_system_prompts: bool = True
supports_streaming: bool = True
supports_function_calling: bool = False
supports_images: bool = False
max_image_size_mb: float = 0.0
supports_temperature: bool = True
description: str = ""
aliases: list[str] = field(default_factory=list)
supports_json_mode: bool = False
max_thinking_tokens: int = 0
supports_temperature: bool = True
# Additional attributes
max_image_size_mb: float = 0.0
is_custom: bool = False
temperature_constraint: TemperatureConstraint = field(
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
@@ -56,3 +78,45 @@ class ModelCapabilities:
for base_model, capabilities in model_configs.items()
if capabilities.aliases
}
@staticmethod
def collect_model_names(
model_configs: dict[str, "ModelCapabilities"],
*,
include_aliases: bool = True,
lowercase: bool = False,
unique: bool = False,
) -> list[str]:
"""Build an ordered list of model names and aliases.
Args:
model_configs: Mapping of canonical model names to capabilities.
include_aliases: When True, include aliases for each model.
lowercase: When True, normalize names to lowercase.
unique: When True, ensure each returned name appears once (after formatting).
Returns:
Ordered list of model names (and optionally aliases) formatted per options.
"""
formatted_names: list[str] = []
seen: set[str] | None = set() if unique else None
def append_name(name: str) -> None:
formatted = name.lower() if lowercase else name
if seen is not None:
if formatted in seen:
return
seen.add(formatted)
formatted_names.append(formatted)
for base_model, capabilities in model_configs.items():
append_name(base_model)
if include_aliases and capabilities.aliases:
for alias in capabilities.aliases:
append_name(alias)
return formatted_names