refactor: removed subclass override when the base class should be resolving the model name
refactor: always disable "stream"
This commit is contained in:
@@ -60,16 +60,19 @@ class ModelProvider(ABC):
|
||||
customise. Subclasses usually only override ``_lookup_capabilities`` to
|
||||
integrate a registry or dynamic source, or ``_finalise_capabilities`` to
|
||||
tweak the returned object.
|
||||
|
||||
Args:
|
||||
model_name: Canonical model name or its alias
|
||||
"""
|
||||
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
capabilities = self._lookup_capabilities(resolved_name, model_name)
|
||||
resolved_model_name = self._resolve_model_name(model_name)
|
||||
capabilities = self._lookup_capabilities(resolved_model_name, model_name)
|
||||
|
||||
if capabilities is None:
|
||||
self._raise_unsupported_model(model_name)
|
||||
|
||||
self._ensure_model_allowed(capabilities, resolved_name, model_name)
|
||||
return self._finalise_capabilities(capabilities, resolved_name, model_name)
|
||||
self._ensure_model_allowed(capabilities, resolved_model_name, model_name)
|
||||
return self._finalise_capabilities(capabilities, resolved_model_name, model_name)
|
||||
|
||||
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||
"""Return statically declared capabilities when available."""
|
||||
@@ -150,7 +153,38 @@ class ModelProvider(ABC):
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""Generate content using the model."""
|
||||
"""Generate content using the model.
|
||||
|
||||
This is the core method that all providers must implement to generate responses
|
||||
from their models. Providers should handle model-specific capabilities and
|
||||
constraints appropriately.
|
||||
|
||||
Args:
|
||||
prompt: The main user prompt/query to send to the model
|
||||
model_name: Canonical model name or its alias that the provider supports
|
||||
system_prompt: Optional system instructions to prepend to the prompt for
|
||||
establishing context, behavior, or role
|
||||
temperature: Controls randomness in generation (0.0=deterministic, 1.0=creative),
|
||||
default 0.3. Some models may not support temperature control
|
||||
max_output_tokens: Optional maximum number of tokens to generate in the response.
|
||||
If not specified, uses the model's default limit
|
||||
**kwargs: Additional provider-specific parameters that vary by implementation
|
||||
(e.g., thinking_mode for Gemini, top_p for OpenAI, images for vision models)
|
||||
|
||||
Returns:
|
||||
ModelResponse: Standardized response object containing:
|
||||
- content: The generated text response
|
||||
- usage: Token usage statistics (input/output/total)
|
||||
- model_name: The model that was actually used
|
||||
- friendly_name: Human-readable provider/model identifier
|
||||
- provider: The ProviderType enum value
|
||||
- metadata: Provider-specific metadata (finish_reason, safety info, etc.)
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not supported, parameters are invalid,
|
||||
or the model is restricted by policy
|
||||
RuntimeError: If the API call fails after retries
|
||||
"""
|
||||
|
||||
def count_tokens(self, text: str, model_name: str) -> int:
|
||||
"""Estimate token usage for a piece of text."""
|
||||
@@ -276,7 +310,12 @@ class ModelProvider(ABC):
|
||||
# Validation hooks
|
||||
# ------------------------------------------------------------------
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Return ``True`` when the model resolves to an allowed capability."""
|
||||
"""
|
||||
Return ``True`` when the model resolves to an allowed capability.
|
||||
|
||||
Args:
|
||||
model_name: Canonical model name or its alias
|
||||
"""
|
||||
|
||||
try:
|
||||
self.get_capabilities(model_name)
|
||||
@@ -285,7 +324,12 @@ class ModelProvider(ABC):
|
||||
return True
|
||||
|
||||
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
|
||||
"""Validate model parameters against capabilities."""
|
||||
"""
|
||||
Validate model parameters against capabilities.
|
||||
|
||||
Args:
|
||||
model_name: Canonical model name or its alias
|
||||
"""
|
||||
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
|
||||
@@ -364,7 +408,7 @@ class ModelProvider(ABC):
|
||||
model configuration sources.
|
||||
|
||||
Args:
|
||||
model_name: Model name that may be an alias
|
||||
model_name: Canonical model name or its alias
|
||||
|
||||
Returns:
|
||||
Resolved model name
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Optional
|
||||
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .openrouter_registry import OpenRouterModelRegistry
|
||||
from .shared import ModelCapabilities, ModelResponse, ProviderType
|
||||
from .shared import ModelCapabilities, ProviderType
|
||||
|
||||
|
||||
class CustomProvider(OpenAICompatibleProvider):
|
||||
@@ -113,49 +113,6 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
|
||||
return ProviderType.CUSTOM
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Validation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Request execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.3,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""Generate content using the custom API.
|
||||
|
||||
Args:
|
||||
prompt: User prompt to send to the model
|
||||
model_name: Name of the model to use
|
||||
system_prompt: Optional system prompt for model behavior
|
||||
temperature: Sampling temperature
|
||||
max_output_tokens: Maximum tokens to generate
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated content and metadata
|
||||
"""
|
||||
# Resolve model alias to actual model name
|
||||
resolved_model = self._resolve_model_name(model_name)
|
||||
|
||||
# Call parent method with resolved model name
|
||||
return super().generate_content(
|
||||
prompt=prompt,
|
||||
model_name=resolved_model,
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature,
|
||||
max_output_tokens=max_output_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Registry helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -333,15 +333,17 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
/openai/deployments/{deployment}/chat/completions
|
||||
|
||||
Args:
|
||||
prompt: User prompt
|
||||
model_name: Model name or alias
|
||||
system_prompt: Optional system prompt
|
||||
temperature: Sampling temperature
|
||||
max_output_tokens: Maximum tokens to generate
|
||||
**kwargs: Additional provider-specific parameters
|
||||
prompt: The main user prompt/query to send to the model
|
||||
model_name: Model name or alias (e.g., "o3", "sonnet-4.1", "gemini-2.5-pro")
|
||||
system_prompt: Optional system instructions to prepend to the prompt for context/behavior
|
||||
temperature: Sampling temperature for randomness (0.0=deterministic, 1.0=creative), default 0.3
|
||||
Note: O3/O4 models don't support temperature and will ignore this parameter
|
||||
max_output_tokens: Optional maximum number of tokens to generate in the response
|
||||
images: Optional list of image paths or data URLs to include with the prompt (for vision-capable models)
|
||||
**kwargs: Additional OpenAI-compatible parameters (top_p, frequency_penalty, presence_penalty, seed, stop)
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated content and metadata
|
||||
ModelResponse: Contains the generated content, token usage stats, model metadata, and finish reason
|
||||
"""
|
||||
# Validate model name against allow-list
|
||||
if not self.validate_model_name(model_name):
|
||||
@@ -381,6 +383,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
completion_params = {
|
||||
"model": resolved_model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
# Determine temperature support from capabilities
|
||||
@@ -397,7 +400,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
# Add additional parameters
|
||||
for key, value in kwargs.items():
|
||||
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
|
||||
if not supports_temperature and key in ["top_p", "frequency_penalty", "presence_penalty"]:
|
||||
if not supports_temperature and key in ["top_p", "frequency_penalty", "presence_penalty", "stream"]:
|
||||
continue
|
||||
completion_params[key] = value
|
||||
|
||||
@@ -437,9 +440,9 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
except Exception as exc:
|
||||
attempts = max(attempt_counter["value"], 1)
|
||||
if attempts == 1:
|
||||
raise ValueError(f"DIAL API error for model {model_name}: {exc}") from exc
|
||||
raise ValueError(f"DIAL API error for model {resolved_model}: {exc}") from exc
|
||||
|
||||
raise ValueError(f"DIAL API error for model {model_name} after {attempts} attempts: {exc}") from exc
|
||||
raise ValueError(f"DIAL API error for model {resolved_model} after {attempts} attempts: {exc}") from exc
|
||||
|
||||
def close(self) -> None:
|
||||
"""Clean up HTTP clients when provider is closed."""
|
||||
|
||||
@@ -172,12 +172,28 @@ class GeminiModelProvider(ModelProvider):
|
||||
images: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""Generate content using Gemini model."""
|
||||
"""
|
||||
Generate content using Gemini model.
|
||||
|
||||
Args:
|
||||
prompt: The main user prompt/query to send to the model
|
||||
model_name: Canonical model name or its alias (e.g., "gemini-2.5-pro", "flash", "pro")
|
||||
system_prompt: Optional system instructions to prepend to the prompt for context/behavior
|
||||
temperature: Controls randomness in generation (0.0=deterministic, 1.0=creative), default 0.3
|
||||
max_output_tokens: Optional maximum number of tokens to generate in the response
|
||||
thinking_mode: Thinking budget level for models that support it ("minimal", "low", "medium", "high", "max"), default "medium"
|
||||
images: Optional list of image paths or data URLs to include with the prompt (for vision models)
|
||||
**kwargs: Additional keyword arguments (reserved for future use)
|
||||
|
||||
Returns:
|
||||
ModelResponse: Contains the generated content, token usage stats, model metadata, and safety information
|
||||
"""
|
||||
# Validate parameters and fetch capabilities
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
self.validate_parameters(model_name, temperature)
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
|
||||
resolved_model_name = self._resolve_model_name(model_name)
|
||||
|
||||
# Prepare content parts (text and potentially images)
|
||||
parts = []
|
||||
|
||||
@@ -201,7 +217,7 @@ class GeminiModelProvider(ModelProvider):
|
||||
# Continue with other images and text
|
||||
continue
|
||||
elif images and not capabilities.supports_images:
|
||||
logger.warning(f"Model {resolved_name} does not support images, ignoring {len(images)} image(s)")
|
||||
logger.warning(f"Model {resolved_model_name} does not support images, ignoring {len(images)} image(s)")
|
||||
|
||||
# Create contents structure
|
||||
contents = [{"parts": parts}]
|
||||
@@ -219,7 +235,7 @@ class GeminiModelProvider(ModelProvider):
|
||||
# Add thinking configuration for models that support it
|
||||
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
|
||||
# Get model's max thinking tokens and calculate actual budget
|
||||
model_config = self.MODEL_CAPABILITIES.get(resolved_name)
|
||||
model_config = self.MODEL_CAPABILITIES.get(resolved_model_name)
|
||||
if model_config and model_config.max_thinking_tokens > 0:
|
||||
max_thinking_tokens = model_config.max_thinking_tokens
|
||||
actual_thinking_budget = int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
|
||||
@@ -233,7 +249,7 @@ class GeminiModelProvider(ModelProvider):
|
||||
def _attempt() -> ModelResponse:
|
||||
attempt_counter["value"] += 1
|
||||
response = self.client.models.generate_content(
|
||||
model=resolved_name,
|
||||
model=resolved_model_name,
|
||||
contents=contents,
|
||||
config=generation_config,
|
||||
)
|
||||
@@ -308,7 +324,7 @@ class GeminiModelProvider(ModelProvider):
|
||||
return ModelResponse(
|
||||
content=response.text,
|
||||
usage=usage,
|
||||
model_name=resolved_name,
|
||||
model_name=resolved_model_name,
|
||||
friendly_name="Gemini",
|
||||
provider=ProviderType.GOOGLE,
|
||||
metadata={
|
||||
@@ -324,12 +340,12 @@ class GeminiModelProvider(ModelProvider):
|
||||
operation=_attempt,
|
||||
max_attempts=max_retries,
|
||||
delays=retry_delays,
|
||||
log_prefix=f"Gemini API ({resolved_name})",
|
||||
log_prefix=f"Gemini API ({resolved_model_name})",
|
||||
)
|
||||
except Exception as exc:
|
||||
attempts = max(attempt_counter["value"], 1)
|
||||
error_msg = (
|
||||
f"Gemini API error for model {resolved_name} after {attempts} attempt"
|
||||
f"Gemini API error for model {resolved_model_name} after {attempts} attempt"
|
||||
f"{'s' if attempts > 1 else ''}: {exc}"
|
||||
)
|
||||
raise RuntimeError(error_msg) from exc
|
||||
|
||||
@@ -462,10 +462,11 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
|
||||
Args:
|
||||
prompt: User prompt to send to the model
|
||||
model_name: Name of the model to use
|
||||
model_name: Canonical model name or its alias
|
||||
system_prompt: Optional system prompt for model behavior
|
||||
temperature: Sampling temperature
|
||||
max_output_tokens: Maximum tokens to generate
|
||||
images: Optional list of image paths or data URLs to include with the prompt (for vision models)
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Returns:
|
||||
@@ -497,6 +498,9 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
# Validate parameters with the effective temperature
|
||||
self.validate_parameters(model_name, effective_temperature)
|
||||
|
||||
# Resolve to canonical model name
|
||||
resolved_model = self._resolve_model_name(model_name)
|
||||
|
||||
# Prepare messages
|
||||
messages = []
|
||||
if system_prompt:
|
||||
@@ -518,7 +522,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
# Continue with other images and text
|
||||
continue
|
||||
elif images and (not capabilities or not capabilities.supports_images):
|
||||
logging.warning(f"Model {model_name} does not support images, ignoring {len(images)} image(s)")
|
||||
logging.warning(f"Model {resolved_model} does not support images, ignoring {len(images)} image(s)")
|
||||
|
||||
# Add user message
|
||||
if len(user_content) == 1:
|
||||
@@ -529,14 +533,14 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
messages.append({"role": "user", "content": user_content})
|
||||
|
||||
# Prepare completion parameters
|
||||
# Always disable streaming for OpenRouter
|
||||
# MCP doesn't use streaming, and this avoids issues with O3 model access
|
||||
completion_params = {
|
||||
"model": model_name,
|
||||
"model": resolved_model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
# Check model capabilities once to determine parameter support
|
||||
resolved_model = self._resolve_model_name(model_name)
|
||||
|
||||
# Use the effective temperature we calculated earlier
|
||||
supports_sampling = effective_temperature is not None
|
||||
|
||||
@@ -553,7 +557,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
for key, value in kwargs.items():
|
||||
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
|
||||
# Reasoning models (those that don't support temperature) also don't support these parameters
|
||||
if not supports_sampling and key in ["top_p", "frequency_penalty", "presence_penalty"]:
|
||||
if not supports_sampling and key in ["top_p", "frequency_penalty", "presence_penalty", "stream"]:
|
||||
continue # Skip unsupported parameters for reasoning models
|
||||
completion_params[key] = value
|
||||
|
||||
@@ -585,7 +589,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
return ModelResponse(
|
||||
content=content,
|
||||
usage=usage,
|
||||
model_name=model_name,
|
||||
model_name=resolved_model,
|
||||
friendly_name=self.FRIENDLY_NAME,
|
||||
provider=self.get_provider_type(),
|
||||
metadata={
|
||||
@@ -601,12 +605,12 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
operation=_attempt,
|
||||
max_attempts=max_retries,
|
||||
delays=retry_delays,
|
||||
log_prefix=f"{self.FRIENDLY_NAME} API ({model_name})",
|
||||
log_prefix=f"{self.FRIENDLY_NAME} API ({resolved_model})",
|
||||
)
|
||||
except Exception as exc:
|
||||
attempts = max(attempt_counter["value"], 1)
|
||||
error_msg = (
|
||||
f"{self.FRIENDLY_NAME} API error for model {model_name} after {attempts} attempt"
|
||||
f"{self.FRIENDLY_NAME} API error for model {resolved_model} after {attempts} attempt"
|
||||
f"{'s' if attempts > 1 else ''}: {exc}"
|
||||
)
|
||||
logging.error(error_msg)
|
||||
@@ -618,7 +622,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
For proxy providers, this may use generic capabilities.
|
||||
|
||||
Args:
|
||||
model_name: Model to validate for
|
||||
model_name: Canonical model name or its alias
|
||||
temperature: Temperature to validate
|
||||
**kwargs: Additional parameters to validate
|
||||
"""
|
||||
|
||||
@@ -7,7 +7,7 @@ if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
|
||||
from .shared import ModelCapabilities, ProviderType, TemperatureConstraint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -253,33 +253,6 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
"""Get the provider type."""
|
||||
return ProviderType.OPENAI
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Request execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.3,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""Generate content using OpenAI API with proper model name resolution."""
|
||||
# Resolve model alias before making API call
|
||||
resolved_model_name = self._resolve_model_name(model_name)
|
||||
|
||||
# Call parent implementation with resolved model name
|
||||
return super().generate_content(
|
||||
prompt=prompt,
|
||||
model_name=resolved_model_name,
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature,
|
||||
max_output_tokens=max_output_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Provider preferences
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -8,7 +8,6 @@ from .openai_compatible import OpenAICompatibleProvider
|
||||
from .openrouter_registry import OpenRouterModelRegistry
|
||||
from .shared import (
|
||||
ModelCapabilities,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint,
|
||||
)
|
||||
@@ -111,50 +110,6 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
"""Identify this provider for restrictions and logging."""
|
||||
return ProviderType.OPENROUTER
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Request execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.3,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""Generate content using the OpenRouter API.
|
||||
|
||||
Args:
|
||||
prompt: User prompt to send to the model
|
||||
model_name: Name of the model (or alias) to use
|
||||
system_prompt: Optional system prompt for model behavior
|
||||
temperature: Sampling temperature
|
||||
max_output_tokens: Maximum tokens to generate
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated content and metadata
|
||||
"""
|
||||
# Resolve model alias to actual OpenRouter model name
|
||||
resolved_model = self._resolve_model_name(model_name)
|
||||
|
||||
# Always disable streaming for OpenRouter
|
||||
# MCP doesn't use streaming, and this avoids issues with O3 model access
|
||||
if "stream" not in kwargs:
|
||||
kwargs["stream"] = False
|
||||
|
||||
# Call parent method with resolved model name
|
||||
return super().generate_content(
|
||||
prompt=prompt,
|
||||
model_name=resolved_model,
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature,
|
||||
max_output_tokens=max_output_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Registry helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -7,7 +7,7 @@ if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
|
||||
from .shared import ModelCapabilities, ProviderType, TemperatureConstraint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -92,29 +92,6 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
"""Get the provider type."""
|
||||
return ProviderType.XAI
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.3,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""Generate content using X.AI API with proper model name resolution."""
|
||||
# Resolve model alias before making API call
|
||||
resolved_model_name = self._resolve_model_name(model_name)
|
||||
|
||||
# Call parent implementation with resolved model name
|
||||
return super().generate_content(
|
||||
prompt=prompt,
|
||||
model_name=resolved_model_name,
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature,
|
||||
max_output_tokens=max_output_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
|
||||
"""Get XAI's preferred model for a given category from allowed models.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user