Cleanup, use ModelCapabilities only

This commit is contained in:
Fahad
2025-06-23 17:39:47 +04:00
parent 498ea88293
commit 14eaf930ed
5 changed files with 47 additions and 84 deletions

View File

@@ -291,7 +291,6 @@ class CustomProvider(OpenAICompatibleProvider):
Returns:
Dictionary mapping model names to their ModelCapabilities objects
"""
from .base import ProviderType
configs = {}
@@ -302,12 +301,8 @@ class CustomProvider(OpenAICompatibleProvider):
if self.validate_model_name(model_name):
config = self._registry.resolve(model_name)
if config and config.is_custom:
# Convert OpenRouterModelConfig to ModelCapabilities
capabilities = config.to_capabilities()
# Override provider type to CUSTOM for local models
capabilities.provider = ProviderType.CUSTOM
capabilities.friendly_name = f"{self.FRIENDLY_NAME} ({config.model_name})"
configs[model_name] = capabilities
# Use ModelCapabilities directly from registry
configs[model_name] = config
return configs

View File

@@ -288,12 +288,8 @@ class OpenRouterProvider(OpenAICompatibleProvider):
if self.validate_model_name(model_name):
config = self._registry.resolve(model_name)
if config and not config.is_custom: # Only OpenRouter models, not custom ones
# Convert OpenRouterModelConfig to ModelCapabilities
capabilities = config.to_capabilities()
# Override provider type to OPENROUTER
capabilities.provider = ProviderType.OPENROUTER
capabilities.friendly_name = f"{self.FRIENDLY_NAME} ({config.model_name})"
configs[model_name] = capabilities
# Use ModelCapabilities directly from registry
configs[model_name] = config
return configs

View File

@@ -2,7 +2,6 @@
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
@@ -11,58 +10,10 @@ from utils.file_utils import read_json_file
from .base import (
ModelCapabilities,
ProviderType,
TemperatureConstraint,
create_temperature_constraint,
)
@dataclass
class OpenRouterModelConfig:
"""Configuration for an OpenRouter model."""
model_name: str
aliases: list[str] = field(default_factory=list)
context_window: int = 32768 # Total context window size in tokens
supports_extended_thinking: bool = False
supports_system_prompts: bool = True
supports_streaming: bool = True
supports_function_calling: bool = False
supports_json_mode: bool = False
supports_images: bool = False # Whether model can process images
max_image_size_mb: float = 0.0 # Maximum total size for all images in MB
supports_temperature: bool = True # Whether model accepts temperature parameter in API calls
temperature_constraint: Optional[str] = (
None # Type of temperature constraint: "fixed", "range", "discrete", or None for default range
)
is_custom: bool = False # True for models that should only be used with custom endpoints
description: str = ""
def _create_temperature_constraint(self) -> TemperatureConstraint:
"""Create temperature constraint object from configuration.
Returns:
TemperatureConstraint object based on configuration
"""
return create_temperature_constraint(self.temperature_constraint or "range")
def to_capabilities(self) -> ModelCapabilities:
"""Convert to ModelCapabilities object."""
return ModelCapabilities(
provider=ProviderType.OPENROUTER,
model_name=self.model_name,
friendly_name="OpenRouter",
context_window=self.context_window,
supports_extended_thinking=self.supports_extended_thinking,
supports_system_prompts=self.supports_system_prompts,
supports_streaming=self.supports_streaming,
supports_function_calling=self.supports_function_calling,
supports_images=self.supports_images,
max_image_size_mb=self.max_image_size_mb,
supports_temperature=self.supports_temperature,
temperature_constraint=self._create_temperature_constraint(),
)
class OpenRouterModelRegistry:
"""Registry for managing OpenRouter model configurations and aliases."""
@@ -73,7 +24,7 @@ class OpenRouterModelRegistry:
config_path: Path to config file. If None, uses default locations.
"""
self.alias_map: dict[str, str] = {} # alias -> model_name
self.model_map: dict[str, OpenRouterModelConfig] = {} # model_name -> config
self.model_map: dict[str, ModelCapabilities] = {} # model_name -> config
# Determine config path
if config_path:
@@ -139,7 +90,7 @@ class OpenRouterModelRegistry:
self.alias_map = {}
self.model_map = {}
def _read_config(self) -> list[OpenRouterModelConfig]:
def _read_config(self) -> list[ModelCapabilities]:
"""Read configuration from file.
Returns:
@@ -158,7 +109,27 @@ class OpenRouterModelRegistry:
# Parse models
configs = []
for model_data in data.get("models", []):
config = OpenRouterModelConfig(**model_data)
# Create ModelCapabilities directly from JSON data
# Handle temperature_constraint conversion
temp_constraint_str = model_data.get("temperature_constraint")
temp_constraint = create_temperature_constraint(temp_constraint_str or "range")
# Set provider-specific defaults based on is_custom flag
is_custom = model_data.get("is_custom", False)
if is_custom:
model_data.setdefault("provider", ProviderType.CUSTOM)
model_data.setdefault("friendly_name", f"Custom ({model_data.get('model_name', 'Unknown')})")
else:
model_data.setdefault("provider", ProviderType.OPENROUTER)
model_data.setdefault("friendly_name", f"OpenRouter ({model_data.get('model_name', 'Unknown')})")
model_data["temperature_constraint"] = temp_constraint
# Remove the string version of temperature_constraint before creating ModelCapabilities
if "temperature_constraint" in model_data and isinstance(model_data["temperature_constraint"], str):
del model_data["temperature_constraint"]
model_data["temperature_constraint"] = temp_constraint
config = ModelCapabilities(**model_data)
configs.append(config)
return configs
@@ -168,7 +139,7 @@ class OpenRouterModelRegistry:
except Exception as e:
raise ValueError(f"Error reading config from {self.config_path}: {e}")
def _build_maps(self, configs: list[OpenRouterModelConfig]) -> None:
def _build_maps(self, configs: list[ModelCapabilities]) -> None:
"""Build alias and model maps from configurations.
Args:
@@ -211,7 +182,7 @@ class OpenRouterModelRegistry:
self.alias_map = alias_map
self.model_map = model_map
def resolve(self, name_or_alias: str) -> Optional[OpenRouterModelConfig]:
def resolve(self, name_or_alias: str) -> Optional[ModelCapabilities]:
"""Resolve a model name or alias to configuration.
Args:
@@ -237,10 +208,8 @@ class OpenRouterModelRegistry:
Returns:
ModelCapabilities if found, None otherwise
"""
config = self.resolve(name_or_alias)
if config:
return config.to_capabilities()
return None
# Registry now returns ModelCapabilities directly
return self.resolve(name_or_alias)
def list_models(self) -> list[str]:
"""List all available model names."""

View File

@@ -57,7 +57,7 @@ class TestOpenRouterProvider:
caps = provider.get_capabilities("o3")
assert caps.provider == ProviderType.OPENROUTER
assert caps.model_name == "openai/o3" # Resolved name
assert caps.friendly_name == "OpenRouter"
assert caps.friendly_name == "OpenRouter (openai/o3)"
# Test with a model not in registry - should get generic capabilities
caps = provider.get_capabilities("unknown-model")

View File

@@ -6,8 +6,8 @@ import tempfile
import pytest
from providers.base import ProviderType
from providers.openrouter_registry import OpenRouterModelConfig, OpenRouterModelRegistry
from providers.base import ModelCapabilities, ProviderType
from providers.openrouter_registry import OpenRouterModelRegistry
class TestOpenRouterModelRegistry:
@@ -110,18 +110,18 @@ class TestOpenRouterModelRegistry:
assert registry.resolve("non-existent") is None
def test_model_capabilities_conversion(self):
"""Test conversion to ModelCapabilities."""
"""Test that registry returns ModelCapabilities directly."""
registry = OpenRouterModelRegistry()
config = registry.resolve("opus")
assert config is not None
caps = config.to_capabilities()
assert caps.provider == ProviderType.OPENROUTER
assert caps.model_name == "anthropic/claude-opus-4"
assert caps.friendly_name == "OpenRouter"
assert caps.context_window == 200000
assert not caps.supports_extended_thinking
# Registry now returns ModelCapabilities objects directly
assert config.provider == ProviderType.OPENROUTER
assert config.model_name == "anthropic/claude-opus-4"
assert config.friendly_name == "OpenRouter (anthropic/claude-opus-4)"
assert config.context_window == 200000
assert not config.supports_extended_thinking
def test_duplicate_alias_detection(self):
"""Test that duplicate aliases are detected."""
@@ -199,8 +199,12 @@ class TestOpenRouterModelRegistry:
def test_model_with_all_capabilities(self):
"""Test model with all capability flags."""
config = OpenRouterModelConfig(
from providers.base import create_temperature_constraint
caps = ModelCapabilities(
provider=ProviderType.OPENROUTER,
model_name="test/full-featured",
friendly_name="OpenRouter (test/full-featured)",
aliases=["full"],
context_window=128000,
supports_extended_thinking=True,
@@ -209,9 +213,8 @@ class TestOpenRouterModelRegistry:
supports_function_calling=True,
supports_json_mode=True,
description="Fully featured test model",
temperature_constraint=create_temperature_constraint("range"),
)
caps = config.to_capabilities()
assert caps.context_window == 128000
assert caps.supports_extended_thinking
assert caps.supports_system_prompts