Cleanup, use ModelCapabilities only
This commit is contained in:
@@ -291,7 +291,6 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
Returns:
|
||||
Dictionary mapping model names to their ModelCapabilities objects
|
||||
"""
|
||||
from .base import ProviderType
|
||||
|
||||
configs = {}
|
||||
|
||||
@@ -302,12 +301,8 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
if self.validate_model_name(model_name):
|
||||
config = self._registry.resolve(model_name)
|
||||
if config and config.is_custom:
|
||||
# Convert OpenRouterModelConfig to ModelCapabilities
|
||||
capabilities = config.to_capabilities()
|
||||
# Override provider type to CUSTOM for local models
|
||||
capabilities.provider = ProviderType.CUSTOM
|
||||
capabilities.friendly_name = f"{self.FRIENDLY_NAME} ({config.model_name})"
|
||||
configs[model_name] = capabilities
|
||||
# Use ModelCapabilities directly from registry
|
||||
configs[model_name] = config
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
@@ -288,12 +288,8 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
if self.validate_model_name(model_name):
|
||||
config = self._registry.resolve(model_name)
|
||||
if config and not config.is_custom: # Only OpenRouter models, not custom ones
|
||||
# Convert OpenRouterModelConfig to ModelCapabilities
|
||||
capabilities = config.to_capabilities()
|
||||
# Override provider type to OPENROUTER
|
||||
capabilities.provider = ProviderType.OPENROUTER
|
||||
capabilities.friendly_name = f"{self.FRIENDLY_NAME} ({config.model_name})"
|
||||
configs[model_name] = capabilities
|
||||
# Use ModelCapabilities directly from registry
|
||||
configs[model_name] = config
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -11,58 +10,10 @@ from utils.file_utils import read_json_file
|
||||
from .base import (
|
||||
ModelCapabilities,
|
||||
ProviderType,
|
||||
TemperatureConstraint,
|
||||
create_temperature_constraint,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenRouterModelConfig:
|
||||
"""Configuration for an OpenRouter model."""
|
||||
|
||||
model_name: str
|
||||
aliases: list[str] = field(default_factory=list)
|
||||
context_window: int = 32768 # Total context window size in tokens
|
||||
supports_extended_thinking: bool = False
|
||||
supports_system_prompts: bool = True
|
||||
supports_streaming: bool = True
|
||||
supports_function_calling: bool = False
|
||||
supports_json_mode: bool = False
|
||||
supports_images: bool = False # Whether model can process images
|
||||
max_image_size_mb: float = 0.0 # Maximum total size for all images in MB
|
||||
supports_temperature: bool = True # Whether model accepts temperature parameter in API calls
|
||||
temperature_constraint: Optional[str] = (
|
||||
None # Type of temperature constraint: "fixed", "range", "discrete", or None for default range
|
||||
)
|
||||
is_custom: bool = False # True for models that should only be used with custom endpoints
|
||||
description: str = ""
|
||||
|
||||
def _create_temperature_constraint(self) -> TemperatureConstraint:
|
||||
"""Create temperature constraint object from configuration.
|
||||
|
||||
Returns:
|
||||
TemperatureConstraint object based on configuration
|
||||
"""
|
||||
return create_temperature_constraint(self.temperature_constraint or "range")
|
||||
|
||||
def to_capabilities(self) -> ModelCapabilities:
|
||||
"""Convert to ModelCapabilities object."""
|
||||
return ModelCapabilities(
|
||||
provider=ProviderType.OPENROUTER,
|
||||
model_name=self.model_name,
|
||||
friendly_name="OpenRouter",
|
||||
context_window=self.context_window,
|
||||
supports_extended_thinking=self.supports_extended_thinking,
|
||||
supports_system_prompts=self.supports_system_prompts,
|
||||
supports_streaming=self.supports_streaming,
|
||||
supports_function_calling=self.supports_function_calling,
|
||||
supports_images=self.supports_images,
|
||||
max_image_size_mb=self.max_image_size_mb,
|
||||
supports_temperature=self.supports_temperature,
|
||||
temperature_constraint=self._create_temperature_constraint(),
|
||||
)
|
||||
|
||||
|
||||
class OpenRouterModelRegistry:
|
||||
"""Registry for managing OpenRouter model configurations and aliases."""
|
||||
|
||||
@@ -73,7 +24,7 @@ class OpenRouterModelRegistry:
|
||||
config_path: Path to config file. If None, uses default locations.
|
||||
"""
|
||||
self.alias_map: dict[str, str] = {} # alias -> model_name
|
||||
self.model_map: dict[str, OpenRouterModelConfig] = {} # model_name -> config
|
||||
self.model_map: dict[str, ModelCapabilities] = {} # model_name -> config
|
||||
|
||||
# Determine config path
|
||||
if config_path:
|
||||
@@ -139,7 +90,7 @@ class OpenRouterModelRegistry:
|
||||
self.alias_map = {}
|
||||
self.model_map = {}
|
||||
|
||||
def _read_config(self) -> list[OpenRouterModelConfig]:
|
||||
def _read_config(self) -> list[ModelCapabilities]:
|
||||
"""Read configuration from file.
|
||||
|
||||
Returns:
|
||||
@@ -158,7 +109,27 @@ class OpenRouterModelRegistry:
|
||||
# Parse models
|
||||
configs = []
|
||||
for model_data in data.get("models", []):
|
||||
config = OpenRouterModelConfig(**model_data)
|
||||
# Create ModelCapabilities directly from JSON data
|
||||
# Handle temperature_constraint conversion
|
||||
temp_constraint_str = model_data.get("temperature_constraint")
|
||||
temp_constraint = create_temperature_constraint(temp_constraint_str or "range")
|
||||
|
||||
# Set provider-specific defaults based on is_custom flag
|
||||
is_custom = model_data.get("is_custom", False)
|
||||
if is_custom:
|
||||
model_data.setdefault("provider", ProviderType.CUSTOM)
|
||||
model_data.setdefault("friendly_name", f"Custom ({model_data.get('model_name', 'Unknown')})")
|
||||
else:
|
||||
model_data.setdefault("provider", ProviderType.OPENROUTER)
|
||||
model_data.setdefault("friendly_name", f"OpenRouter ({model_data.get('model_name', 'Unknown')})")
|
||||
model_data["temperature_constraint"] = temp_constraint
|
||||
|
||||
# Remove the string version of temperature_constraint before creating ModelCapabilities
|
||||
if "temperature_constraint" in model_data and isinstance(model_data["temperature_constraint"], str):
|
||||
del model_data["temperature_constraint"]
|
||||
model_data["temperature_constraint"] = temp_constraint
|
||||
|
||||
config = ModelCapabilities(**model_data)
|
||||
configs.append(config)
|
||||
|
||||
return configs
|
||||
@@ -168,7 +139,7 @@ class OpenRouterModelRegistry:
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading config from {self.config_path}: {e}")
|
||||
|
||||
def _build_maps(self, configs: list[OpenRouterModelConfig]) -> None:
|
||||
def _build_maps(self, configs: list[ModelCapabilities]) -> None:
|
||||
"""Build alias and model maps from configurations.
|
||||
|
||||
Args:
|
||||
@@ -211,7 +182,7 @@ class OpenRouterModelRegistry:
|
||||
self.alias_map = alias_map
|
||||
self.model_map = model_map
|
||||
|
||||
def resolve(self, name_or_alias: str) -> Optional[OpenRouterModelConfig]:
|
||||
def resolve(self, name_or_alias: str) -> Optional[ModelCapabilities]:
|
||||
"""Resolve a model name or alias to configuration.
|
||||
|
||||
Args:
|
||||
@@ -237,10 +208,8 @@ class OpenRouterModelRegistry:
|
||||
Returns:
|
||||
ModelCapabilities if found, None otherwise
|
||||
"""
|
||||
config = self.resolve(name_or_alias)
|
||||
if config:
|
||||
return config.to_capabilities()
|
||||
return None
|
||||
# Registry now returns ModelCapabilities directly
|
||||
return self.resolve(name_or_alias)
|
||||
|
||||
def list_models(self) -> list[str]:
|
||||
"""List all available model names."""
|
||||
|
||||
@@ -57,7 +57,7 @@ class TestOpenRouterProvider:
|
||||
caps = provider.get_capabilities("o3")
|
||||
assert caps.provider == ProviderType.OPENROUTER
|
||||
assert caps.model_name == "openai/o3" # Resolved name
|
||||
assert caps.friendly_name == "OpenRouter"
|
||||
assert caps.friendly_name == "OpenRouter (openai/o3)"
|
||||
|
||||
# Test with a model not in registry - should get generic capabilities
|
||||
caps = provider.get_capabilities("unknown-model")
|
||||
|
||||
@@ -6,8 +6,8 @@ import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.openrouter_registry import OpenRouterModelConfig, OpenRouterModelRegistry
|
||||
from providers.base import ModelCapabilities, ProviderType
|
||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||
|
||||
|
||||
class TestOpenRouterModelRegistry:
|
||||
@@ -110,18 +110,18 @@ class TestOpenRouterModelRegistry:
|
||||
assert registry.resolve("non-existent") is None
|
||||
|
||||
def test_model_capabilities_conversion(self):
|
||||
"""Test conversion to ModelCapabilities."""
|
||||
"""Test that registry returns ModelCapabilities directly."""
|
||||
registry = OpenRouterModelRegistry()
|
||||
|
||||
config = registry.resolve("opus")
|
||||
assert config is not None
|
||||
|
||||
caps = config.to_capabilities()
|
||||
assert caps.provider == ProviderType.OPENROUTER
|
||||
assert caps.model_name == "anthropic/claude-opus-4"
|
||||
assert caps.friendly_name == "OpenRouter"
|
||||
assert caps.context_window == 200000
|
||||
assert not caps.supports_extended_thinking
|
||||
# Registry now returns ModelCapabilities objects directly
|
||||
assert config.provider == ProviderType.OPENROUTER
|
||||
assert config.model_name == "anthropic/claude-opus-4"
|
||||
assert config.friendly_name == "OpenRouter (anthropic/claude-opus-4)"
|
||||
assert config.context_window == 200000
|
||||
assert not config.supports_extended_thinking
|
||||
|
||||
def test_duplicate_alias_detection(self):
|
||||
"""Test that duplicate aliases are detected."""
|
||||
@@ -199,8 +199,12 @@ class TestOpenRouterModelRegistry:
|
||||
|
||||
def test_model_with_all_capabilities(self):
|
||||
"""Test model with all capability flags."""
|
||||
config = OpenRouterModelConfig(
|
||||
from providers.base import create_temperature_constraint
|
||||
|
||||
caps = ModelCapabilities(
|
||||
provider=ProviderType.OPENROUTER,
|
||||
model_name="test/full-featured",
|
||||
friendly_name="OpenRouter (test/full-featured)",
|
||||
aliases=["full"],
|
||||
context_window=128000,
|
||||
supports_extended_thinking=True,
|
||||
@@ -209,9 +213,8 @@ class TestOpenRouterModelRegistry:
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
description="Fully featured test model",
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
)
|
||||
|
||||
caps = config.to_capabilities()
|
||||
assert caps.context_window == 128000
|
||||
assert caps.supports_extended_thinking
|
||||
assert caps.supports_system_prompts
|
||||
|
||||
Reference in New Issue
Block a user