Breaking change: openrouter_models.json -> custom_models.json
* Support for Custom URLs and custom models, including locally hosted models such as ollama * Support for native + openrouter + local models (i.e. dozens of models) means you can start delegating sub-tasks to particular models or work to local models such as localizations or other boring work etc. * Several tests added * precommit to also include untracked (new) files * Logfile auto rollover * Improved logging
This commit is contained in:
@@ -12,6 +12,7 @@ class ProviderType(Enum):
|
||||
GOOGLE = "google"
|
||||
OPENAI = "openai"
|
||||
OPENROUTER = "openrouter"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class TemperatureConstraint(ABC):
|
||||
|
||||
273
providers/custom.py
Normal file
273
providers/custom.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""Custom API provider implementation."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from .base import (
|
||||
ModelCapabilities,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint,
|
||||
)
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .openrouter_registry import OpenRouterModelRegistry
|
||||
|
||||
|
||||
class CustomProvider(OpenAICompatibleProvider):
|
||||
"""Custom API provider for local models.
|
||||
|
||||
Supports local inference servers like Ollama, vLLM, LM Studio,
|
||||
and any OpenAI-compatible API endpoint.
|
||||
"""
|
||||
|
||||
FRIENDLY_NAME = "Custom API"
|
||||
|
||||
# Model registry for managing configurations and aliases (shared with OpenRouter)
|
||||
_registry: Optional[OpenRouterModelRegistry] = None
|
||||
|
||||
def __init__(self, api_key: str = "", base_url: str = "", **kwargs):
|
||||
"""Initialize Custom provider for local/self-hosted models.
|
||||
|
||||
This provider supports any OpenAI-compatible API endpoint including:
|
||||
- Ollama (typically no API key required)
|
||||
- vLLM (may require API key)
|
||||
- LM Studio (may require API key)
|
||||
- Text Generation WebUI (may require API key)
|
||||
- Enterprise/self-hosted APIs (typically require API key)
|
||||
|
||||
Args:
|
||||
api_key: API key for the custom endpoint. Can be empty string for
|
||||
providers that don't require authentication (like Ollama).
|
||||
Falls back to CUSTOM_API_KEY environment variable if not provided.
|
||||
base_url: Base URL for the custom API endpoint (e.g., 'http://host.docker.internal:11434/v1').
|
||||
Falls back to CUSTOM_API_URL environment variable if not provided.
|
||||
**kwargs: Additional configuration passed to parent OpenAI-compatible provider
|
||||
|
||||
Raises:
|
||||
ValueError: If no base_url is provided via parameter or environment variable
|
||||
"""
|
||||
# Fall back to environment variables only if not provided
|
||||
if not base_url:
|
||||
base_url = os.getenv("CUSTOM_API_URL", "")
|
||||
if not api_key:
|
||||
api_key = os.getenv("CUSTOM_API_KEY", "")
|
||||
|
||||
if not base_url:
|
||||
raise ValueError(
|
||||
"Custom API URL must be provided via base_url parameter or CUSTOM_API_URL environment variable"
|
||||
)
|
||||
|
||||
# For Ollama and other providers that don't require authentication,
|
||||
# set a dummy API key to avoid OpenAI client header issues
|
||||
if not api_key:
|
||||
api_key = "dummy-key-for-unauthenticated-endpoint"
|
||||
logging.debug("Using dummy API key for unauthenticated custom endpoint")
|
||||
|
||||
logging.info(f"Initializing Custom provider with endpoint: {base_url}")
|
||||
|
||||
super().__init__(api_key, base_url=base_url, **kwargs)
|
||||
|
||||
# Initialize model registry (shared with OpenRouter for consistent aliases)
|
||||
if CustomProvider._registry is None:
|
||||
CustomProvider._registry = OpenRouterModelRegistry()
|
||||
|
||||
# Log loaded models and aliases
|
||||
models = self._registry.list_models()
|
||||
aliases = self._registry.list_aliases()
|
||||
logging.info(f"Custom provider loaded {len(models)} models with {len(aliases)} aliases")
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model aliases to actual model names.
|
||||
|
||||
For Ollama-style models, strips version tags (e.g., 'llama3.2:latest' -> 'llama3.2')
|
||||
since the base model name is what's typically used in API calls.
|
||||
|
||||
Args:
|
||||
model_name: Input model name or alias
|
||||
|
||||
Returns:
|
||||
Resolved model name with version tags stripped if applicable
|
||||
"""
|
||||
# First, try to resolve through registry as-is
|
||||
config = self._registry.resolve(model_name)
|
||||
|
||||
if config:
|
||||
if config.model_name != model_name:
|
||||
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
||||
return config.model_name
|
||||
else:
|
||||
# If not found in registry, handle version tags for local models
|
||||
# Strip version tags (anything after ':') for Ollama-style models
|
||||
if ":" in model_name:
|
||||
base_model = model_name.split(":")[0]
|
||||
logging.debug(f"Stripped version tag from '{model_name}' -> '{base_model}'")
|
||||
|
||||
# Try to resolve the base model through registry
|
||||
base_config = self._registry.resolve(base_model)
|
||||
if base_config:
|
||||
logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'")
|
||||
return base_config.model_name
|
||||
else:
|
||||
return base_model
|
||||
else:
|
||||
# If not found in registry and no version tag, return as-is
|
||||
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
||||
return model_name
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a custom model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model (or alias)
|
||||
|
||||
Returns:
|
||||
ModelCapabilities from registry or generic defaults
|
||||
"""
|
||||
# Try to get from registry first
|
||||
capabilities = self._registry.get_capabilities(model_name)
|
||||
|
||||
if capabilities:
|
||||
# Update provider type to CUSTOM
|
||||
capabilities.provider = ProviderType.CUSTOM
|
||||
return capabilities
|
||||
else:
|
||||
# Resolve any potential aliases and create generic capabilities
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
logging.debug(
|
||||
f"Using generic capabilities for '{resolved_name}' via Custom API. "
|
||||
"Consider adding to custom_models.json for specific capabilities."
|
||||
)
|
||||
|
||||
# Create generic capabilities with conservative defaults
|
||||
capabilities = ModelCapabilities(
|
||||
provider=ProviderType.CUSTOM,
|
||||
model_name=resolved_name,
|
||||
friendly_name=f"{self.FRIENDLY_NAME} ({resolved_name})",
|
||||
context_window=32_768, # Conservative default
|
||||
supports_extended_thinking=False, # Most custom models don't support this
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # Conservative default
|
||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
|
||||
)
|
||||
|
||||
# Mark as generic for validation purposes
|
||||
capabilities._is_generic = True
|
||||
|
||||
return capabilities
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
return ProviderType.CUSTOM
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is allowed.
|
||||
|
||||
For custom endpoints, only accept models that are explicitly intended for
|
||||
local/custom usage. This provider should NOT handle OpenRouter or cloud models.
|
||||
|
||||
Args:
|
||||
model_name: Model name to validate
|
||||
|
||||
Returns:
|
||||
True if model is intended for custom/local endpoint
|
||||
"""
|
||||
logging.debug(f"Custom provider validating model: '{model_name}'")
|
||||
|
||||
# Try to resolve through registry first
|
||||
config = self._registry.resolve(model_name)
|
||||
if config:
|
||||
model_id = config.model_name
|
||||
# Only accept models that are clearly local/custom based on the resolved name
|
||||
# Local models should not have vendor/ prefix (except for special cases)
|
||||
is_local_model = (
|
||||
"/" not in model_id # Simple names like "llama3.2"
|
||||
or "local" in model_id.lower() # Explicit local indicator
|
||||
or
|
||||
# Check if any of the aliases contain local indicators
|
||||
any("local" in alias.lower() or "ollama" in alias.lower() for alias in config.aliases)
|
||||
if hasattr(config, "aliases")
|
||||
else False
|
||||
)
|
||||
|
||||
if is_local_model:
|
||||
logging.debug(f"Model '{model_name}' -> '{model_id}' validated via registry (local model)")
|
||||
return True
|
||||
else:
|
||||
# This is a cloud/OpenRouter model - reject it for custom provider
|
||||
logging.debug(f"Model '{model_name}' -> '{model_id}' rejected (cloud model for OpenRouter)")
|
||||
return False
|
||||
|
||||
# Strip :latest suffix and try validation again (it's just a version tag)
|
||||
clean_model_name = model_name
|
||||
if model_name.endswith(":latest"):
|
||||
clean_model_name = model_name[:-7] # Remove ":latest"
|
||||
logging.debug(f"Stripped :latest from '{model_name}' -> '{clean_model_name}'")
|
||||
# Try to resolve the clean name
|
||||
config = self._registry.resolve(clean_model_name)
|
||||
if config:
|
||||
return self.validate_model_name(clean_model_name) # Recursively validate clean name
|
||||
|
||||
# Accept models with explicit local indicators in the name
|
||||
if any(indicator in clean_model_name.lower() for indicator in ["local", "ollama", "vllm", "lmstudio"]):
|
||||
logging.debug(f"Model '{clean_model_name}' validated via local indicators")
|
||||
return True
|
||||
|
||||
# Accept simple model names without vendor prefix ONLY if they're not in registry
|
||||
# This allows for unknown local models like custom fine-tunes
|
||||
if "/" not in clean_model_name and ":" not in clean_model_name and not config:
|
||||
logging.debug(f"Model '{clean_model_name}' validated via simple name pattern (unknown local model)")
|
||||
return True
|
||||
|
||||
logging.debug(f"Model '{model_name}' NOT validated by custom provider")
|
||||
return False
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""Generate content using the custom API.
|
||||
|
||||
Args:
|
||||
prompt: User prompt to send to the model
|
||||
model_name: Name of the model to use
|
||||
system_prompt: Optional system prompt for model behavior
|
||||
temperature: Sampling temperature
|
||||
max_output_tokens: Maximum tokens to generate
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated content and metadata
|
||||
"""
|
||||
# Resolve model alias to actual model name
|
||||
resolved_model = self._resolve_model_name(model_name)
|
||||
|
||||
# Call parent method with resolved model name
|
||||
return super().generate_content(
|
||||
prompt=prompt,
|
||||
model_name=resolved_model,
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature,
|
||||
max_output_tokens=max_output_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
"""Check if the model supports extended thinking mode.
|
||||
|
||||
Most custom/local models don't support extended thinking.
|
||||
|
||||
Args:
|
||||
model_name: Model to check
|
||||
|
||||
Returns:
|
||||
False (custom models generally don't support thinking mode)
|
||||
"""
|
||||
return False
|
||||
@@ -3,7 +3,6 @@
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
@@ -36,7 +35,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
Args:
|
||||
api_key: API key for authentication
|
||||
base_url: Base URL for the API endpoint
|
||||
**kwargs: Additional configuration options
|
||||
**kwargs: Additional configuration options including timeout
|
||||
"""
|
||||
super().__init__(api_key, **kwargs)
|
||||
self._client = None
|
||||
@@ -44,6 +43,9 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
self.organization = kwargs.get("organization")
|
||||
self.allowed_models = self._parse_allowed_models()
|
||||
|
||||
# Configure timeouts - especially important for custom/local endpoints
|
||||
self.timeout_config = self._configure_timeouts(**kwargs)
|
||||
|
||||
# Validate base URL for security
|
||||
if self.base_url:
|
||||
self._validate_base_url()
|
||||
@@ -82,11 +84,59 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
|
||||
return None
|
||||
|
||||
def _is_localhost_url(self) -> bool:
|
||||
"""Check if the base URL points to localhost.
|
||||
def _configure_timeouts(self, **kwargs):
|
||||
"""Configure timeout settings based on provider type and custom settings.
|
||||
|
||||
Custom URLs and local models often need longer timeouts due to:
|
||||
- Network latency on local networks
|
||||
- Extended thinking models taking longer to respond
|
||||
- Local inference being slower than cloud APIs
|
||||
|
||||
Returns:
|
||||
True if URL is localhost, False otherwise
|
||||
httpx.Timeout object with appropriate timeout settings
|
||||
"""
|
||||
import httpx
|
||||
|
||||
# Default timeouts - more generous for custom/local endpoints
|
||||
default_connect = 30.0 # 30 seconds for connection (vs OpenAI's 5s)
|
||||
default_read = 600.0 # 10 minutes for reading (same as OpenAI default)
|
||||
default_write = 600.0 # 10 minutes for writing
|
||||
default_pool = 600.0 # 10 minutes for pool
|
||||
|
||||
# For custom/local URLs, use even longer timeouts
|
||||
if self.base_url and self._is_localhost_url():
|
||||
default_connect = 60.0 # 1 minute for local connections
|
||||
default_read = 1800.0 # 30 minutes for local models (extended thinking)
|
||||
default_write = 1800.0 # 30 minutes for local models
|
||||
default_pool = 1800.0 # 30 minutes for local models
|
||||
logging.info(f"Using extended timeouts for local endpoint: {self.base_url}")
|
||||
elif self.base_url:
|
||||
default_connect = 45.0 # 45 seconds for custom remote endpoints
|
||||
default_read = 900.0 # 15 minutes for custom remote endpoints
|
||||
default_write = 900.0 # 15 minutes for custom remote endpoints
|
||||
default_pool = 900.0 # 15 minutes for custom remote endpoints
|
||||
logging.info(f"Using extended timeouts for custom endpoint: {self.base_url}")
|
||||
|
||||
# Allow override via kwargs or environment variables in future, for now...
|
||||
connect_timeout = kwargs.get("connect_timeout", float(os.getenv("CUSTOM_CONNECT_TIMEOUT", default_connect)))
|
||||
read_timeout = kwargs.get("read_timeout", float(os.getenv("CUSTOM_READ_TIMEOUT", default_read)))
|
||||
write_timeout = kwargs.get("write_timeout", float(os.getenv("CUSTOM_WRITE_TIMEOUT", default_write)))
|
||||
pool_timeout = kwargs.get("pool_timeout", float(os.getenv("CUSTOM_POOL_TIMEOUT", default_pool)))
|
||||
|
||||
timeout = httpx.Timeout(connect=connect_timeout, read=read_timeout, write=write_timeout, pool=pool_timeout)
|
||||
|
||||
logging.debug(
|
||||
f"Configured timeouts - Connect: {connect_timeout}s, Read: {read_timeout}s, "
|
||||
f"Write: {write_timeout}s, Pool: {pool_timeout}s"
|
||||
)
|
||||
|
||||
return timeout
|
||||
|
||||
def _is_localhost_url(self) -> bool:
|
||||
"""Check if the base URL points to localhost or local network.
|
||||
|
||||
Returns:
|
||||
True if URL is localhost or local network, False otherwise
|
||||
"""
|
||||
if not self.base_url:
|
||||
return False
|
||||
@@ -99,6 +149,19 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
if hostname in ["localhost", "127.0.0.1", "::1"]:
|
||||
return True
|
||||
|
||||
# Check for Docker internal hostnames (like host.docker.internal)
|
||||
if hostname and ("docker.internal" in hostname or "host.docker.internal" in hostname):
|
||||
return True
|
||||
|
||||
# Check for private network ranges (local network)
|
||||
if hostname:
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
return ip.is_private or ip.is_loopback
|
||||
except ValueError:
|
||||
# Not an IP address, might be a hostname
|
||||
pass
|
||||
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
@@ -123,64 +186,10 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
if not parsed.hostname:
|
||||
raise ValueError("URL must include a hostname")
|
||||
|
||||
# Check port - allow only standard HTTP/HTTPS ports
|
||||
# Check port is valid (if specified)
|
||||
port = parsed.port
|
||||
if port is None:
|
||||
port = 443 if parsed.scheme == "https" else 80
|
||||
|
||||
# Allow common HTTP ports and some alternative ports
|
||||
allowed_ports = {80, 443, 8080, 8443, 4000, 3000} # Common API ports
|
||||
if port not in allowed_ports:
|
||||
raise ValueError(f"Port {port} not allowed. Allowed ports: {sorted(allowed_ports)}")
|
||||
|
||||
# Check against allowed domains if configured
|
||||
allowed_domains = os.getenv("ALLOWED_BASE_DOMAINS", "").split(",")
|
||||
allowed_domains = [d.strip().lower() for d in allowed_domains if d.strip()]
|
||||
|
||||
if allowed_domains:
|
||||
hostname_lower = parsed.hostname.lower()
|
||||
if not any(
|
||||
hostname_lower == domain or hostname_lower.endswith("." + domain) for domain in allowed_domains
|
||||
):
|
||||
raise ValueError(
|
||||
f"Domain not in allow-list: {parsed.hostname}. " f"Allowed domains: {allowed_domains}"
|
||||
)
|
||||
|
||||
# Try to resolve hostname and check if it's a private IP
|
||||
# Skip for localhost addresses which are commonly used for development
|
||||
if parsed.hostname not in ["localhost", "127.0.0.1", "::1"]:
|
||||
try:
|
||||
# Get all IP addresses for the hostname
|
||||
addr_info = socket.getaddrinfo(parsed.hostname, port, proto=socket.IPPROTO_TCP)
|
||||
|
||||
for _family, _, _, _, sockaddr in addr_info:
|
||||
ip_str = sockaddr[0]
|
||||
try:
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
|
||||
# Check for dangerous IP ranges
|
||||
if (
|
||||
ip.is_private
|
||||
or ip.is_loopback
|
||||
or ip.is_link_local
|
||||
or ip.is_multicast
|
||||
or ip.is_reserved
|
||||
or ip.is_unspecified
|
||||
):
|
||||
raise ValueError(
|
||||
f"URL resolves to restricted IP address: {ip_str}. "
|
||||
"This could be a security risk (SSRF)."
|
||||
)
|
||||
except ValueError as ve:
|
||||
# Invalid IP address format or restricted IP - re-raise if it's our security error
|
||||
if "restricted IP address" in str(ve):
|
||||
raise
|
||||
continue
|
||||
|
||||
except socket.gaierror as e:
|
||||
# If we can't resolve the hostname, it's suspicious
|
||||
raise ValueError(f"Cannot resolve hostname '{parsed.hostname}': {e}")
|
||||
|
||||
if port is not None and (port < 1 or port > 65535):
|
||||
raise ValueError(f"Invalid port number: {port}. Must be between 1 and 65535.")
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
@@ -188,7 +197,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Lazy initialization of OpenAI client with security checks."""
|
||||
"""Lazy initialization of OpenAI client with security checks and timeout configuration."""
|
||||
if self._client is None:
|
||||
client_kwargs = {
|
||||
"api_key": self.api_key,
|
||||
@@ -204,6 +213,11 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
if self.DEFAULT_HEADERS:
|
||||
client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
|
||||
|
||||
# Add configured timeout settings
|
||||
if hasattr(self, "timeout_config") and self.timeout_config:
|
||||
client_kwargs["timeout"] = self.timeout_config
|
||||
logging.debug(f"OpenAI client initialized with custom timeout: {self.timeout_config}")
|
||||
|
||||
self._client = OpenAI(**client_kwargs)
|
||||
|
||||
return self._client
|
||||
|
||||
@@ -39,8 +39,8 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
api_key: OpenRouter API key
|
||||
**kwargs: Additional configuration
|
||||
"""
|
||||
# Always use OpenRouter's base URL
|
||||
super().__init__(api_key, base_url="https://openrouter.ai/api/v1", **kwargs)
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
super().__init__(api_key, base_url=base_url, **kwargs)
|
||||
|
||||
# Initialize model registry
|
||||
if OpenRouterProvider._registry is None:
|
||||
@@ -101,7 +101,7 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
|
||||
logging.debug(
|
||||
f"Using generic capabilities for '{resolved_name}' via OpenRouter. "
|
||||
"Consider adding to openrouter_models.json for specific capabilities."
|
||||
"Consider adding to custom_models.json for specific capabilities."
|
||||
)
|
||||
|
||||
# Create generic capabilities with conservative defaults
|
||||
@@ -129,16 +129,18 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is allowed.
|
||||
|
||||
For OpenRouter, we accept any model name. OpenRouter will
|
||||
validate based on the API key's permissions.
|
||||
As the catch-all provider, OpenRouter accepts any model name that wasn't
|
||||
handled by higher-priority providers. OpenRouter will validate based on
|
||||
the API key's permissions.
|
||||
|
||||
Args:
|
||||
model_name: Model name to validate
|
||||
|
||||
Returns:
|
||||
Always True - OpenRouter handles validation
|
||||
Always True - OpenRouter is the catch-all provider
|
||||
"""
|
||||
# Accept any model name - OpenRouter will validate based on API key permissions
|
||||
# Accept any model name - OpenRouter is the fallback provider
|
||||
# Higher priority providers (native APIs, custom endpoints) get first chance
|
||||
return True
|
||||
|
||||
def generate_content(
|
||||
|
||||
@@ -7,6 +7,8 @@ from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from utils.file_utils import translate_path_for_environment
|
||||
|
||||
from .base import ModelCapabilities, ProviderType, RangeTemperatureConstraint
|
||||
|
||||
|
||||
@@ -53,15 +55,19 @@ class OpenRouterModelRegistry:
|
||||
|
||||
# Determine config path
|
||||
if config_path:
|
||||
self.config_path = Path(config_path)
|
||||
# Direct config_path parameter - translate for Docker if needed
|
||||
translated_path = translate_path_for_environment(config_path)
|
||||
self.config_path = Path(translated_path)
|
||||
else:
|
||||
# Check environment variable first
|
||||
env_path = os.getenv("OPENROUTER_MODELS_PATH")
|
||||
env_path = os.getenv("CUSTOM_MODELS_CONFIG_PATH")
|
||||
if env_path:
|
||||
self.config_path = Path(env_path)
|
||||
# Environment variable path - translate for Docker if needed
|
||||
translated_path = translate_path_for_environment(env_path)
|
||||
self.config_path = Path(translated_path)
|
||||
else:
|
||||
# Default to conf/openrouter_models.json
|
||||
self.config_path = Path(__file__).parent.parent / "conf" / "openrouter_models.json"
|
||||
# Default to conf/custom_models.json (already in container)
|
||||
self.config_path = Path(__file__).parent.parent / "conf" / "custom_models.json"
|
||||
|
||||
# Load configuration
|
||||
self.reload()
|
||||
@@ -125,6 +131,22 @@ class OpenRouterModelRegistry:
|
||||
# Add to model map
|
||||
model_map[config.model_name] = config
|
||||
|
||||
# Add the model_name itself as an alias for case-insensitive lookup
|
||||
# But only if it's not already in the aliases list
|
||||
model_name_lower = config.model_name.lower()
|
||||
aliases_lower = [alias.lower() for alias in config.aliases]
|
||||
|
||||
if model_name_lower not in aliases_lower:
|
||||
if model_name_lower in alias_map:
|
||||
existing_model = alias_map[model_name_lower]
|
||||
if existing_model != config.model_name:
|
||||
raise ValueError(
|
||||
f"Duplicate model name '{config.model_name}' (case-insensitive) found for models "
|
||||
f"'{existing_model}' and '{config.model_name}'"
|
||||
)
|
||||
else:
|
||||
alias_map[model_name_lower] = config.model_name
|
||||
|
||||
# Add aliases
|
||||
for alias in config.aliases:
|
||||
alias_lower = alias.lower()
|
||||
@@ -148,14 +170,13 @@ class OpenRouterModelRegistry:
|
||||
Returns:
|
||||
Model configuration if found, None otherwise
|
||||
"""
|
||||
# Try alias first (case-insensitive)
|
||||
# Try alias lookup (case-insensitive) - this now includes model names too
|
||||
alias_lower = name_or_alias.lower()
|
||||
if alias_lower in self.alias_map:
|
||||
model_name = self.alias_map[alias_lower]
|
||||
return self.model_map.get(model_name)
|
||||
|
||||
# Try as direct model name
|
||||
return self.model_map.get(name_or_alias)
|
||||
return None
|
||||
|
||||
def get_capabilities(self, name_or_alias: str) -> Optional[ModelCapabilities]:
|
||||
"""Get model capabilities for a name or alias.
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Model provider registry for managing available providers."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
@@ -10,13 +11,18 @@ class ModelProviderRegistry:
|
||||
"""Registry for managing model providers."""
|
||||
|
||||
_instance = None
|
||||
_providers: dict[ProviderType, type[ModelProvider]] = {}
|
||||
_initialized_providers: dict[ProviderType, ModelProvider] = {}
|
||||
|
||||
def __new__(cls):
|
||||
"""Singleton pattern for registry."""
|
||||
if cls._instance is None:
|
||||
logging.debug("REGISTRY: Creating new registry instance")
|
||||
cls._instance = super().__new__(cls)
|
||||
# Initialize instance dictionaries on first creation
|
||||
cls._instance._providers = {}
|
||||
cls._instance._initialized_providers = {}
|
||||
logging.debug(f"REGISTRY: Created instance {cls._instance}")
|
||||
else:
|
||||
logging.debug(f"REGISTRY: Returning existing instance {cls._instance}")
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
@@ -27,7 +33,8 @@ class ModelProviderRegistry:
|
||||
provider_type: Type of the provider (e.g., ProviderType.GOOGLE)
|
||||
provider_class: Class that implements ModelProvider interface
|
||||
"""
|
||||
cls._providers[provider_type] = provider_class
|
||||
instance = cls()
|
||||
instance._providers[provider_type] = provider_class
|
||||
|
||||
@classmethod
|
||||
def get_provider(cls, provider_type: ProviderType, force_new: bool = False) -> Optional[ModelProvider]:
|
||||
@@ -40,25 +47,48 @@ class ModelProviderRegistry:
|
||||
Returns:
|
||||
Initialized ModelProvider instance or None if not available
|
||||
"""
|
||||
instance = cls()
|
||||
|
||||
# Return cached instance if available and not forcing new
|
||||
if not force_new and provider_type in cls._initialized_providers:
|
||||
return cls._initialized_providers[provider_type]
|
||||
if not force_new and provider_type in instance._initialized_providers:
|
||||
return instance._initialized_providers[provider_type]
|
||||
|
||||
# Check if provider class is registered
|
||||
if provider_type not in cls._providers:
|
||||
if provider_type not in instance._providers:
|
||||
return None
|
||||
|
||||
# Get API key from environment
|
||||
api_key = cls._get_api_key_for_provider(provider_type)
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
# Initialize provider
|
||||
provider_class = cls._providers[provider_type]
|
||||
provider = provider_class(api_key=api_key)
|
||||
# Get provider class or factory function
|
||||
provider_class = instance._providers[provider_type]
|
||||
|
||||
# For custom providers, handle special initialization requirements
|
||||
if provider_type == ProviderType.CUSTOM:
|
||||
# Check if it's a factory function (callable but not a class)
|
||||
if callable(provider_class) and not isinstance(provider_class, type):
|
||||
# Factory function - call it with api_key parameter
|
||||
provider = provider_class(api_key=api_key)
|
||||
else:
|
||||
# Regular class - need to handle URL requirement
|
||||
custom_url = os.getenv("CUSTOM_API_URL", "")
|
||||
if not custom_url:
|
||||
if api_key: # Key is set but URL is missing
|
||||
logging.warning("CUSTOM_API_KEY set but CUSTOM_API_URL missing – skipping Custom provider")
|
||||
return None
|
||||
# Use empty string as API key for custom providers that don't need auth (e.g., Ollama)
|
||||
# This allows the provider to be created even without CUSTOM_API_KEY being set
|
||||
api_key = api_key or ""
|
||||
# Initialize custom provider with both API key and base URL
|
||||
provider = provider_class(api_key=api_key, base_url=custom_url)
|
||||
else:
|
||||
if not api_key:
|
||||
return None
|
||||
# Initialize non-custom provider with just API key
|
||||
provider = provider_class(api_key=api_key)
|
||||
|
||||
# Cache the instance
|
||||
cls._initialized_providers[provider_type] = provider
|
||||
instance._initialized_providers[provider_type] = provider
|
||||
|
||||
return provider
|
||||
|
||||
@@ -66,25 +96,55 @@ class ModelProviderRegistry:
|
||||
def get_provider_for_model(cls, model_name: str) -> Optional[ModelProvider]:
|
||||
"""Get provider instance for a specific model name.
|
||||
|
||||
Provider priority order:
|
||||
1. Native APIs (GOOGLE, OPENAI) - Most direct and efficient
|
||||
2. CUSTOM - For local/private models with specific endpoints
|
||||
3. OPENROUTER - Catch-all for cloud models via unified API
|
||||
|
||||
Args:
|
||||
model_name: Name of the model (e.g., "gemini-2.5-flash-preview-05-20", "o3-mini")
|
||||
|
||||
Returns:
|
||||
ModelProvider instance that supports this model
|
||||
"""
|
||||
# Check each registered provider
|
||||
for provider_type, _provider_class in cls._providers.items():
|
||||
# Get or create provider instance
|
||||
provider = cls.get_provider(provider_type)
|
||||
if provider and provider.validate_model_name(model_name):
|
||||
return provider
|
||||
logging.debug(f"get_provider_for_model called with model_name='{model_name}'")
|
||||
|
||||
# Define explicit provider priority order
|
||||
# Native APIs first, then custom endpoints, then catch-all providers
|
||||
PROVIDER_PRIORITY_ORDER = [
|
||||
ProviderType.GOOGLE, # Direct Gemini access
|
||||
ProviderType.OPENAI, # Direct OpenAI access
|
||||
ProviderType.CUSTOM, # Local/self-hosted models
|
||||
ProviderType.OPENROUTER, # Catch-all for cloud models
|
||||
]
|
||||
|
||||
# Check providers in priority order
|
||||
instance = cls()
|
||||
logging.debug(f"Registry instance: {instance}")
|
||||
logging.debug(f"Available providers in registry: {list(instance._providers.keys())}")
|
||||
|
||||
for provider_type in PROVIDER_PRIORITY_ORDER:
|
||||
logging.debug(f"Checking provider_type: {provider_type}")
|
||||
if provider_type in instance._providers:
|
||||
logging.debug(f"Found {provider_type} in registry")
|
||||
# Get or create provider instance
|
||||
provider = cls.get_provider(provider_type)
|
||||
if provider and provider.validate_model_name(model_name):
|
||||
logging.debug(f"{provider_type} validates model {model_name}")
|
||||
return provider
|
||||
else:
|
||||
logging.debug(f"{provider_type} does not validate model {model_name}")
|
||||
else:
|
||||
logging.debug(f"{provider_type} not found in registry")
|
||||
|
||||
logging.debug(f"No provider found for model {model_name}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_available_providers(cls) -> list[ProviderType]:
|
||||
"""Get list of registered provider types."""
|
||||
return list(cls._providers.keys())
|
||||
instance = cls()
|
||||
return list(instance._providers.keys())
|
||||
|
||||
@classmethod
|
||||
def get_available_models(cls) -> dict[str, ProviderType]:
|
||||
@@ -94,8 +154,9 @@ class ModelProviderRegistry:
|
||||
Dict mapping model names to provider types
|
||||
"""
|
||||
models = {}
|
||||
instance = cls()
|
||||
|
||||
for provider_type in cls._providers:
|
||||
for provider_type in instance._providers:
|
||||
provider = cls.get_provider(provider_type)
|
||||
if provider:
|
||||
# This assumes providers have a method to list supported models
|
||||
@@ -118,6 +179,7 @@ class ModelProviderRegistry:
|
||||
ProviderType.GOOGLE: "GEMINI_API_KEY",
|
||||
ProviderType.OPENAI: "OPENAI_API_KEY",
|
||||
ProviderType.OPENROUTER: "OPENROUTER_API_KEY",
|
||||
ProviderType.CUSTOM: "CUSTOM_API_KEY", # Can be empty for providers that don't need auth
|
||||
}
|
||||
|
||||
env_var = key_mapping.get(provider_type)
|
||||
@@ -165,7 +227,8 @@ class ModelProviderRegistry:
|
||||
List of ProviderType values for providers with valid API keys
|
||||
"""
|
||||
available = []
|
||||
for provider_type in cls._providers:
|
||||
instance = cls()
|
||||
for provider_type in instance._providers:
|
||||
if cls.get_provider(provider_type) is not None:
|
||||
available.append(provider_type)
|
||||
return available
|
||||
@@ -173,10 +236,12 @@ class ModelProviderRegistry:
|
||||
@classmethod
|
||||
def clear_cache(cls) -> None:
|
||||
"""Clear cached provider instances."""
|
||||
cls._initialized_providers.clear()
|
||||
instance = cls()
|
||||
instance._initialized_providers.clear()
|
||||
|
||||
@classmethod
|
||||
def unregister_provider(cls, provider_type: ProviderType) -> None:
|
||||
"""Unregister a provider (mainly for testing)."""
|
||||
cls._providers.pop(provider_type, None)
|
||||
cls._initialized_providers.pop(provider_type, None)
|
||||
instance = cls()
|
||||
instance._providers.pop(provider_type, None)
|
||||
instance._initialized_providers.pop(provider_type, None)
|
||||
|
||||
Reference in New Issue
Block a user