Cleanup, use ModelCapabilities only

This commit is contained in:
Fahad
2025-06-23 17:39:47 +04:00
parent 498ea88293
commit 14eaf930ed
5 changed files with 47 additions and 84 deletions

View File

@@ -291,7 +291,6 @@ class CustomProvider(OpenAICompatibleProvider):
Returns: Returns:
Dictionary mapping model names to their ModelCapabilities objects Dictionary mapping model names to their ModelCapabilities objects
""" """
from .base import ProviderType
configs = {} configs = {}
@@ -302,12 +301,8 @@ class CustomProvider(OpenAICompatibleProvider):
if self.validate_model_name(model_name): if self.validate_model_name(model_name):
config = self._registry.resolve(model_name) config = self._registry.resolve(model_name)
if config and config.is_custom: if config and config.is_custom:
# Convert OpenRouterModelConfig to ModelCapabilities # Use ModelCapabilities directly from registry
capabilities = config.to_capabilities() configs[model_name] = config
# Override provider type to CUSTOM for local models
capabilities.provider = ProviderType.CUSTOM
capabilities.friendly_name = f"{self.FRIENDLY_NAME} ({config.model_name})"
configs[model_name] = capabilities
return configs return configs

View File

@@ -288,12 +288,8 @@ class OpenRouterProvider(OpenAICompatibleProvider):
if self.validate_model_name(model_name): if self.validate_model_name(model_name):
config = self._registry.resolve(model_name) config = self._registry.resolve(model_name)
if config and not config.is_custom: # Only OpenRouter models, not custom ones if config and not config.is_custom: # Only OpenRouter models, not custom ones
# Convert OpenRouterModelConfig to ModelCapabilities # Use ModelCapabilities directly from registry
capabilities = config.to_capabilities() configs[model_name] = config
# Override provider type to OPENROUTER
capabilities.provider = ProviderType.OPENROUTER
capabilities.friendly_name = f"{self.FRIENDLY_NAME} ({config.model_name})"
configs[model_name] = capabilities
return configs return configs

View File

@@ -2,7 +2,6 @@
import logging import logging
import os import os
from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@@ -11,58 +10,10 @@ from utils.file_utils import read_json_file
from .base import ( from .base import (
ModelCapabilities, ModelCapabilities,
ProviderType, ProviderType,
TemperatureConstraint,
create_temperature_constraint, create_temperature_constraint,
) )
@dataclass
class OpenRouterModelConfig:
"""Configuration for an OpenRouter model."""
model_name: str
aliases: list[str] = field(default_factory=list)
context_window: int = 32768 # Total context window size in tokens
supports_extended_thinking: bool = False
supports_system_prompts: bool = True
supports_streaming: bool = True
supports_function_calling: bool = False
supports_json_mode: bool = False
supports_images: bool = False # Whether model can process images
max_image_size_mb: float = 0.0 # Maximum total size for all images in MB
supports_temperature: bool = True # Whether model accepts temperature parameter in API calls
temperature_constraint: Optional[str] = (
None # Type of temperature constraint: "fixed", "range", "discrete", or None for default range
)
is_custom: bool = False # True for models that should only be used with custom endpoints
description: str = ""
def _create_temperature_constraint(self) -> TemperatureConstraint:
"""Create temperature constraint object from configuration.
Returns:
TemperatureConstraint object based on configuration
"""
return create_temperature_constraint(self.temperature_constraint or "range")
def to_capabilities(self) -> ModelCapabilities:
"""Convert to ModelCapabilities object."""
return ModelCapabilities(
provider=ProviderType.OPENROUTER,
model_name=self.model_name,
friendly_name="OpenRouter",
context_window=self.context_window,
supports_extended_thinking=self.supports_extended_thinking,
supports_system_prompts=self.supports_system_prompts,
supports_streaming=self.supports_streaming,
supports_function_calling=self.supports_function_calling,
supports_images=self.supports_images,
max_image_size_mb=self.max_image_size_mb,
supports_temperature=self.supports_temperature,
temperature_constraint=self._create_temperature_constraint(),
)
class OpenRouterModelRegistry: class OpenRouterModelRegistry:
"""Registry for managing OpenRouter model configurations and aliases.""" """Registry for managing OpenRouter model configurations and aliases."""
@@ -73,7 +24,7 @@ class OpenRouterModelRegistry:
config_path: Path to config file. If None, uses default locations. config_path: Path to config file. If None, uses default locations.
""" """
self.alias_map: dict[str, str] = {} # alias -> model_name self.alias_map: dict[str, str] = {} # alias -> model_name
self.model_map: dict[str, OpenRouterModelConfig] = {} # model_name -> config self.model_map: dict[str, ModelCapabilities] = {} # model_name -> config
# Determine config path # Determine config path
if config_path: if config_path:
@@ -139,7 +90,7 @@ class OpenRouterModelRegistry:
self.alias_map = {} self.alias_map = {}
self.model_map = {} self.model_map = {}
def _read_config(self) -> list[OpenRouterModelConfig]: def _read_config(self) -> list[ModelCapabilities]:
"""Read configuration from file. """Read configuration from file.
Returns: Returns:
@@ -158,7 +109,27 @@ class OpenRouterModelRegistry:
# Parse models # Parse models
configs = [] configs = []
for model_data in data.get("models", []): for model_data in data.get("models", []):
config = OpenRouterModelConfig(**model_data) # Create ModelCapabilities directly from JSON data
# Handle temperature_constraint conversion
temp_constraint_str = model_data.get("temperature_constraint")
temp_constraint = create_temperature_constraint(temp_constraint_str or "range")
# Set provider-specific defaults based on is_custom flag
is_custom = model_data.get("is_custom", False)
if is_custom:
model_data.setdefault("provider", ProviderType.CUSTOM)
model_data.setdefault("friendly_name", f"Custom ({model_data.get('model_name', 'Unknown')})")
else:
model_data.setdefault("provider", ProviderType.OPENROUTER)
model_data.setdefault("friendly_name", f"OpenRouter ({model_data.get('model_name', 'Unknown')})")
model_data["temperature_constraint"] = temp_constraint
# Remove the string version of temperature_constraint before creating ModelCapabilities
if "temperature_constraint" in model_data and isinstance(model_data["temperature_constraint"], str):
del model_data["temperature_constraint"]
model_data["temperature_constraint"] = temp_constraint
config = ModelCapabilities(**model_data)
configs.append(config) configs.append(config)
return configs return configs
@@ -168,7 +139,7 @@ class OpenRouterModelRegistry:
except Exception as e: except Exception as e:
raise ValueError(f"Error reading config from {self.config_path}: {e}") raise ValueError(f"Error reading config from {self.config_path}: {e}")
def _build_maps(self, configs: list[OpenRouterModelConfig]) -> None: def _build_maps(self, configs: list[ModelCapabilities]) -> None:
"""Build alias and model maps from configurations. """Build alias and model maps from configurations.
Args: Args:
@@ -211,7 +182,7 @@ class OpenRouterModelRegistry:
self.alias_map = alias_map self.alias_map = alias_map
self.model_map = model_map self.model_map = model_map
def resolve(self, name_or_alias: str) -> Optional[OpenRouterModelConfig]: def resolve(self, name_or_alias: str) -> Optional[ModelCapabilities]:
"""Resolve a model name or alias to configuration. """Resolve a model name or alias to configuration.
Args: Args:
@@ -237,10 +208,8 @@ class OpenRouterModelRegistry:
Returns: Returns:
ModelCapabilities if found, None otherwise ModelCapabilities if found, None otherwise
""" """
config = self.resolve(name_or_alias) # Registry now returns ModelCapabilities directly
if config: return self.resolve(name_or_alias)
return config.to_capabilities()
return None
def list_models(self) -> list[str]: def list_models(self) -> list[str]:
"""List all available model names.""" """List all available model names."""

View File

@@ -57,7 +57,7 @@ class TestOpenRouterProvider:
caps = provider.get_capabilities("o3") caps = provider.get_capabilities("o3")
assert caps.provider == ProviderType.OPENROUTER assert caps.provider == ProviderType.OPENROUTER
assert caps.model_name == "openai/o3" # Resolved name assert caps.model_name == "openai/o3" # Resolved name
assert caps.friendly_name == "OpenRouter" assert caps.friendly_name == "OpenRouter (openai/o3)"
# Test with a model not in registry - should get generic capabilities # Test with a model not in registry - should get generic capabilities
caps = provider.get_capabilities("unknown-model") caps = provider.get_capabilities("unknown-model")

View File

@@ -6,8 +6,8 @@ import tempfile
import pytest import pytest
from providers.base import ProviderType from providers.base import ModelCapabilities, ProviderType
from providers.openrouter_registry import OpenRouterModelConfig, OpenRouterModelRegistry from providers.openrouter_registry import OpenRouterModelRegistry
class TestOpenRouterModelRegistry: class TestOpenRouterModelRegistry:
@@ -110,18 +110,18 @@ class TestOpenRouterModelRegistry:
assert registry.resolve("non-existent") is None assert registry.resolve("non-existent") is None
def test_model_capabilities_conversion(self): def test_model_capabilities_conversion(self):
"""Test conversion to ModelCapabilities.""" """Test that registry returns ModelCapabilities directly."""
registry = OpenRouterModelRegistry() registry = OpenRouterModelRegistry()
config = registry.resolve("opus") config = registry.resolve("opus")
assert config is not None assert config is not None
caps = config.to_capabilities() # Registry now returns ModelCapabilities objects directly
assert caps.provider == ProviderType.OPENROUTER assert config.provider == ProviderType.OPENROUTER
assert caps.model_name == "anthropic/claude-opus-4" assert config.model_name == "anthropic/claude-opus-4"
assert caps.friendly_name == "OpenRouter" assert config.friendly_name == "OpenRouter (anthropic/claude-opus-4)"
assert caps.context_window == 200000 assert config.context_window == 200000
assert not caps.supports_extended_thinking assert not config.supports_extended_thinking
def test_duplicate_alias_detection(self): def test_duplicate_alias_detection(self):
"""Test that duplicate aliases are detected.""" """Test that duplicate aliases are detected."""
@@ -199,8 +199,12 @@ class TestOpenRouterModelRegistry:
def test_model_with_all_capabilities(self): def test_model_with_all_capabilities(self):
"""Test model with all capability flags.""" """Test model with all capability flags."""
config = OpenRouterModelConfig( from providers.base import create_temperature_constraint
caps = ModelCapabilities(
provider=ProviderType.OPENROUTER,
model_name="test/full-featured", model_name="test/full-featured",
friendly_name="OpenRouter (test/full-featured)",
aliases=["full"], aliases=["full"],
context_window=128000, context_window=128000,
supports_extended_thinking=True, supports_extended_thinking=True,
@@ -209,9 +213,8 @@ class TestOpenRouterModelRegistry:
supports_function_calling=True, supports_function_calling=True,
supports_json_mode=True, supports_json_mode=True,
description="Fully featured test model", description="Fully featured test model",
temperature_constraint=create_temperature_constraint("range"),
) )
caps = config.to_capabilities()
assert caps.context_window == 128000 assert caps.context_window == 128000
assert caps.supports_extended_thinking assert caps.supports_extended_thinking
assert caps.supports_system_prompts assert caps.supports_system_prompts