refactor: removed subclass override when the base class should be resolving the model name

refactor: always disable "stream"
This commit is contained in:
Fahad
2025-10-04 10:35:32 +04:00
parent d184024820
commit 06d7701cc3
17 changed files with 210 additions and 260 deletions

View File

@@ -60,16 +60,19 @@ class ModelProvider(ABC):
customise. Subclasses usually only override ``_lookup_capabilities`` to
integrate a registry or dynamic source, or ``_finalise_capabilities`` to
tweak the returned object.
Args:
model_name: Canonical model name or its alias
"""
resolved_name = self._resolve_model_name(model_name)
capabilities = self._lookup_capabilities(resolved_name, model_name)
resolved_model_name = self._resolve_model_name(model_name)
capabilities = self._lookup_capabilities(resolved_model_name, model_name)
if capabilities is None:
self._raise_unsupported_model(model_name)
self._ensure_model_allowed(capabilities, resolved_name, model_name)
return self._finalise_capabilities(capabilities, resolved_name, model_name)
self._ensure_model_allowed(capabilities, resolved_model_name, model_name)
return self._finalise_capabilities(capabilities, resolved_model_name, model_name)
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
"""Return statically declared capabilities when available."""
@@ -150,7 +153,38 @@ class ModelProvider(ABC):
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using the model."""
"""Generate content using the model.
This is the core method that all providers must implement to generate responses
from their models. Providers should handle model-specific capabilities and
constraints appropriately.
Args:
prompt: The main user prompt/query to send to the model
model_name: Canonical model name or its alias that the provider supports
system_prompt: Optional system instructions to prepend to the prompt for
establishing context, behavior, or role
temperature: Controls randomness in generation (0.0=deterministic, 1.0=creative),
default 0.3. Some models may not support temperature control
max_output_tokens: Optional maximum number of tokens to generate in the response.
If not specified, uses the model's default limit
**kwargs: Additional provider-specific parameters that vary by implementation
(e.g., thinking_mode for Gemini, top_p for OpenAI, images for vision models)
Returns:
ModelResponse: Standardized response object containing:
- content: The generated text response
- usage: Token usage statistics (input/output/total)
- model_name: The model that was actually used
- friendly_name: Human-readable provider/model identifier
- provider: The ProviderType enum value
- metadata: Provider-specific metadata (finish_reason, safety info, etc.)
Raises:
ValueError: If the model is not supported, parameters are invalid,
or the model is restricted by policy
RuntimeError: If the API call fails after retries
"""
def count_tokens(self, text: str, model_name: str) -> int:
"""Estimate token usage for a piece of text."""
@@ -276,7 +310,12 @@ class ModelProvider(ABC):
# Validation hooks
# ------------------------------------------------------------------
def validate_model_name(self, model_name: str) -> bool:
"""Return ``True`` when the model resolves to an allowed capability."""
"""
Return ``True`` when the model resolves to an allowed capability.
Args:
model_name: Canonical model name or its alias
"""
try:
self.get_capabilities(model_name)
@@ -285,7 +324,12 @@ class ModelProvider(ABC):
return True
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
"""Validate model parameters against capabilities."""
"""
Validate model parameters against capabilities.
Args:
model_name: Canonical model name or its alias
"""
capabilities = self.get_capabilities(model_name)
@@ -364,7 +408,7 @@ class ModelProvider(ABC):
model configuration sources.
Args:
model_name: Model name that may be an alias
model_name: Canonical model name or its alias
Returns:
Resolved model name

View File

@@ -6,7 +6,7 @@ from typing import Optional
from .openai_compatible import OpenAICompatibleProvider
from .openrouter_registry import OpenRouterModelRegistry
from .shared import ModelCapabilities, ModelResponse, ProviderType
from .shared import ModelCapabilities, ProviderType
class CustomProvider(OpenAICompatibleProvider):
@@ -113,49 +113,6 @@ class CustomProvider(OpenAICompatibleProvider):
return ProviderType.CUSTOM
# ------------------------------------------------------------------
# Validation
# ------------------------------------------------------------------
# ------------------------------------------------------------------
# Request execution
# ------------------------------------------------------------------
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using the custom API.
Args:
prompt: User prompt to send to the model
model_name: Name of the model to use
system_prompt: Optional system prompt for model behavior
temperature: Sampling temperature
max_output_tokens: Maximum tokens to generate
**kwargs: Additional provider-specific parameters
Returns:
ModelResponse with generated content and metadata
"""
# Resolve model alias to actual model name
resolved_model = self._resolve_model_name(model_name)
# Call parent method with resolved model name
return super().generate_content(
prompt=prompt,
model_name=resolved_model,
system_prompt=system_prompt,
temperature=temperature,
max_output_tokens=max_output_tokens,
**kwargs,
)
# ------------------------------------------------------------------
# Registry helpers
# ------------------------------------------------------------------

View File

@@ -333,15 +333,17 @@ class DIALModelProvider(OpenAICompatibleProvider):
/openai/deployments/{deployment}/chat/completions
Args:
prompt: User prompt
model_name: Model name or alias
system_prompt: Optional system prompt
temperature: Sampling temperature
max_output_tokens: Maximum tokens to generate
**kwargs: Additional provider-specific parameters
prompt: The main user prompt/query to send to the model
model_name: Model name or alias (e.g., "o3", "sonnet-4.1", "gemini-2.5-pro")
system_prompt: Optional system instructions to prepend to the prompt for context/behavior
temperature: Sampling temperature for randomness (0.0=deterministic, 1.0=creative), default 0.3
Note: O3/O4 models don't support temperature and will ignore this parameter
max_output_tokens: Optional maximum number of tokens to generate in the response
images: Optional list of image paths or data URLs to include with the prompt (for vision-capable models)
**kwargs: Additional OpenAI-compatible parameters (top_p, frequency_penalty, presence_penalty, seed, stop)
Returns:
ModelResponse with generated content and metadata
ModelResponse: Contains the generated content, token usage stats, model metadata, and finish reason
"""
# Validate model name against allow-list
if not self.validate_model_name(model_name):
@@ -381,6 +383,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
completion_params = {
"model": resolved_model,
"messages": messages,
"stream": False,
}
# Determine temperature support from capabilities
@@ -397,7 +400,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
# Add additional parameters
for key, value in kwargs.items():
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
if not supports_temperature and key in ["top_p", "frequency_penalty", "presence_penalty"]:
if not supports_temperature and key in ["top_p", "frequency_penalty", "presence_penalty", "stream"]:
continue
completion_params[key] = value
@@ -437,9 +440,9 @@ class DIALModelProvider(OpenAICompatibleProvider):
except Exception as exc:
attempts = max(attempt_counter["value"], 1)
if attempts == 1:
raise ValueError(f"DIAL API error for model {model_name}: {exc}") from exc
raise ValueError(f"DIAL API error for model {resolved_model}: {exc}") from exc
raise ValueError(f"DIAL API error for model {model_name} after {attempts} attempts: {exc}") from exc
raise ValueError(f"DIAL API error for model {resolved_model} after {attempts} attempts: {exc}") from exc
def close(self) -> None:
"""Clean up HTTP clients when provider is closed."""

View File

@@ -172,12 +172,28 @@ class GeminiModelProvider(ModelProvider):
images: Optional[list[str]] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using Gemini model."""
"""
Generate content using Gemini model.
Args:
prompt: The main user prompt/query to send to the model
model_name: Canonical model name or its alias (e.g., "gemini-2.5-pro", "flash", "pro")
system_prompt: Optional system instructions to prepend to the prompt for context/behavior
temperature: Controls randomness in generation (0.0=deterministic, 1.0=creative), default 0.3
max_output_tokens: Optional maximum number of tokens to generate in the response
thinking_mode: Thinking budget level for models that support it ("minimal", "low", "medium", "high", "max"), default "medium"
images: Optional list of image paths or data URLs to include with the prompt (for vision models)
**kwargs: Additional keyword arguments (reserved for future use)
Returns:
ModelResponse: Contains the generated content, token usage stats, model metadata, and safety information
"""
# Validate parameters and fetch capabilities
resolved_name = self._resolve_model_name(model_name)
self.validate_parameters(model_name, temperature)
capabilities = self.get_capabilities(model_name)
resolved_model_name = self._resolve_model_name(model_name)
# Prepare content parts (text and potentially images)
parts = []
@@ -201,7 +217,7 @@ class GeminiModelProvider(ModelProvider):
# Continue with other images and text
continue
elif images and not capabilities.supports_images:
logger.warning(f"Model {resolved_name} does not support images, ignoring {len(images)} image(s)")
logger.warning(f"Model {resolved_model_name} does not support images, ignoring {len(images)} image(s)")
# Create contents structure
contents = [{"parts": parts}]
@@ -219,7 +235,7 @@ class GeminiModelProvider(ModelProvider):
# Add thinking configuration for models that support it
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
# Get model's max thinking tokens and calculate actual budget
model_config = self.MODEL_CAPABILITIES.get(resolved_name)
model_config = self.MODEL_CAPABILITIES.get(resolved_model_name)
if model_config and model_config.max_thinking_tokens > 0:
max_thinking_tokens = model_config.max_thinking_tokens
actual_thinking_budget = int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
@@ -233,7 +249,7 @@ class GeminiModelProvider(ModelProvider):
def _attempt() -> ModelResponse:
attempt_counter["value"] += 1
response = self.client.models.generate_content(
model=resolved_name,
model=resolved_model_name,
contents=contents,
config=generation_config,
)
@@ -308,7 +324,7 @@ class GeminiModelProvider(ModelProvider):
return ModelResponse(
content=response.text,
usage=usage,
model_name=resolved_name,
model_name=resolved_model_name,
friendly_name="Gemini",
provider=ProviderType.GOOGLE,
metadata={
@@ -324,12 +340,12 @@ class GeminiModelProvider(ModelProvider):
operation=_attempt,
max_attempts=max_retries,
delays=retry_delays,
log_prefix=f"Gemini API ({resolved_name})",
log_prefix=f"Gemini API ({resolved_model_name})",
)
except Exception as exc:
attempts = max(attempt_counter["value"], 1)
error_msg = (
f"Gemini API error for model {resolved_name} after {attempts} attempt"
f"Gemini API error for model {resolved_model_name} after {attempts} attempt"
f"{'s' if attempts > 1 else ''}: {exc}"
)
raise RuntimeError(error_msg) from exc

View File

@@ -462,10 +462,11 @@ class OpenAICompatibleProvider(ModelProvider):
Args:
prompt: User prompt to send to the model
model_name: Name of the model to use
model_name: Canonical model name or its alias
system_prompt: Optional system prompt for model behavior
temperature: Sampling temperature
max_output_tokens: Maximum tokens to generate
images: Optional list of image paths or data URLs to include with the prompt (for vision models)
**kwargs: Additional provider-specific parameters
Returns:
@@ -497,6 +498,9 @@ class OpenAICompatibleProvider(ModelProvider):
# Validate parameters with the effective temperature
self.validate_parameters(model_name, effective_temperature)
# Resolve to canonical model name
resolved_model = self._resolve_model_name(model_name)
# Prepare messages
messages = []
if system_prompt:
@@ -518,7 +522,7 @@ class OpenAICompatibleProvider(ModelProvider):
# Continue with other images and text
continue
elif images and (not capabilities or not capabilities.supports_images):
logging.warning(f"Model {model_name} does not support images, ignoring {len(images)} image(s)")
logging.warning(f"Model {resolved_model} does not support images, ignoring {len(images)} image(s)")
# Add user message
if len(user_content) == 1:
@@ -529,14 +533,14 @@ class OpenAICompatibleProvider(ModelProvider):
messages.append({"role": "user", "content": user_content})
# Prepare completion parameters
# Always disable streaming for OpenRouter
# MCP doesn't use streaming, and this avoids issues with O3 model access
completion_params = {
"model": model_name,
"model": resolved_model,
"messages": messages,
"stream": False,
}
# Check model capabilities once to determine parameter support
resolved_model = self._resolve_model_name(model_name)
# Use the effective temperature we calculated earlier
supports_sampling = effective_temperature is not None
@@ -553,7 +557,7 @@ class OpenAICompatibleProvider(ModelProvider):
for key, value in kwargs.items():
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
# Reasoning models (those that don't support temperature) also don't support these parameters
if not supports_sampling and key in ["top_p", "frequency_penalty", "presence_penalty"]:
if not supports_sampling and key in ["top_p", "frequency_penalty", "presence_penalty", "stream"]:
continue # Skip unsupported parameters for reasoning models
completion_params[key] = value
@@ -585,7 +589,7 @@ class OpenAICompatibleProvider(ModelProvider):
return ModelResponse(
content=content,
usage=usage,
model_name=model_name,
model_name=resolved_model,
friendly_name=self.FRIENDLY_NAME,
provider=self.get_provider_type(),
metadata={
@@ -601,12 +605,12 @@ class OpenAICompatibleProvider(ModelProvider):
operation=_attempt,
max_attempts=max_retries,
delays=retry_delays,
log_prefix=f"{self.FRIENDLY_NAME} API ({model_name})",
log_prefix=f"{self.FRIENDLY_NAME} API ({resolved_model})",
)
except Exception as exc:
attempts = max(attempt_counter["value"], 1)
error_msg = (
f"{self.FRIENDLY_NAME} API error for model {model_name} after {attempts} attempt"
f"{self.FRIENDLY_NAME} API error for model {resolved_model} after {attempts} attempt"
f"{'s' if attempts > 1 else ''}: {exc}"
)
logging.error(error_msg)
@@ -618,7 +622,7 @@ class OpenAICompatibleProvider(ModelProvider):
For proxy providers, this may use generic capabilities.
Args:
model_name: Model to validate for
model_name: Canonical model name or its alias
temperature: Temperature to validate
**kwargs: Additional parameters to validate
"""

View File

@@ -7,7 +7,7 @@ if TYPE_CHECKING:
from tools.models import ToolModelCategory
from .openai_compatible import OpenAICompatibleProvider
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
from .shared import ModelCapabilities, ProviderType, TemperatureConstraint
logger = logging.getLogger(__name__)
@@ -253,33 +253,6 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
"""Get the provider type."""
return ProviderType.OPENAI
# ------------------------------------------------------------------
# Request execution
# ------------------------------------------------------------------
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using OpenAI API with proper model name resolution."""
# Resolve model alias before making API call
resolved_model_name = self._resolve_model_name(model_name)
# Call parent implementation with resolved model name
return super().generate_content(
prompt=prompt,
model_name=resolved_model_name,
system_prompt=system_prompt,
temperature=temperature,
max_output_tokens=max_output_tokens,
**kwargs,
)
# ------------------------------------------------------------------
# Provider preferences
# ------------------------------------------------------------------

View File

@@ -8,7 +8,6 @@ from .openai_compatible import OpenAICompatibleProvider
from .openrouter_registry import OpenRouterModelRegistry
from .shared import (
ModelCapabilities,
ModelResponse,
ProviderType,
RangeTemperatureConstraint,
)
@@ -111,50 +110,6 @@ class OpenRouterProvider(OpenAICompatibleProvider):
"""Identify this provider for restrictions and logging."""
return ProviderType.OPENROUTER
# ------------------------------------------------------------------
# Request execution
# ------------------------------------------------------------------
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using the OpenRouter API.
Args:
prompt: User prompt to send to the model
model_name: Name of the model (or alias) to use
system_prompt: Optional system prompt for model behavior
temperature: Sampling temperature
max_output_tokens: Maximum tokens to generate
**kwargs: Additional provider-specific parameters
Returns:
ModelResponse with generated content and metadata
"""
# Resolve model alias to actual OpenRouter model name
resolved_model = self._resolve_model_name(model_name)
# Always disable streaming for OpenRouter
# MCP doesn't use streaming, and this avoids issues with O3 model access
if "stream" not in kwargs:
kwargs["stream"] = False
# Call parent method with resolved model name
return super().generate_content(
prompt=prompt,
model_name=resolved_model,
system_prompt=system_prompt,
temperature=temperature,
max_output_tokens=max_output_tokens,
**kwargs,
)
# ------------------------------------------------------------------
# Registry helpers
# ------------------------------------------------------------------

View File

@@ -7,7 +7,7 @@ if TYPE_CHECKING:
from tools.models import ToolModelCategory
from .openai_compatible import OpenAICompatibleProvider
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
from .shared import ModelCapabilities, ProviderType, TemperatureConstraint
logger = logging.getLogger(__name__)
@@ -92,29 +92,6 @@ class XAIModelProvider(OpenAICompatibleProvider):
"""Get the provider type."""
return ProviderType.XAI
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using X.AI API with proper model name resolution."""
# Resolve model alias before making API call
resolved_model_name = self._resolve_model_name(model_name)
# Call parent implementation with resolved model name
return super().generate_content(
prompt=prompt,
model_name=resolved_model_name,
system_prompt=system_prompt,
temperature=temperature,
max_output_tokens=max_output_tokens,
**kwargs,
)
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
"""Get XAI's preferred model for a given category from allowed models.