refactor: moved registries into a separate module and code cleanup
fix: refactored dial provider to follow the same pattern
This commit is contained in:
169
conf/dial_models.json
Normal file
169
conf/dial_models.json
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
{
|
||||||
|
"_README": {
|
||||||
|
"description": "Model metadata for the DIAL (Data & AI Layer) aggregation provider.",
|
||||||
|
"documentation": "https://github.com/BeehiveInnovations/zen-mcp-server/blob/main/docs/configuration.md",
|
||||||
|
"usage": "Models listed here are exposed through the DIAL provider. Aliases are case-insensitive.",
|
||||||
|
"field_notes": "Matches providers/shared/model_capabilities.py.",
|
||||||
|
"field_descriptions": {
|
||||||
|
"model_name": "The model identifier as exposed by DIAL (typically deployment name)",
|
||||||
|
"aliases": "Array of shorthand names users can type instead of the full model name",
|
||||||
|
"context_window": "Total number of tokens the model can process (input + output combined)",
|
||||||
|
"max_output_tokens": "Maximum number of tokens the model can generate in a single response",
|
||||||
|
"supports_extended_thinking": "Whether the model supports extended reasoning tokens",
|
||||||
|
"supports_json_mode": "Whether the model can guarantee valid JSON output",
|
||||||
|
"supports_function_calling": "Whether the model supports function/tool calling",
|
||||||
|
"supports_images": "Whether the model can process images/visual input",
|
||||||
|
"max_image_size_mb": "Maximum total size in MB for all images combined",
|
||||||
|
"supports_temperature": "Whether the model accepts the temperature parameter",
|
||||||
|
"temperature_constraint": "Temperature constraint hint: 'fixed', 'range', or 'discrete'",
|
||||||
|
"description": "Human-readable description of the model",
|
||||||
|
"intelligence_score": "1-20 human rating used as the primary signal for auto-mode ordering"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"model_name": "o3-2025-04-16",
|
||||||
|
"friendly_name": "DIAL (O3)",
|
||||||
|
"aliases": ["o3"],
|
||||||
|
"intelligence_score": 14,
|
||||||
|
"description": "OpenAI O3 via DIAL - Strong reasoning model",
|
||||||
|
"context_window": 200000,
|
||||||
|
"max_output_tokens": 100000,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_function_calling": false,
|
||||||
|
"supports_json_mode": true,
|
||||||
|
"supports_images": true,
|
||||||
|
"max_image_size_mb": 20.0,
|
||||||
|
"supports_temperature": false,
|
||||||
|
"temperature_constraint": "fixed"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "o4-mini-2025-04-16",
|
||||||
|
"friendly_name": "DIAL (O4-mini)",
|
||||||
|
"aliases": ["o4-mini"],
|
||||||
|
"intelligence_score": 11,
|
||||||
|
"description": "OpenAI O4-mini via DIAL - Fast reasoning model",
|
||||||
|
"context_window": 200000,
|
||||||
|
"max_output_tokens": 100000,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_function_calling": false,
|
||||||
|
"supports_json_mode": true,
|
||||||
|
"supports_images": true,
|
||||||
|
"max_image_size_mb": 20.0,
|
||||||
|
"supports_temperature": false,
|
||||||
|
"temperature_constraint": "fixed"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "anthropic.claude-sonnet-4.1-20250805-v1:0",
|
||||||
|
"friendly_name": "DIAL (Sonnet 4.1)",
|
||||||
|
"aliases": ["sonnet-4.1", "sonnet-4"],
|
||||||
|
"intelligence_score": 10,
|
||||||
|
"description": "Claude Sonnet 4.1 via DIAL - Balanced performance",
|
||||||
|
"context_window": 200000,
|
||||||
|
"max_output_tokens": 64000,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_function_calling": false,
|
||||||
|
"supports_json_mode": true,
|
||||||
|
"supports_images": true,
|
||||||
|
"max_image_size_mb": 5.0,
|
||||||
|
"supports_temperature": true,
|
||||||
|
"temperature_constraint": "range"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "anthropic.claude-sonnet-4.1-20250805-v1:0-with-thinking",
|
||||||
|
"friendly_name": "DIAL (Sonnet 4.1 Thinking)",
|
||||||
|
"aliases": ["sonnet-4.1-thinking", "sonnet-4-thinking"],
|
||||||
|
"intelligence_score": 11,
|
||||||
|
"description": "Claude Sonnet 4.1 with thinking mode via DIAL",
|
||||||
|
"context_window": 200000,
|
||||||
|
"max_output_tokens": 64000,
|
||||||
|
"supports_extended_thinking": true,
|
||||||
|
"supports_function_calling": false,
|
||||||
|
"supports_json_mode": true,
|
||||||
|
"supports_images": true,
|
||||||
|
"max_image_size_mb": 5.0,
|
||||||
|
"supports_temperature": true,
|
||||||
|
"temperature_constraint": "range"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "anthropic.claude-opus-4.1-20250805-v1:0",
|
||||||
|
"friendly_name": "DIAL (Opus 4.1)",
|
||||||
|
"aliases": ["opus-4.1", "opus-4"],
|
||||||
|
"intelligence_score": 14,
|
||||||
|
"description": "Claude Opus 4.1 via DIAL - Most capable Claude model",
|
||||||
|
"context_window": 200000,
|
||||||
|
"max_output_tokens": 64000,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_function_calling": false,
|
||||||
|
"supports_json_mode": true,
|
||||||
|
"supports_images": true,
|
||||||
|
"max_image_size_mb": 5.0,
|
||||||
|
"supports_temperature": true,
|
||||||
|
"temperature_constraint": "range"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "anthropic.claude-opus-4.1-20250805-v1:0-with-thinking",
|
||||||
|
"friendly_name": "DIAL (Opus 4.1 Thinking)",
|
||||||
|
"aliases": ["opus-4.1-thinking", "opus-4-thinking"],
|
||||||
|
"intelligence_score": 15,
|
||||||
|
"description": "Claude Opus 4.1 with thinking mode via DIAL",
|
||||||
|
"context_window": 200000,
|
||||||
|
"max_output_tokens": 64000,
|
||||||
|
"supports_extended_thinking": true,
|
||||||
|
"supports_function_calling": false,
|
||||||
|
"supports_json_mode": true,
|
||||||
|
"supports_images": true,
|
||||||
|
"max_image_size_mb": 5.0,
|
||||||
|
"supports_temperature": true,
|
||||||
|
"temperature_constraint": "range"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gemini-2.5-pro-preview-03-25-google-search",
|
||||||
|
"friendly_name": "DIAL (Gemini 2.5 Pro Search)",
|
||||||
|
"aliases": ["gemini-2.5-pro-search"],
|
||||||
|
"intelligence_score": 17,
|
||||||
|
"description": "Gemini 2.5 Pro with Google Search via DIAL",
|
||||||
|
"context_window": 1000000,
|
||||||
|
"max_output_tokens": 65536,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_function_calling": false,
|
||||||
|
"supports_json_mode": true,
|
||||||
|
"supports_images": true,
|
||||||
|
"max_image_size_mb": 20.0,
|
||||||
|
"supports_temperature": true,
|
||||||
|
"temperature_constraint": "range"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gemini-2.5-pro-preview-05-06",
|
||||||
|
"friendly_name": "DIAL (Gemini 2.5 Pro)",
|
||||||
|
"aliases": ["gemini-2.5-pro"],
|
||||||
|
"intelligence_score": 18,
|
||||||
|
"description": "Gemini 2.5 Pro via DIAL - Deep reasoning",
|
||||||
|
"context_window": 1000000,
|
||||||
|
"max_output_tokens": 65536,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_function_calling": false,
|
||||||
|
"supports_json_mode": true,
|
||||||
|
"supports_images": true,
|
||||||
|
"max_image_size_mb": 20.0,
|
||||||
|
"supports_temperature": true,
|
||||||
|
"temperature_constraint": "range"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gemini-2.5-flash-preview-05-20",
|
||||||
|
"friendly_name": "DIAL (Gemini Flash 2.5)",
|
||||||
|
"aliases": ["gemini-2.5-flash"],
|
||||||
|
"intelligence_score": 10,
|
||||||
|
"description": "Gemini 2.5 Flash via DIAL - Ultra-fast",
|
||||||
|
"context_window": 1000000,
|
||||||
|
"max_output_tokens": 65536,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_function_calling": false,
|
||||||
|
"supports_json_mode": true,
|
||||||
|
"supports_images": true,
|
||||||
|
"max_image_size_mb": 20.0,
|
||||||
|
"supports_temperature": true,
|
||||||
|
"temperature_constraint": "range"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -53,7 +53,7 @@
|
|||||||
"gpt5-pro"
|
"gpt5-pro"
|
||||||
],
|
],
|
||||||
"intelligence_score": 18,
|
"intelligence_score": 18,
|
||||||
"description": "GPT-5 Pro (400K context, 272K output) - Advanced model with reasoning support",
|
"description": "GPT-5 Pro (400K context, 272K output) - Very advanced, reasoning model",
|
||||||
"context_window": 400000,
|
"context_window": 400000,
|
||||||
"max_output_tokens": 272000,
|
"max_output_tokens": 272000,
|
||||||
"supports_extended_thinking": true,
|
"supports_extended_thinking": true,
|
||||||
@@ -156,7 +156,7 @@
|
|||||||
"o3pro"
|
"o3pro"
|
||||||
],
|
],
|
||||||
"intelligence_score": 15,
|
"intelligence_score": 15,
|
||||||
"description": "Professional-grade reasoning (200K context) - EXTREMELY EXPENSIVE: Only for the most complex problems requiring universe-scale complexity analysis OR when the user explicitly asks for this model. Use sparingly for critical architectural decisions or exceptionally complex debugging that other models cannot handle.",
|
"description": "Professional-grade reasoning (200K context)",
|
||||||
"context_window": 200000,
|
"context_window": 200000,
|
||||||
"max_output_tokens": 65536,
|
"max_output_tokens": 65536,
|
||||||
"supports_extended_thinking": false,
|
"supports_extended_thinking": false,
|
||||||
|
|||||||
@@ -30,7 +30,8 @@ DEFAULT_MODEL = get_env("DEFAULT_MODEL", "auto") or "auto"
|
|||||||
# Auto mode detection - when DEFAULT_MODEL is "auto", Claude picks the model
|
# Auto mode detection - when DEFAULT_MODEL is "auto", Claude picks the model
|
||||||
IS_AUTO_MODE = DEFAULT_MODEL.lower() == "auto"
|
IS_AUTO_MODE = DEFAULT_MODEL.lower() == "auto"
|
||||||
|
|
||||||
# Each provider (gemini.py, openai_provider.py, xai.py) defines its own MODEL_CAPABILITIES
|
# Each provider (gemini.py, openai.py, xai.py, dial.py, openrouter.py, custom.py, azure_openai.py)
|
||||||
|
# defines its own MODEL_CAPABILITIES
|
||||||
# with detailed descriptions. Tools use ModelProviderRegistry.get_available_model_names()
|
# with detailed descriptions. Tools use ModelProviderRegistry.get_available_model_names()
|
||||||
# to get models only from enabled providers (those with valid API keys).
|
# to get models only from enabled providers (those with valid API keys).
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ DEFAULT_MODEL=auto # Claude picks best model for each task (recommended)
|
|||||||
- `conf/gemini_models.json` – Gemini catalogue (`GEMINI_MODELS_CONFIG_PATH`)
|
- `conf/gemini_models.json` – Gemini catalogue (`GEMINI_MODELS_CONFIG_PATH`)
|
||||||
- `conf/xai_models.json` – X.AI / GROK catalogue (`XAI_MODELS_CONFIG_PATH`)
|
- `conf/xai_models.json` – X.AI / GROK catalogue (`XAI_MODELS_CONFIG_PATH`)
|
||||||
- `conf/openrouter_models.json` – OpenRouter catalogue (`OPENROUTER_MODELS_CONFIG_PATH`)
|
- `conf/openrouter_models.json` – OpenRouter catalogue (`OPENROUTER_MODELS_CONFIG_PATH`)
|
||||||
|
- `conf/dial_models.json` – DIAL aggregation catalogue (`DIAL_MODELS_CONFIG_PATH`)
|
||||||
- `conf/custom_models.json` – Custom/OpenAI-compatible endpoints (`CUSTOM_MODELS_CONFIG_PATH`)
|
- `conf/custom_models.json` – Custom/OpenAI-compatible endpoints (`CUSTOM_MODELS_CONFIG_PATH`)
|
||||||
|
|
||||||
Each JSON file documents the allowed fields via its `_README` block and controls model aliases, capability limits, and feature flags. Edit these files (or point the matching `*_MODELS_CONFIG_PATH` variable to your own copy) when you want to adjust context windows, enable JSON mode, or expose additional aliases without touching Python code.
|
Each JSON file documents the allowed fields via its `_README` block and controls model aliases, capability limits, and feature flags. Edit these files (or point the matching `*_MODELS_CONFIG_PATH` variable to your own copy) when you want to adjust context windows, enable JSON mode, or expose additional aliases without touching Python code.
|
||||||
@@ -154,6 +155,7 @@ OPENAI_MODELS_CONFIG_PATH=/path/to/openai_models.json
|
|||||||
GEMINI_MODELS_CONFIG_PATH=/path/to/gemini_models.json
|
GEMINI_MODELS_CONFIG_PATH=/path/to/gemini_models.json
|
||||||
XAI_MODELS_CONFIG_PATH=/path/to/xai_models.json
|
XAI_MODELS_CONFIG_PATH=/path/to/xai_models.json
|
||||||
OPENROUTER_MODELS_CONFIG_PATH=/path/to/openrouter_models.json
|
OPENROUTER_MODELS_CONFIG_PATH=/path/to/openrouter_models.json
|
||||||
|
DIAL_MODELS_CONFIG_PATH=/path/to/dial_models.json
|
||||||
CUSTOM_MODELS_CONFIG_PATH=/path/to/custom_models.json
|
CUSTOM_MODELS_CONFIG_PATH=/path/to/custom_models.json
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ Zen ships multiple registries:
|
|||||||
- `conf/gemini_models.json` – native Google Gemini catalogue (`GEMINI_MODELS_CONFIG_PATH`)
|
- `conf/gemini_models.json` – native Google Gemini catalogue (`GEMINI_MODELS_CONFIG_PATH`)
|
||||||
- `conf/xai_models.json` – native X.AI / GROK catalogue (`XAI_MODELS_CONFIG_PATH`)
|
- `conf/xai_models.json` – native X.AI / GROK catalogue (`XAI_MODELS_CONFIG_PATH`)
|
||||||
- `conf/openrouter_models.json` – OpenRouter catalogue (`OPENROUTER_MODELS_CONFIG_PATH`)
|
- `conf/openrouter_models.json` – OpenRouter catalogue (`OPENROUTER_MODELS_CONFIG_PATH`)
|
||||||
|
- `conf/dial_models.json` – DIAL aggregation catalogue (`DIAL_MODELS_CONFIG_PATH`)
|
||||||
- `conf/custom_models.json` – local/self-hosted OpenAI-compatible catalogue (`CUSTOM_MODELS_CONFIG_PATH`)
|
- `conf/custom_models.json` – local/self-hosted OpenAI-compatible catalogue (`CUSTOM_MODELS_CONFIG_PATH`)
|
||||||
|
|
||||||
Copy whichever file you need into your project (or point the corresponding `*_MODELS_CONFIG_PATH` env var at your own copy) and edit it to advertise the models you want.
|
Copy whichever file you need into your project (or point the corresponding `*_MODELS_CONFIG_PATH` env var at your own copy) and edit it to advertise the models you want.
|
||||||
@@ -71,7 +72,7 @@ Consult the JSON file for the full list, aliases, and capability flags. Add new
|
|||||||
|
|
||||||
View the baseline OpenRouter catalogue in [`conf/openrouter_models.json`](conf/openrouter_models.json) and populate [`conf/custom_models.json`](conf/custom_models.json) with your local models.
|
View the baseline OpenRouter catalogue in [`conf/openrouter_models.json`](conf/openrouter_models.json) and populate [`conf/custom_models.json`](conf/custom_models.json) with your local models.
|
||||||
|
|
||||||
Native catalogues (`conf/openai_models.json`, `conf/gemini_models.json`, `conf/xai_models.json`) follow the same schema. Updating those files lets you:
|
Native catalogues (`conf/openai_models.json`, `conf/gemini_models.json`, `conf/xai_models.json`, `conf/dial_models.json`) follow the same schema. Updating those files lets you:
|
||||||
|
|
||||||
- Expose new aliases (e.g., map `enterprise-pro` to `gpt-5-pro`)
|
- Expose new aliases (e.g., map `enterprise-pro` to `gpt-5-pro`)
|
||||||
- Advertise support for JSON mode or vision if the upstream provider adds it
|
- Advertise support for JSON mode or vision if the upstream provider adds it
|
||||||
|
|||||||
@@ -3,8 +3,8 @@
|
|||||||
from .azure_openai import AzureOpenAIProvider
|
from .azure_openai import AzureOpenAIProvider
|
||||||
from .base import ModelProvider
|
from .base import ModelProvider
|
||||||
from .gemini import GeminiModelProvider
|
from .gemini import GeminiModelProvider
|
||||||
|
from .openai import OpenAIModelProvider
|
||||||
from .openai_compatible import OpenAICompatibleProvider
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
from .openai_provider import OpenAIModelProvider
|
|
||||||
from .openrouter import OpenRouterProvider
|
from .openrouter import OpenRouterProvider
|
||||||
from .registry import ModelProviderRegistry
|
from .registry import ModelProviderRegistry
|
||||||
from .shared import ModelCapabilities, ModelResponse
|
from .shared import ModelCapabilities, ModelResponse
|
||||||
|
|||||||
@@ -12,9 +12,9 @@ except ImportError: # pragma: no cover
|
|||||||
|
|
||||||
from utils.env import get_env, suppress_env_vars
|
from utils.env import get_env, suppress_env_vars
|
||||||
|
|
||||||
from .azure_registry import AzureModelRegistry
|
from .openai import OpenAIModelProvider
|
||||||
from .openai_compatible import OpenAICompatibleProvider
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
from .openai_provider import OpenAIModelProvider
|
from .registries.azure import AzureModelRegistry
|
||||||
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
|
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -4,11 +4,12 @@ import logging
|
|||||||
|
|
||||||
from utils.env import get_env
|
from utils.env import get_env
|
||||||
|
|
||||||
from .custom_registry import CustomEndpointModelRegistry
|
|
||||||
from .openai_compatible import OpenAICompatibleProvider
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
from .openrouter_registry import OpenRouterModelRegistry
|
from .registries.custom import CustomEndpointModelRegistry
|
||||||
|
from .registries.openrouter import OpenRouterModelRegistry
|
||||||
from .shared import ModelCapabilities, ProviderType
|
from .shared import ModelCapabilities, ProviderType
|
||||||
|
|
||||||
|
|
||||||
class CustomProvider(OpenAICompatibleProvider):
|
class CustomProvider(OpenAICompatibleProvider):
|
||||||
"""Adapter for self-hosted or local OpenAI-compatible endpoints.
|
"""Adapter for self-hosted or local OpenAI-compatible endpoints.
|
||||||
|
|
||||||
|
|||||||
@@ -2,17 +2,19 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from typing import Optional
|
from typing import ClassVar, Optional
|
||||||
|
|
||||||
from utils.env import get_env
|
from utils.env import get_env
|
||||||
|
|
||||||
from .openai_compatible import OpenAICompatibleProvider
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
|
from .registries.dial import DialModelRegistry
|
||||||
|
from .registry_provider_mixin import RegistryBackedProviderMixin
|
||||||
|
from .shared import ModelCapabilities, ModelResponse, ProviderType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DIALModelProvider(OpenAICompatibleProvider):
|
class DIALModelProvider(RegistryBackedProviderMixin, OpenAICompatibleProvider):
|
||||||
"""Client for the DIAL (Data & AI Layer) aggregation service.
|
"""Client for the DIAL (Data & AI Layer) aggregation service.
|
||||||
|
|
||||||
DIAL exposes several third-party models behind a single OpenAI-compatible
|
DIAL exposes several third-party models behind a single OpenAI-compatible
|
||||||
@@ -23,185 +25,13 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
|
|
||||||
FRIENDLY_NAME = "DIAL"
|
FRIENDLY_NAME = "DIAL"
|
||||||
|
|
||||||
|
REGISTRY_CLASS = DialModelRegistry
|
||||||
|
MODEL_CAPABILITIES: ClassVar[dict[str, ModelCapabilities]] = {}
|
||||||
|
|
||||||
# Retry configuration for API calls
|
# Retry configuration for API calls
|
||||||
MAX_RETRIES = 4
|
MAX_RETRIES = 4
|
||||||
RETRY_DELAYS = [1, 3, 5, 8] # seconds
|
RETRY_DELAYS = [1, 3, 5, 8] # seconds
|
||||||
|
|
||||||
# Model configurations using ModelCapabilities objects
|
|
||||||
MODEL_CAPABILITIES = {
|
|
||||||
"o3-2025-04-16": ModelCapabilities(
|
|
||||||
provider=ProviderType.DIAL,
|
|
||||||
model_name="o3-2025-04-16",
|
|
||||||
friendly_name="DIAL (O3)",
|
|
||||||
intelligence_score=14,
|
|
||||||
context_window=200_000,
|
|
||||||
max_output_tokens=100_000,
|
|
||||||
supports_extended_thinking=False,
|
|
||||||
supports_system_prompts=True,
|
|
||||||
supports_streaming=True,
|
|
||||||
supports_function_calling=False, # DIAL may not expose function calling
|
|
||||||
supports_json_mode=True,
|
|
||||||
supports_images=True,
|
|
||||||
max_image_size_mb=20.0,
|
|
||||||
supports_temperature=False, # O3 models don't accept temperature
|
|
||||||
temperature_constraint=TemperatureConstraint.create("fixed"),
|
|
||||||
description="OpenAI O3 via DIAL - Strong reasoning model",
|
|
||||||
aliases=["o3"],
|
|
||||||
),
|
|
||||||
"o4-mini-2025-04-16": ModelCapabilities(
|
|
||||||
provider=ProviderType.DIAL,
|
|
||||||
model_name="o4-mini-2025-04-16",
|
|
||||||
friendly_name="DIAL (O4-mini)",
|
|
||||||
intelligence_score=11,
|
|
||||||
context_window=200_000,
|
|
||||||
max_output_tokens=100_000,
|
|
||||||
supports_extended_thinking=False,
|
|
||||||
supports_system_prompts=True,
|
|
||||||
supports_streaming=True,
|
|
||||||
supports_function_calling=False, # DIAL may not expose function calling
|
|
||||||
supports_json_mode=True,
|
|
||||||
supports_images=True,
|
|
||||||
max_image_size_mb=20.0,
|
|
||||||
supports_temperature=False, # O4 models don't accept temperature
|
|
||||||
temperature_constraint=TemperatureConstraint.create("fixed"),
|
|
||||||
description="OpenAI O4-mini via DIAL - Fast reasoning model",
|
|
||||||
aliases=["o4-mini"],
|
|
||||||
),
|
|
||||||
"anthropic.claude-sonnet-4.1-20250805-v1:0": ModelCapabilities(
|
|
||||||
provider=ProviderType.DIAL,
|
|
||||||
model_name="anthropic.claude-sonnet-4.1-20250805-v1:0",
|
|
||||||
friendly_name="DIAL (Sonnet 4.1)",
|
|
||||||
intelligence_score=10,
|
|
||||||
context_window=200_000,
|
|
||||||
max_output_tokens=64_000,
|
|
||||||
supports_extended_thinking=False,
|
|
||||||
supports_system_prompts=True,
|
|
||||||
supports_streaming=True,
|
|
||||||
supports_function_calling=False,
|
|
||||||
supports_json_mode=True,
|
|
||||||
supports_images=True,
|
|
||||||
max_image_size_mb=5.0,
|
|
||||||
supports_temperature=True,
|
|
||||||
temperature_constraint=TemperatureConstraint.create("range"),
|
|
||||||
description="Claude Sonnet 4.1 via DIAL - Balanced performance",
|
|
||||||
aliases=["sonnet-4.1", "sonnet-4"],
|
|
||||||
),
|
|
||||||
"anthropic.claude-sonnet-4.1-20250805-v1:0-with-thinking": ModelCapabilities(
|
|
||||||
provider=ProviderType.DIAL,
|
|
||||||
model_name="anthropic.claude-sonnet-4.1-20250805-v1:0-with-thinking",
|
|
||||||
friendly_name="DIAL (Sonnet 4.1 Thinking)",
|
|
||||||
intelligence_score=11,
|
|
||||||
context_window=200_000,
|
|
||||||
max_output_tokens=64_000,
|
|
||||||
supports_extended_thinking=True,
|
|
||||||
supports_system_prompts=True,
|
|
||||||
supports_streaming=True,
|
|
||||||
supports_function_calling=False,
|
|
||||||
supports_json_mode=True,
|
|
||||||
supports_images=True,
|
|
||||||
max_image_size_mb=5.0,
|
|
||||||
supports_temperature=True,
|
|
||||||
temperature_constraint=TemperatureConstraint.create("range"),
|
|
||||||
description="Claude Sonnet 4.1 with thinking mode via DIAL",
|
|
||||||
aliases=["sonnet-4.1-thinking", "sonnet-4-thinking"],
|
|
||||||
),
|
|
||||||
"anthropic.claude-opus-4.1-20250805-v1:0": ModelCapabilities(
|
|
||||||
provider=ProviderType.DIAL,
|
|
||||||
model_name="anthropic.claude-opus-4.1-20250805-v1:0",
|
|
||||||
friendly_name="DIAL (Opus 4.1)",
|
|
||||||
intelligence_score=14,
|
|
||||||
context_window=200_000,
|
|
||||||
max_output_tokens=64_000,
|
|
||||||
supports_extended_thinking=False,
|
|
||||||
supports_system_prompts=True,
|
|
||||||
supports_streaming=True,
|
|
||||||
supports_function_calling=False,
|
|
||||||
supports_json_mode=True,
|
|
||||||
supports_images=True,
|
|
||||||
max_image_size_mb=5.0,
|
|
||||||
supports_temperature=True,
|
|
||||||
temperature_constraint=TemperatureConstraint.create("range"),
|
|
||||||
description="Claude Opus 4.1 via DIAL - Most capable Claude model",
|
|
||||||
aliases=["opus-4.1", "opus-4"],
|
|
||||||
),
|
|
||||||
"anthropic.claude-opus-4.1-20250805-v1:0-with-thinking": ModelCapabilities(
|
|
||||||
provider=ProviderType.DIAL,
|
|
||||||
model_name="anthropic.claude-opus-4.1-20250805-v1:0-with-thinking",
|
|
||||||
friendly_name="DIAL (Opus 4.1 Thinking)",
|
|
||||||
intelligence_score=15,
|
|
||||||
context_window=200_000,
|
|
||||||
max_output_tokens=64_000,
|
|
||||||
supports_extended_thinking=True,
|
|
||||||
supports_system_prompts=True,
|
|
||||||
supports_streaming=True,
|
|
||||||
supports_function_calling=False,
|
|
||||||
supports_json_mode=True,
|
|
||||||
supports_images=True,
|
|
||||||
max_image_size_mb=5.0,
|
|
||||||
supports_temperature=True,
|
|
||||||
temperature_constraint=TemperatureConstraint.create("range"),
|
|
||||||
description="Claude Opus 4.1 with thinking mode via DIAL",
|
|
||||||
aliases=["opus-4.1-thinking", "opus-4-thinking"],
|
|
||||||
),
|
|
||||||
"gemini-2.5-pro-preview-03-25-google-search": ModelCapabilities(
|
|
||||||
provider=ProviderType.DIAL,
|
|
||||||
model_name="gemini-2.5-pro-preview-03-25-google-search",
|
|
||||||
friendly_name="DIAL (Gemini 2.5 Pro Search)",
|
|
||||||
intelligence_score=17,
|
|
||||||
context_window=1_000_000,
|
|
||||||
max_output_tokens=65_536,
|
|
||||||
supports_extended_thinking=False,
|
|
||||||
supports_system_prompts=True,
|
|
||||||
supports_streaming=True,
|
|
||||||
supports_function_calling=False,
|
|
||||||
supports_json_mode=True,
|
|
||||||
supports_images=True,
|
|
||||||
max_image_size_mb=20.0,
|
|
||||||
supports_temperature=True,
|
|
||||||
temperature_constraint=TemperatureConstraint.create("range"),
|
|
||||||
description="Gemini 2.5 Pro with Google Search via DIAL",
|
|
||||||
aliases=["gemini-2.5-pro-search"],
|
|
||||||
),
|
|
||||||
"gemini-2.5-pro-preview-05-06": ModelCapabilities(
|
|
||||||
provider=ProviderType.DIAL,
|
|
||||||
model_name="gemini-2.5-pro-preview-05-06",
|
|
||||||
friendly_name="DIAL (Gemini 2.5 Pro)",
|
|
||||||
intelligence_score=18,
|
|
||||||
context_window=1_000_000,
|
|
||||||
max_output_tokens=65_536,
|
|
||||||
supports_extended_thinking=False,
|
|
||||||
supports_system_prompts=True,
|
|
||||||
supports_streaming=True,
|
|
||||||
supports_function_calling=False,
|
|
||||||
supports_json_mode=True,
|
|
||||||
supports_images=True,
|
|
||||||
max_image_size_mb=20.0,
|
|
||||||
supports_temperature=True,
|
|
||||||
temperature_constraint=TemperatureConstraint.create("range"),
|
|
||||||
description="Gemini 2.5 Pro via DIAL - Deep reasoning",
|
|
||||||
aliases=["gemini-2.5-pro"],
|
|
||||||
),
|
|
||||||
"gemini-2.5-flash-preview-05-20": ModelCapabilities(
|
|
||||||
provider=ProviderType.DIAL,
|
|
||||||
model_name="gemini-2.5-flash-preview-05-20",
|
|
||||||
friendly_name="DIAL (Gemini Flash 2.5)",
|
|
||||||
intelligence_score=10,
|
|
||||||
context_window=1_000_000,
|
|
||||||
max_output_tokens=65_536,
|
|
||||||
supports_extended_thinking=False,
|
|
||||||
supports_system_prompts=True,
|
|
||||||
supports_streaming=True,
|
|
||||||
supports_function_calling=False,
|
|
||||||
supports_json_mode=True,
|
|
||||||
supports_images=True,
|
|
||||||
max_image_size_mb=20.0,
|
|
||||||
supports_temperature=True,
|
|
||||||
temperature_constraint=TemperatureConstraint.create("range"),
|
|
||||||
description="Gemini 2.5 Flash via DIAL - Ultra-fast",
|
|
||||||
aliases=["gemini-2.5-flash"],
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, api_key: str, **kwargs):
|
def __init__(self, api_key: str, **kwargs):
|
||||||
"""Initialize DIAL provider with API key and host.
|
"""Initialize DIAL provider with API key and host.
|
||||||
|
|
||||||
@@ -209,6 +39,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
api_key: DIAL API key for authentication
|
api_key: DIAL API key for authentication
|
||||||
**kwargs: Additional configuration options
|
**kwargs: Additional configuration options
|
||||||
"""
|
"""
|
||||||
|
self._ensure_registry()
|
||||||
# Get DIAL API host from environment or kwargs
|
# Get DIAL API host from environment or kwargs
|
||||||
dial_host = kwargs.get("base_url") or get_env("DIAL_API_HOST") or "https://core.dialx.ai"
|
dial_host = kwargs.get("base_url") or get_env("DIAL_API_HOST") or "https://core.dialx.ai"
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, ClassVar, Optional
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tools.models import ToolModelCategory
|
from tools.models import ToolModelCategory
|
||||||
@@ -14,7 +14,7 @@ from utils.env import get_env
|
|||||||
from utils.image_utils import validate_image
|
from utils.image_utils import validate_image
|
||||||
|
|
||||||
from .base import ModelProvider
|
from .base import ModelProvider
|
||||||
from .gemini_registry import GeminiModelRegistry
|
from .registries.gemini import GeminiModelRegistry
|
||||||
from .registry_provider_mixin import RegistryBackedProviderMixin
|
from .registry_provider_mixin import RegistryBackedProviderMixin
|
||||||
from .shared import ModelCapabilities, ModelResponse, ProviderType
|
from .shared import ModelCapabilities, ModelResponse, ProviderType
|
||||||
|
|
||||||
@@ -30,7 +30,7 @@ class GeminiModelProvider(RegistryBackedProviderMixin, ModelProvider):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
REGISTRY_CLASS = GeminiModelRegistry
|
REGISTRY_CLASS = GeminiModelRegistry
|
||||||
MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {}
|
MODEL_CAPABILITIES: ClassVar[dict[str, ModelCapabilities]] = {}
|
||||||
|
|
||||||
# Thinking mode configurations - percentages of model's max_thinking_tokens
|
# Thinking mode configurations - percentages of model's max_thinking_tokens
|
||||||
# These percentages work across all models that support thinking
|
# These percentages work across all models that support thinking
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
"""OpenAI model provider implementation."""
|
"""OpenAI model provider implementation."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, ClassVar, Optional
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tools.models import ToolModelCategory
|
from tools.models import ToolModelCategory
|
||||||
|
|
||||||
from .openai_compatible import OpenAICompatibleProvider
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
from .openai_registry import OpenAIModelRegistry
|
from .registries.openai import OpenAIModelRegistry
|
||||||
from .registry_provider_mixin import RegistryBackedProviderMixin
|
from .registry_provider_mixin import RegistryBackedProviderMixin
|
||||||
from .shared import ModelCapabilities, ProviderType
|
from .shared import ModelCapabilities, ProviderType
|
||||||
|
|
||||||
@@ -23,7 +23,7 @@ class OpenAIModelProvider(RegistryBackedProviderMixin, OpenAICompatibleProvider)
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
REGISTRY_CLASS = OpenAIModelRegistry
|
REGISTRY_CLASS = OpenAIModelRegistry
|
||||||
MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {}
|
MODEL_CAPABILITIES: ClassVar[dict[str, ModelCapabilities]] = {}
|
||||||
|
|
||||||
def __init__(self, api_key: str, **kwargs):
|
def __init__(self, api_key: str, **kwargs):
|
||||||
"""Initialize OpenAI provider with API key."""
|
"""Initialize OpenAI provider with API key."""
|
||||||
@@ -50,7 +50,7 @@ class OpenAIModelProvider(RegistryBackedProviderMixin, OpenAICompatibleProvider)
|
|||||||
return builtin
|
return builtin
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .openrouter_registry import OpenRouterModelRegistry
|
from .registries.openrouter import OpenRouterModelRegistry
|
||||||
|
|
||||||
registry = OpenRouterModelRegistry()
|
registry = OpenRouterModelRegistry()
|
||||||
config = registry.get_model_config(canonical_name)
|
config = registry.get_model_config(canonical_name)
|
||||||
@@ -5,7 +5,7 @@ import logging
|
|||||||
from utils.env import get_env
|
from utils.env import get_env
|
||||||
|
|
||||||
from .openai_compatible import OpenAICompatibleProvider
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
from .openrouter_registry import OpenRouterModelRegistry
|
from .registries.openrouter import OpenRouterModelRegistry
|
||||||
from .shared import (
|
from .shared import (
|
||||||
ModelCapabilities,
|
ModelCapabilities,
|
||||||
ProviderType,
|
ProviderType,
|
||||||
|
|||||||
19
providers/registries/__init__.py
Normal file
19
providers/registries/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
"""Registry implementations for provider capability manifests."""
|
||||||
|
|
||||||
|
from .azure import AzureModelRegistry
|
||||||
|
from .custom import CustomEndpointModelRegistry
|
||||||
|
from .dial import DialModelRegistry
|
||||||
|
from .gemini import GeminiModelRegistry
|
||||||
|
from .openai import OpenAIModelRegistry
|
||||||
|
from .openrouter import OpenRouterModelRegistry
|
||||||
|
from .xai import XAIModelRegistry
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AzureModelRegistry",
|
||||||
|
"CustomEndpointModelRegistry",
|
||||||
|
"DialModelRegistry",
|
||||||
|
"GeminiModelRegistry",
|
||||||
|
"OpenAIModelRegistry",
|
||||||
|
"OpenRouterModelRegistry",
|
||||||
|
"XAIModelRegistry",
|
||||||
|
]
|
||||||
@@ -4,8 +4,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from .model_registry_base import CAPABILITY_FIELD_NAMES, CustomModelRegistryBase
|
from ..shared import ModelCapabilities, ProviderType, TemperatureConstraint
|
||||||
from .shared import ModelCapabilities, ProviderType, TemperatureConstraint
|
from .base import CAPABILITY_FIELD_NAMES, CustomModelRegistryBase
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -12,7 +12,7 @@ from pathlib import Path
|
|||||||
from utils.env import get_env
|
from utils.env import get_env
|
||||||
from utils.file_utils import read_json_file
|
from utils.file_utils import read_json_file
|
||||||
|
|
||||||
from .shared import ModelCapabilities, ProviderType, TemperatureConstraint
|
from ..shared import ModelCapabilities, ProviderType, TemperatureConstraint
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -34,7 +34,7 @@ class CustomModelRegistryBase:
|
|||||||
self._default_filename = default_filename
|
self._default_filename = default_filename
|
||||||
self._use_resources = False
|
self._use_resources = False
|
||||||
self._resource_package = "conf"
|
self._resource_package = "conf"
|
||||||
self._default_path = Path(__file__).parent.parent / "conf" / default_filename
|
self._default_path = Path(__file__).resolve().parents[3] / "conf" / default_filename
|
||||||
|
|
||||||
if config_path:
|
if config_path:
|
||||||
self.config_path = Path(config_path)
|
self.config_path = Path(config_path)
|
||||||
@@ -51,7 +51,7 @@ class CustomModelRegistryBase:
|
|||||||
else:
|
else:
|
||||||
raise AttributeError("resource accessor not available")
|
raise AttributeError("resource accessor not available")
|
||||||
except Exception:
|
except Exception:
|
||||||
self.config_path = Path(__file__).parent.parent / "conf" / default_filename
|
self.config_path = Path(__file__).resolve().parents[3] / "conf" / default_filename
|
||||||
|
|
||||||
self.alias_map: dict[str, str] = {}
|
self.alias_map: dict[str, str] = {}
|
||||||
self.model_map: dict[str, ModelCapabilities] = {}
|
self.model_map: dict[str, ModelCapabilities] = {}
|
||||||
@@ -213,7 +213,7 @@ class CustomModelRegistryBase:
|
|||||||
|
|
||||||
|
|
||||||
class CapabilityModelRegistry(CustomModelRegistryBase):
|
class CapabilityModelRegistry(CustomModelRegistryBase):
|
||||||
"""Registry that returns `ModelCapabilities` objects with alias support."""
|
"""Registry that returns :class:`ModelCapabilities` objects with alias support."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -1,12 +1,14 @@
|
|||||||
"""Registry for models exposed via custom (local) OpenAI-compatible endpoints."""
|
"""Registry loader for custom OpenAI-compatible endpoints."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .model_registry_base import CAPABILITY_FIELD_NAMES, CapabilityModelRegistry
|
from ..shared import ModelCapabilities, ProviderType
|
||||||
from .shared import ModelCapabilities, ProviderType
|
from .base import CAPABILITY_FIELD_NAMES, CapabilityModelRegistry
|
||||||
|
|
||||||
|
|
||||||
class CustomEndpointModelRegistry(CapabilityModelRegistry):
|
class CustomEndpointModelRegistry(CapabilityModelRegistry):
|
||||||
|
"""Capability registry backed by ``conf/custom_models.json``."""
|
||||||
|
|
||||||
def __init__(self, config_path: str | None = None) -> None:
|
def __init__(self, config_path: str | None = None) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
env_var_name="CUSTOM_MODELS_CONFIG_PATH",
|
env_var_name="CUSTOM_MODELS_CONFIG_PATH",
|
||||||
@@ -15,11 +17,8 @@ class CustomEndpointModelRegistry(CapabilityModelRegistry):
|
|||||||
friendly_prefix="Custom ({model})",
|
friendly_prefix="Custom ({model})",
|
||||||
config_path=config_path,
|
config_path=config_path,
|
||||||
)
|
)
|
||||||
self.reload()
|
|
||||||
|
|
||||||
def _finalise_entry(self, entry: dict) -> tuple[ModelCapabilities, dict]:
|
def _finalise_entry(self, entry: dict) -> tuple[ModelCapabilities, dict]:
|
||||||
entry["provider"] = ProviderType.CUSTOM
|
|
||||||
entry.setdefault("friendly_name", f"Custom ({entry['model_name']})")
|
|
||||||
filtered = {k: v for k, v in entry.items() if k in CAPABILITY_FIELD_NAMES}
|
filtered = {k: v for k, v in entry.items() if k in CAPABILITY_FIELD_NAMES}
|
||||||
filtered.setdefault("provider", ProviderType.CUSTOM)
|
filtered.setdefault("provider", ProviderType.CUSTOM)
|
||||||
capability = ModelCapabilities(**filtered)
|
capability = ModelCapabilities(**filtered)
|
||||||
19
providers/registries/dial.py
Normal file
19
providers/registries/dial.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
"""Registry loader for DIAL provider capabilities."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from ..shared import ProviderType
|
||||||
|
from .base import CapabilityModelRegistry
|
||||||
|
|
||||||
|
|
||||||
|
class DialModelRegistry(CapabilityModelRegistry):
|
||||||
|
"""Capability registry backed by ``conf/dial_models.json``."""
|
||||||
|
|
||||||
|
def __init__(self, config_path: str | None = None) -> None:
|
||||||
|
super().__init__(
|
||||||
|
env_var_name="DIAL_MODELS_CONFIG_PATH",
|
||||||
|
default_filename="dial_models.json",
|
||||||
|
provider=ProviderType.DIAL,
|
||||||
|
friendly_prefix="DIAL ({model})",
|
||||||
|
config_path=config_path,
|
||||||
|
)
|
||||||
@@ -2,12 +2,12 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .model_registry_base import CapabilityModelRegistry
|
from ..shared import ProviderType
|
||||||
from .shared import ProviderType
|
from .base import CapabilityModelRegistry
|
||||||
|
|
||||||
|
|
||||||
class GeminiModelRegistry(CapabilityModelRegistry):
|
class GeminiModelRegistry(CapabilityModelRegistry):
|
||||||
"""Capability registry backed by `conf/gemini_models.json`."""
|
"""Capability registry backed by ``conf/gemini_models.json``."""
|
||||||
|
|
||||||
def __init__(self, config_path: str | None = None) -> None:
|
def __init__(self, config_path: str | None = None) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -2,12 +2,12 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .model_registry_base import CapabilityModelRegistry
|
from ..shared import ProviderType
|
||||||
from .shared import ProviderType
|
from .base import CapabilityModelRegistry
|
||||||
|
|
||||||
|
|
||||||
class OpenAIModelRegistry(CapabilityModelRegistry):
|
class OpenAIModelRegistry(CapabilityModelRegistry):
|
||||||
"""Capability registry backed by `conf/openai_models.json`."""
|
"""Capability registry backed by ``conf/openai_models.json``."""
|
||||||
|
|
||||||
def __init__(self, config_path: str | None = None) -> None:
|
def __init__(self, config_path: str | None = None) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -2,12 +2,12 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .model_registry_base import CAPABILITY_FIELD_NAMES, CapabilityModelRegistry
|
from ..shared import ModelCapabilities, ProviderType
|
||||||
from .shared import ModelCapabilities, ProviderType
|
from .base import CAPABILITY_FIELD_NAMES, CapabilityModelRegistry
|
||||||
|
|
||||||
|
|
||||||
class OpenRouterModelRegistry(CapabilityModelRegistry):
|
class OpenRouterModelRegistry(CapabilityModelRegistry):
|
||||||
"""Capability registry backed by `conf/openrouter_models.json`."""
|
"""Capability registry backed by ``conf/openrouter_models.json``."""
|
||||||
|
|
||||||
def __init__(self, config_path: str | None = None) -> None:
|
def __init__(self, config_path: str | None = None) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -1,13 +1,13 @@
|
|||||||
"""Registry loader for X.AI (GROK) model capabilities."""
|
"""Registry loader for X.AI model capabilities."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .model_registry_base import CapabilityModelRegistry
|
from ..shared import ProviderType
|
||||||
from .shared import ProviderType
|
from .base import CapabilityModelRegistry
|
||||||
|
|
||||||
|
|
||||||
class XAIModelRegistry(CapabilityModelRegistry):
|
class XAIModelRegistry(CapabilityModelRegistry):
|
||||||
"""Capability registry backed by `conf/xai_models.json`."""
|
"""Capability registry backed by ``conf/xai_models.json``."""
|
||||||
|
|
||||||
def __init__(self, config_path: str | None = None) -> None:
|
def __init__(self, config_path: str | None = None) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -22,7 +22,7 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
from typing import ClassVar
|
from typing import ClassVar
|
||||||
|
|
||||||
from .model_registry_base import CapabilityModelRegistry
|
from .registries.base import CapabilityModelRegistry
|
||||||
from .shared import ModelCapabilities
|
from .shared import ModelCapabilities
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
"""X.AI (GROK) model provider implementation."""
|
"""X.AI (GROK) model provider implementation."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, ClassVar, Optional
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tools.models import ToolModelCategory
|
from tools.models import ToolModelCategory
|
||||||
|
|
||||||
from .openai_compatible import OpenAICompatibleProvider
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
from .registries.xai import XAIModelRegistry
|
||||||
from .registry_provider_mixin import RegistryBackedProviderMixin
|
from .registry_provider_mixin import RegistryBackedProviderMixin
|
||||||
from .shared import ModelCapabilities, ProviderType
|
from .shared import ModelCapabilities, ProviderType
|
||||||
from .xai_registry import XAIModelRegistry
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -24,7 +24,7 @@ class XAIModelProvider(RegistryBackedProviderMixin, OpenAICompatibleProvider):
|
|||||||
FRIENDLY_NAME = "X.AI"
|
FRIENDLY_NAME = "X.AI"
|
||||||
|
|
||||||
REGISTRY_CLASS = XAIModelRegistry
|
REGISTRY_CLASS = XAIModelRegistry
|
||||||
MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {}
|
MODEL_CAPABILITIES: ClassVar[dict[str, ModelCapabilities]] = {}
|
||||||
|
|
||||||
def __init__(self, api_key: str, **kwargs):
|
def __init__(self, api_key: str, **kwargs):
|
||||||
"""Initialize X.AI provider with API key."""
|
"""Initialize X.AI provider with API key."""
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ py-modules = ["server", "config"]
|
|||||||
"conf/openai_models.json",
|
"conf/openai_models.json",
|
||||||
"conf/gemini_models.json",
|
"conf/gemini_models.json",
|
||||||
"conf/xai_models.json",
|
"conf/xai_models.json",
|
||||||
|
"conf/dial_models.json",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
@@ -395,7 +395,7 @@ def configure_providers():
|
|||||||
from providers.custom import CustomProvider
|
from providers.custom import CustomProvider
|
||||||
from providers.dial import DIALModelProvider
|
from providers.dial import DIALModelProvider
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
from providers.openrouter import OpenRouterProvider
|
from providers.openrouter import OpenRouterProvider
|
||||||
from providers.shared import ProviderType
|
from providers.shared import ProviderType
|
||||||
from providers.xai import XAIModelProvider
|
from providers.xai import XAIModelProvider
|
||||||
@@ -432,7 +432,7 @@ def configure_providers():
|
|||||||
azure_models_available = False
|
azure_models_available = False
|
||||||
if azure_key and azure_key != "your_azure_openai_key_here" and azure_endpoint:
|
if azure_key and azure_key != "your_azure_openai_key_here" and azure_endpoint:
|
||||||
try:
|
try:
|
||||||
from providers.azure_registry import AzureModelRegistry
|
from providers.registries.azure import AzureModelRegistry
|
||||||
|
|
||||||
azure_registry = AzureModelRegistry()
|
azure_registry = AzureModelRegistry()
|
||||||
if azure_registry.list_models():
|
if azure_registry.list_models():
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ if sys.platform == "win32":
|
|||||||
|
|
||||||
# Register providers for all tests
|
# Register providers for all tests
|
||||||
from providers.gemini import GeminiModelProvider # noqa: E402
|
from providers.gemini import GeminiModelProvider # noqa: E402
|
||||||
from providers.openai_provider import OpenAIModelProvider # noqa: E402
|
from providers.openai import OpenAIModelProvider # noqa: E402
|
||||||
from providers.registry import ModelProviderRegistry # noqa: E402
|
from providers.registry import ModelProviderRegistry # noqa: E402
|
||||||
from providers.shared import ProviderType # noqa: E402
|
from providers.shared import ProviderType # noqa: E402
|
||||||
from providers.xai import XAIModelProvider # noqa: E402
|
from providers.xai import XAIModelProvider # noqa: E402
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import os
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
from providers.shared import ProviderType
|
from providers.shared import ProviderType
|
||||||
from utils.model_restrictions import ModelRestrictionService
|
from utils.model_restrictions import ModelRestrictionService
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
from providers.registry import ModelProviderRegistry
|
from providers.registry import ModelProviderRegistry
|
||||||
from providers.shared import ProviderType
|
from providers.shared import ProviderType
|
||||||
from providers.xai import XAIModelProvider
|
from providers.xai import XAIModelProvider
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import pytest
|
|||||||
import utils.env as env_config
|
import utils.env as env_config
|
||||||
import utils.model_restrictions as model_restrictions
|
import utils.model_restrictions as model_restrictions
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
from providers.openrouter import OpenRouterProvider
|
from providers.openrouter import OpenRouterProvider
|
||||||
from providers.registry import ModelProviderRegistry
|
from providers.registry import ModelProviderRegistry
|
||||||
from providers.shared import ProviderType
|
from providers.shared import ProviderType
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ class TestAutoModeProviderSelection:
|
|||||||
os.environ.pop(key, None)
|
os.environ.pop(key, None)
|
||||||
|
|
||||||
# Register only OpenAI provider
|
# Register only OpenAI provider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
|
||||||
@@ -127,7 +127,7 @@ class TestAutoModeProviderSelection:
|
|||||||
|
|
||||||
# Register both providers
|
# Register both providers
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
@@ -212,7 +212,7 @@ class TestAutoModeProviderSelection:
|
|||||||
|
|
||||||
# Register both providers
|
# Register both providers
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
@@ -256,7 +256,7 @@ class TestAutoModeProviderSelection:
|
|||||||
|
|
||||||
# Register all providers
|
# Register all providers
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
from providers.xai import XAIModelProvider
|
from providers.xai import XAIModelProvider
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
@@ -307,7 +307,7 @@ class TestAutoModeProviderSelection:
|
|||||||
|
|
||||||
# Register all providers
|
# Register all providers
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
from providers.xai import XAIModelProvider
|
from providers.xai import XAIModelProvider
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
from providers.shared import ProviderType
|
from providers.shared import ProviderType
|
||||||
from utils.model_restrictions import ModelRestrictionService
|
from utils.model_restrictions import ModelRestrictionService
|
||||||
|
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ async def test_chat_cross_model_continuation(monkeypatch):
|
|||||||
|
|
||||||
ModelProviderRegistry.reset_for_testing()
|
ModelProviderRegistry.reset_for_testing()
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
@@ -170,7 +170,7 @@ async def test_chat_cross_model_continuation(monkeypatch):
|
|||||||
|
|
||||||
ModelProviderRegistry.reset_for_testing()
|
ModelProviderRegistry.reset_for_testing()
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ async def test_chat_auto_mode_with_openai(monkeypatch):
|
|||||||
|
|
||||||
# Reset registry and register only OpenAI provider
|
# Reset registry and register only OpenAI provider
|
||||||
ModelProviderRegistry.reset_for_testing()
|
ModelProviderRegistry.reset_for_testing()
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
|
||||||
@@ -115,7 +115,7 @@ async def test_chat_openai_continuation(monkeypatch):
|
|||||||
m.delenv(key, raising=False)
|
m.delenv(key, raising=False)
|
||||||
|
|
||||||
ModelProviderRegistry.reset_for_testing()
|
ModelProviderRegistry.reset_for_testing()
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ async def test_consensus_multi_model_consultations(monkeypatch):
|
|||||||
# Reset providers and register only OpenAI & Gemini for deterministic behavior
|
# Reset providers and register only OpenAI & Gemini for deterministic behavior
|
||||||
ModelProviderRegistry.reset_for_testing()
|
ModelProviderRegistry.reset_for_testing()
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import tempfile
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
|
|
||||||
class TestCustomOpenAITemperatureParameterFix:
|
class TestCustomOpenAITemperatureParameterFix:
|
||||||
@@ -79,7 +79,7 @@ class TestCustomOpenAITemperatureParameterFix:
|
|||||||
mock_client.chat.completions.create.return_value = mock_response
|
mock_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
# Create provider with custom config
|
# Create provider with custom config
|
||||||
with patch("providers.openrouter_registry.OpenRouterModelRegistry") as mock_registry_class:
|
with patch("providers.registries.openrouter.OpenRouterModelRegistry") as mock_registry_class:
|
||||||
# Mock registry to load our test config
|
# Mock registry to load our test config
|
||||||
mock_registry = Mock()
|
mock_registry = Mock()
|
||||||
mock_registry_class.return_value = mock_registry
|
mock_registry_class.return_value = mock_registry
|
||||||
@@ -163,7 +163,7 @@ class TestCustomOpenAITemperatureParameterFix:
|
|||||||
mock_client.chat.completions.create.return_value = mock_response
|
mock_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
# Create provider with custom config
|
# Create provider with custom config
|
||||||
with patch("providers.openrouter_registry.OpenRouterModelRegistry") as mock_registry_class:
|
with patch("providers.registries.openrouter.OpenRouterModelRegistry") as mock_registry_class:
|
||||||
# Mock registry to load our test config
|
# Mock registry to load our test config
|
||||||
mock_registry = Mock()
|
mock_registry = Mock()
|
||||||
mock_registry_class.return_value = mock_registry
|
mock_registry_class.return_value = mock_registry
|
||||||
@@ -221,7 +221,7 @@ class TestCustomOpenAITemperatureParameterFix:
|
|||||||
mock_service.is_allowed.return_value = True
|
mock_service.is_allowed.return_value = True
|
||||||
mock_restriction_service.return_value = mock_service
|
mock_restriction_service.return_value = mock_service
|
||||||
|
|
||||||
with patch("providers.openrouter_registry.OpenRouterModelRegistry") as mock_registry_class:
|
with patch("providers.registries.openrouter.OpenRouterModelRegistry") as mock_registry_class:
|
||||||
# Mock registry to return a custom OpenAI model
|
# Mock registry to return a custom OpenAI model
|
||||||
mock_registry = Mock()
|
mock_registry = Mock()
|
||||||
mock_registry_class.return_value = mock_registry
|
mock_registry_class.return_value = mock_registry
|
||||||
@@ -267,7 +267,7 @@ class TestCustomOpenAITemperatureParameterFix:
|
|||||||
mock_service.is_allowed.return_value = True
|
mock_service.is_allowed.return_value = True
|
||||||
mock_restriction_service.return_value = mock_service
|
mock_restriction_service.return_value = mock_service
|
||||||
|
|
||||||
with patch("providers.openrouter_registry.OpenRouterModelRegistry") as mock_registry_class:
|
with patch("providers.registries.openrouter.OpenRouterModelRegistry") as mock_registry_class:
|
||||||
# Mock registry to raise an exception
|
# Mock registry to raise an exception
|
||||||
mock_registry_class.side_effect = Exception("Registry not available")
|
mock_registry_class.side_effect = Exception("Registry not available")
|
||||||
|
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ class TestIntelligentFallback:
|
|||||||
def test_prefers_openai_o3_mini_when_available(self):
|
def test_prefers_openai_o3_mini_when_available(self):
|
||||||
"""Test that gpt-5 is preferred when OpenAI API key is available (based on new preference order)"""
|
"""Test that gpt-5 is preferred when OpenAI API key is available (based on new preference order)"""
|
||||||
# Register only OpenAI provider for this test
|
# Register only OpenAI provider for this test
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
|
||||||
@@ -62,7 +62,7 @@ class TestIntelligentFallback:
|
|||||||
"""Test that OpenAI is preferred when both API keys are available"""
|
"""Test that OpenAI is preferred when both API keys are available"""
|
||||||
# Register both OpenAI and Gemini providers
|
# Register both OpenAI and Gemini providers
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
@@ -75,7 +75,7 @@ class TestIntelligentFallback:
|
|||||||
"""Test fallback behavior when no API keys are available"""
|
"""Test fallback behavior when no API keys are available"""
|
||||||
# Register providers but with no API keys available
|
# Register providers but with no API keys available
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
@@ -86,7 +86,7 @@ class TestIntelligentFallback:
|
|||||||
def test_available_providers_with_keys(self):
|
def test_available_providers_with_keys(self):
|
||||||
"""Test the get_available_providers_with_keys method"""
|
"""Test the get_available_providers_with_keys method"""
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False):
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False):
|
||||||
# Clear and register providers
|
# Clear and register providers
|
||||||
@@ -119,7 +119,7 @@ class TestIntelligentFallback:
|
|||||||
patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False),
|
patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False),
|
||||||
):
|
):
|
||||||
# Register only OpenAI provider for this test
|
# Register only OpenAI provider for this test
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ Issue: Custom OpenAI models (gpt-5, o3) use temperature despite the config havin
|
|||||||
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
|
|
||||||
def test_issue_245_custom_openai_temperature_ignored():
|
def test_issue_245_custom_openai_temperature_ignored():
|
||||||
@@ -14,7 +14,7 @@ def test_issue_245_custom_openai_temperature_ignored():
|
|||||||
|
|
||||||
with patch("utils.model_restrictions.get_restriction_service") as mock_restriction:
|
with patch("utils.model_restrictions.get_restriction_service") as mock_restriction:
|
||||||
with patch("providers.openai_compatible.OpenAI") as mock_openai:
|
with patch("providers.openai_compatible.OpenAI") as mock_openai:
|
||||||
with patch("providers.openrouter_registry.OpenRouterModelRegistry") as mock_registry_class:
|
with patch("providers.registries.openrouter.OpenRouterModelRegistry") as mock_registry_class:
|
||||||
|
|
||||||
# Mock restriction service
|
# Mock restriction service
|
||||||
mock_service = Mock()
|
mock_service = Mock()
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ class TestListModelsRestrictions(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
@patch("utils.model_restrictions.get_restriction_service")
|
@patch("utils.model_restrictions.get_restriction_service")
|
||||||
@patch("providers.openrouter_registry.OpenRouterModelRegistry")
|
@patch("providers.registries.openrouter.OpenRouterModelRegistry")
|
||||||
@patch.object(ModelProviderRegistry, "get_available_models")
|
@patch.object(ModelProviderRegistry, "get_available_models")
|
||||||
@patch.object(ModelProviderRegistry, "get_provider")
|
@patch.object(ModelProviderRegistry, "get_provider")
|
||||||
def test_listmodels_respects_openrouter_restrictions(
|
def test_listmodels_respects_openrouter_restrictions(
|
||||||
@@ -239,7 +239,7 @@ class TestListModelsRestrictions(unittest.TestCase):
|
|||||||
self.assertIn("OpenRouter models restricted by", result)
|
self.assertIn("OpenRouter models restricted by", result)
|
||||||
|
|
||||||
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key", "GEMINI_API_KEY": "gemini-test-key"}, clear=True)
|
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key", "GEMINI_API_KEY": "gemini-test-key"}, clear=True)
|
||||||
@patch("providers.openrouter_registry.OpenRouterModelRegistry")
|
@patch("providers.registries.openrouter.OpenRouterModelRegistry")
|
||||||
@patch.object(ModelProviderRegistry, "get_provider")
|
@patch.object(ModelProviderRegistry, "get_provider")
|
||||||
def test_listmodels_shows_all_models_without_restrictions(self, mock_get_provider, mock_registry_class):
|
def test_listmodels_shows_all_models_without_restrictions(self, mock_get_provider, mock_registry_class):
|
||||||
"""Test that listmodels shows all models when no restrictions are set."""
|
"""Test that listmodels shows all models when no restrictions are set."""
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
from providers.shared import ProviderType
|
from providers.shared import ProviderType
|
||||||
from utils.model_restrictions import ModelRestrictionService
|
from utils.model_restrictions import ModelRestrictionService
|
||||||
|
|
||||||
@@ -767,7 +767,7 @@ class TestAutoModeWithRestrictions:
|
|||||||
# Clear registry and register only OpenAI and Gemini providers
|
# Clear registry and register only OpenAI and Gemini providers
|
||||||
ModelProviderRegistry._instance = None
|
ModelProviderRegistry._instance = None
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ for O3 models while maintaining them for regular models.
|
|||||||
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
|
|
||||||
class TestO3TemperatureParameterFixSimple:
|
class TestO3TemperatureParameterFixSimple:
|
||||||
@@ -175,7 +175,7 @@ class TestO3TemperatureParameterFixSimple:
|
|||||||
@patch("utils.model_restrictions.get_restriction_service")
|
@patch("utils.model_restrictions.get_restriction_service")
|
||||||
def test_all_o3_models_have_correct_temperature_capability(self, mock_restriction_service):
|
def test_all_o3_models_have_correct_temperature_capability(self, mock_restriction_service):
|
||||||
"""Test that all O3/O4 models have supports_temperature=False in their capabilities."""
|
"""Test that all O3/O4 models have supports_temperature=False in their capabilities."""
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
# Mock restriction service to allow all models
|
# Mock restriction service to allow all models
|
||||||
mock_service = Mock()
|
mock_service = Mock()
|
||||||
@@ -211,7 +211,7 @@ class TestO3TemperatureParameterFixSimple:
|
|||||||
@patch("utils.model_restrictions.get_restriction_service")
|
@patch("utils.model_restrictions.get_restriction_service")
|
||||||
def test_openai_provider_temperature_constraints(self, mock_restriction_service):
|
def test_openai_provider_temperature_constraints(self, mock_restriction_service):
|
||||||
"""Test that OpenAI provider has correct temperature constraints for O3 models."""
|
"""Test that OpenAI provider has correct temperature constraints for O3 models."""
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
# Mock restriction service to allow all models
|
# Mock restriction service to allow all models
|
||||||
mock_service = Mock()
|
mock_service = Mock()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
import os
|
import os
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
from providers.shared import ProviderType
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -282,7 +282,7 @@ class TestOpenRouterRegistry:
|
|||||||
|
|
||||||
def test_registry_loading(self):
|
def test_registry_loading(self):
|
||||||
"""Test registry loads models from config."""
|
"""Test registry loads models from config."""
|
||||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
from providers.registries.openrouter import OpenRouterModelRegistry
|
||||||
|
|
||||||
registry = OpenRouterModelRegistry()
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
@@ -301,7 +301,7 @@ class TestOpenRouterRegistry:
|
|||||||
|
|
||||||
def test_registry_capabilities(self):
|
def test_registry_capabilities(self):
|
||||||
"""Test registry provides correct capabilities."""
|
"""Test registry provides correct capabilities."""
|
||||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
from providers.registries.openrouter import OpenRouterModelRegistry
|
||||||
|
|
||||||
registry = OpenRouterModelRegistry()
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
@@ -322,7 +322,7 @@ class TestOpenRouterRegistry:
|
|||||||
|
|
||||||
def test_multiple_aliases_same_model(self):
|
def test_multiple_aliases_same_model(self):
|
||||||
"""Test multiple aliases pointing to same model."""
|
"""Test multiple aliases pointing to same model."""
|
||||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
from providers.registries.openrouter import OpenRouterModelRegistry
|
||||||
|
|
||||||
registry = OpenRouterModelRegistry()
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
from providers.registries.openrouter import OpenRouterModelRegistry
|
||||||
from providers.shared import ModelCapabilities, ProviderType
|
from providers.shared import ModelCapabilities, ProviderType
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ class TestModelSelection:
|
|||||||
ModelProviderRegistry.unregister_provider(provider_type)
|
ModelProviderRegistry.unregister_provider(provider_type)
|
||||||
|
|
||||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}, clear=False):
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}, clear=False):
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
|
||||||
@@ -125,7 +125,7 @@ class TestModelSelection:
|
|||||||
ModelProviderRegistry.unregister_provider(provider_type)
|
ModelProviderRegistry.unregister_provider(provider_type)
|
||||||
|
|
||||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}, clear=False):
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}, clear=False):
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
|
||||||
@@ -159,7 +159,7 @@ class TestModelSelection:
|
|||||||
ModelProviderRegistry.unregister_provider(provider_type)
|
ModelProviderRegistry.unregister_provider(provider_type)
|
||||||
|
|
||||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}, clear=False):
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}, clear=False):
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
|
||||||
@@ -220,7 +220,7 @@ class TestFlexibleModelSelection:
|
|||||||
with patch.dict(os.environ, case["env"], clear=False):
|
with patch.dict(os.environ, case["env"], clear=False):
|
||||||
# Register the appropriate provider
|
# Register the appropriate provider
|
||||||
if case["provider_type"] == ProviderType.OPENAI:
|
if case["provider_type"] == ProviderType.OPENAI:
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
elif case["provider_type"] == ProviderType.GOOGLE:
|
elif case["provider_type"] == ProviderType.GOOGLE:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from types import SimpleNamespace
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
|
|
||||||
def _mock_chat_response(content: str = "retry success") -> SimpleNamespace:
|
def _mock_chat_response(content: str = "retry success") -> SimpleNamespace:
|
||||||
|
|||||||
@@ -189,7 +189,7 @@ class TestProviderRoutingBugs:
|
|||||||
|
|
||||||
# Register providers in priority order (like server.py)
|
# Register providers in priority order (like server.py)
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
from providers.openrouter import OpenRouterProvider
|
from providers.openrouter import OpenRouterProvider
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from unittest.mock import Mock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
from providers.shared import ProviderType
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import pytest
|
|||||||
|
|
||||||
from providers import ModelProviderRegistry, ModelResponse
|
from providers import ModelProviderRegistry, ModelResponse
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
from providers.shared import ProviderType
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ Test to verify structured error code-based retry logic.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
|
|
||||||
def test_openai_structured_error_retry_logic():
|
def test_openai_structured_error_retry_logic():
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from providers.dial import DIALModelProvider
|
from providers.dial import DIALModelProvider
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
from providers.xai import XAIModelProvider
|
from providers.xai import XAIModelProvider
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import tempfile
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
from providers.registries.openrouter import OpenRouterModelRegistry
|
||||||
|
|
||||||
|
|
||||||
class TestUvxPathResolution:
|
class TestUvxPathResolution:
|
||||||
@@ -55,7 +55,7 @@ class TestUvxPathResolution:
|
|||||||
assert registry.config_path == config_path
|
assert registry.config_path == config_path
|
||||||
assert len(registry.list_models()) > 0
|
assert len(registry.list_models()) > 0
|
||||||
|
|
||||||
@patch("providers.model_registry_base.importlib.resources.files")
|
@patch("providers.registries.base.importlib.resources.files")
|
||||||
def test_multiple_path_fallback(self, mock_files):
|
def test_multiple_path_fallback(self, mock_files):
|
||||||
"""Test that file-system fallback works when resource loading fails."""
|
"""Test that file-system fallback works when resource loading fails."""
|
||||||
mock_files.side_effect = Exception("Resource loading failed")
|
mock_files.side_effect = Exception("Resource loading failed")
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ def inject_transport(monkeypatch, cassette_path: str):
|
|||||||
transport = inject_transport(monkeypatch, "path/to/cassette.json")
|
transport = inject_transport(monkeypatch, "path/to/cassette.json")
|
||||||
"""
|
"""
|
||||||
# Ensure OpenAI provider is registered - always needed for transport injection
|
# Ensure OpenAI provider is registered - always needed for transport injection
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
from providers.registry import ModelProviderRegistry
|
from providers.registry import ModelProviderRegistry
|
||||||
from providers.shared import ProviderType
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
from mcp.types import TextContent
|
from mcp.types import TextContent
|
||||||
|
|
||||||
from providers.custom_registry import CustomEndpointModelRegistry
|
from providers.registries.custom import CustomEndpointModelRegistry
|
||||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
from providers.registries.openrouter import OpenRouterModelRegistry
|
||||||
from tools.models import ToolModelCategory, ToolOutput
|
from tools.models import ToolModelCategory, ToolOutput
|
||||||
from tools.shared.base_models import ToolRequest
|
from tools.shared.base_models import ToolRequest
|
||||||
from tools.shared.base_tool import BaseTool
|
from tools.shared.base_tool import BaseTool
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ class BaseTool(ABC):
|
|||||||
"""Get cached OpenRouter registry instance, creating if needed."""
|
"""Get cached OpenRouter registry instance, creating if needed."""
|
||||||
# Use BaseTool class directly to ensure cache is shared across all subclasses
|
# Use BaseTool class directly to ensure cache is shared across all subclasses
|
||||||
if BaseTool._openrouter_registry_cache is None:
|
if BaseTool._openrouter_registry_cache is None:
|
||||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
from providers.registries.openrouter import OpenRouterModelRegistry
|
||||||
|
|
||||||
BaseTool._openrouter_registry_cache = OpenRouterModelRegistry()
|
BaseTool._openrouter_registry_cache = OpenRouterModelRegistry()
|
||||||
logger.debug("Created cached OpenRouter registry instance")
|
logger.debug("Created cached OpenRouter registry instance")
|
||||||
@@ -99,7 +99,7 @@ class BaseTool(ABC):
|
|||||||
def _get_custom_registry(cls):
|
def _get_custom_registry(cls):
|
||||||
"""Get cached custom-endpoint registry instance."""
|
"""Get cached custom-endpoint registry instance."""
|
||||||
if BaseTool._custom_registry_cache is None:
|
if BaseTool._custom_registry_cache is None:
|
||||||
from providers.custom_registry import CustomEndpointModelRegistry
|
from providers.registries.custom import CustomEndpointModelRegistry
|
||||||
|
|
||||||
BaseTool._custom_registry_cache = CustomEndpointModelRegistry()
|
BaseTool._custom_registry_cache = CustomEndpointModelRegistry()
|
||||||
logger.debug("Created cached Custom registry instance")
|
logger.debug("Created cached Custom registry instance")
|
||||||
|
|||||||
Reference in New Issue
Block a user