diff --git a/providers/custom.py b/providers/custom.py index 52d9b94..021bba5 100644 --- a/providers/custom.py +++ b/providers/custom.py @@ -291,7 +291,6 @@ class CustomProvider(OpenAICompatibleProvider): Returns: Dictionary mapping model names to their ModelCapabilities objects """ - from .base import ProviderType configs = {} @@ -302,12 +301,8 @@ class CustomProvider(OpenAICompatibleProvider): if self.validate_model_name(model_name): config = self._registry.resolve(model_name) if config and config.is_custom: - # Convert OpenRouterModelConfig to ModelCapabilities - capabilities = config.to_capabilities() - # Override provider type to CUSTOM for local models - capabilities.provider = ProviderType.CUSTOM - capabilities.friendly_name = f"{self.FRIENDLY_NAME} ({config.model_name})" - configs[model_name] = capabilities + # Use ModelCapabilities directly from registry + configs[model_name] = config return configs diff --git a/providers/openrouter.py b/providers/openrouter.py index 5d29514..3d90238 100644 --- a/providers/openrouter.py +++ b/providers/openrouter.py @@ -288,12 +288,8 @@ class OpenRouterProvider(OpenAICompatibleProvider): if self.validate_model_name(model_name): config = self._registry.resolve(model_name) if config and not config.is_custom: # Only OpenRouter models, not custom ones - # Convert OpenRouterModelConfig to ModelCapabilities - capabilities = config.to_capabilities() - # Override provider type to OPENROUTER - capabilities.provider = ProviderType.OPENROUTER - capabilities.friendly_name = f"{self.FRIENDLY_NAME} ({config.model_name})" - configs[model_name] = capabilities + # Use ModelCapabilities directly from registry + configs[model_name] = config return configs diff --git a/providers/openrouter_registry.py b/providers/openrouter_registry.py index 47258c8..97b8f60 100644 --- a/providers/openrouter_registry.py +++ b/providers/openrouter_registry.py @@ -2,7 +2,6 @@ import logging import os -from dataclasses import dataclass, field from pathlib import Path from typing import Optional @@ -11,58 +10,10 @@ from utils.file_utils import read_json_file from .base import ( ModelCapabilities, ProviderType, - TemperatureConstraint, create_temperature_constraint, ) -@dataclass -class OpenRouterModelConfig: - """Configuration for an OpenRouter model.""" - - model_name: str - aliases: list[str] = field(default_factory=list) - context_window: int = 32768 # Total context window size in tokens - supports_extended_thinking: bool = False - supports_system_prompts: bool = True - supports_streaming: bool = True - supports_function_calling: bool = False - supports_json_mode: bool = False - supports_images: bool = False # Whether model can process images - max_image_size_mb: float = 0.0 # Maximum total size for all images in MB - supports_temperature: bool = True # Whether model accepts temperature parameter in API calls - temperature_constraint: Optional[str] = ( - None # Type of temperature constraint: "fixed", "range", "discrete", or None for default range - ) - is_custom: bool = False # True for models that should only be used with custom endpoints - description: str = "" - - def _create_temperature_constraint(self) -> TemperatureConstraint: - """Create temperature constraint object from configuration. - - Returns: - TemperatureConstraint object based on configuration - """ - return create_temperature_constraint(self.temperature_constraint or "range") - - def to_capabilities(self) -> ModelCapabilities: - """Convert to ModelCapabilities object.""" - return ModelCapabilities( - provider=ProviderType.OPENROUTER, - model_name=self.model_name, - friendly_name="OpenRouter", - context_window=self.context_window, - supports_extended_thinking=self.supports_extended_thinking, - supports_system_prompts=self.supports_system_prompts, - supports_streaming=self.supports_streaming, - supports_function_calling=self.supports_function_calling, - supports_images=self.supports_images, - max_image_size_mb=self.max_image_size_mb, - supports_temperature=self.supports_temperature, - temperature_constraint=self._create_temperature_constraint(), - ) - - class OpenRouterModelRegistry: """Registry for managing OpenRouter model configurations and aliases.""" @@ -73,7 +24,7 @@ class OpenRouterModelRegistry: config_path: Path to config file. If None, uses default locations. """ self.alias_map: dict[str, str] = {} # alias -> model_name - self.model_map: dict[str, OpenRouterModelConfig] = {} # model_name -> config + self.model_map: dict[str, ModelCapabilities] = {} # model_name -> config # Determine config path if config_path: @@ -139,7 +90,7 @@ class OpenRouterModelRegistry: self.alias_map = {} self.model_map = {} - def _read_config(self) -> list[OpenRouterModelConfig]: + def _read_config(self) -> list[ModelCapabilities]: """Read configuration from file. Returns: @@ -158,7 +109,27 @@ class OpenRouterModelRegistry: # Parse models configs = [] for model_data in data.get("models", []): - config = OpenRouterModelConfig(**model_data) + # Create ModelCapabilities directly from JSON data + # Handle temperature_constraint conversion + temp_constraint_str = model_data.get("temperature_constraint") + temp_constraint = create_temperature_constraint(temp_constraint_str or "range") + + # Set provider-specific defaults based on is_custom flag + is_custom = model_data.get("is_custom", False) + if is_custom: + model_data.setdefault("provider", ProviderType.CUSTOM) + model_data.setdefault("friendly_name", f"Custom ({model_data.get('model_name', 'Unknown')})") + else: + model_data.setdefault("provider", ProviderType.OPENROUTER) + model_data.setdefault("friendly_name", f"OpenRouter ({model_data.get('model_name', 'Unknown')})") + model_data["temperature_constraint"] = temp_constraint + + # Remove the string version of temperature_constraint before creating ModelCapabilities + if "temperature_constraint" in model_data and isinstance(model_data["temperature_constraint"], str): + del model_data["temperature_constraint"] + model_data["temperature_constraint"] = temp_constraint + + config = ModelCapabilities(**model_data) configs.append(config) return configs @@ -168,7 +139,7 @@ class OpenRouterModelRegistry: except Exception as e: raise ValueError(f"Error reading config from {self.config_path}: {e}") - def _build_maps(self, configs: list[OpenRouterModelConfig]) -> None: + def _build_maps(self, configs: list[ModelCapabilities]) -> None: """Build alias and model maps from configurations. Args: @@ -211,7 +182,7 @@ class OpenRouterModelRegistry: self.alias_map = alias_map self.model_map = model_map - def resolve(self, name_or_alias: str) -> Optional[OpenRouterModelConfig]: + def resolve(self, name_or_alias: str) -> Optional[ModelCapabilities]: """Resolve a model name or alias to configuration. Args: @@ -237,10 +208,8 @@ class OpenRouterModelRegistry: Returns: ModelCapabilities if found, None otherwise """ - config = self.resolve(name_or_alias) - if config: - return config.to_capabilities() - return None + # Registry now returns ModelCapabilities directly + return self.resolve(name_or_alias) def list_models(self) -> list[str]: """List all available model names.""" diff --git a/tests/test_openrouter_provider.py b/tests/test_openrouter_provider.py index da10678..6d427ba 100644 --- a/tests/test_openrouter_provider.py +++ b/tests/test_openrouter_provider.py @@ -57,7 +57,7 @@ class TestOpenRouterProvider: caps = provider.get_capabilities("o3") assert caps.provider == ProviderType.OPENROUTER assert caps.model_name == "openai/o3" # Resolved name - assert caps.friendly_name == "OpenRouter" + assert caps.friendly_name == "OpenRouter (openai/o3)" # Test with a model not in registry - should get generic capabilities caps = provider.get_capabilities("unknown-model") diff --git a/tests/test_openrouter_registry.py b/tests/test_openrouter_registry.py index 4b8bbbf..f6ea000 100644 --- a/tests/test_openrouter_registry.py +++ b/tests/test_openrouter_registry.py @@ -6,8 +6,8 @@ import tempfile import pytest -from providers.base import ProviderType -from providers.openrouter_registry import OpenRouterModelConfig, OpenRouterModelRegistry +from providers.base import ModelCapabilities, ProviderType +from providers.openrouter_registry import OpenRouterModelRegistry class TestOpenRouterModelRegistry: @@ -110,18 +110,18 @@ class TestOpenRouterModelRegistry: assert registry.resolve("non-existent") is None def test_model_capabilities_conversion(self): - """Test conversion to ModelCapabilities.""" + """Test that registry returns ModelCapabilities directly.""" registry = OpenRouterModelRegistry() config = registry.resolve("opus") assert config is not None - caps = config.to_capabilities() - assert caps.provider == ProviderType.OPENROUTER - assert caps.model_name == "anthropic/claude-opus-4" - assert caps.friendly_name == "OpenRouter" - assert caps.context_window == 200000 - assert not caps.supports_extended_thinking + # Registry now returns ModelCapabilities objects directly + assert config.provider == ProviderType.OPENROUTER + assert config.model_name == "anthropic/claude-opus-4" + assert config.friendly_name == "OpenRouter (anthropic/claude-opus-4)" + assert config.context_window == 200000 + assert not config.supports_extended_thinking def test_duplicate_alias_detection(self): """Test that duplicate aliases are detected.""" @@ -199,8 +199,12 @@ class TestOpenRouterModelRegistry: def test_model_with_all_capabilities(self): """Test model with all capability flags.""" - config = OpenRouterModelConfig( + from providers.base import create_temperature_constraint + + caps = ModelCapabilities( + provider=ProviderType.OPENROUTER, model_name="test/full-featured", + friendly_name="OpenRouter (test/full-featured)", aliases=["full"], context_window=128000, supports_extended_thinking=True, @@ -209,9 +213,8 @@ class TestOpenRouterModelRegistry: supports_function_calling=True, supports_json_mode=True, description="Fully featured test model", + temperature_constraint=create_temperature_constraint("range"), ) - - caps = config.to_capabilities() assert caps.context_window == 128000 assert caps.supports_extended_thinking assert caps.supports_system_prompts