Breaking change: openrouter_models.json -> custom_models.json
* Support for Custom URLs and custom models, including locally hosted models such as ollama * Support for native + openrouter + local models (i.e. dozens of models) means you can start delegating sub-tasks to particular models or work to local models such as localizations or other boring work etc. * Several tests added * precommit to also include untracked (new) files * Logfile auto rollover * Improved logging
This commit is contained in:
273
providers/custom.py
Normal file
273
providers/custom.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""Custom API provider implementation."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from .base import (
|
||||
ModelCapabilities,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint,
|
||||
)
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .openrouter_registry import OpenRouterModelRegistry
|
||||
|
||||
|
||||
class CustomProvider(OpenAICompatibleProvider):
|
||||
"""Custom API provider for local models.
|
||||
|
||||
Supports local inference servers like Ollama, vLLM, LM Studio,
|
||||
and any OpenAI-compatible API endpoint.
|
||||
"""
|
||||
|
||||
FRIENDLY_NAME = "Custom API"
|
||||
|
||||
# Model registry for managing configurations and aliases (shared with OpenRouter)
|
||||
_registry: Optional[OpenRouterModelRegistry] = None
|
||||
|
||||
def __init__(self, api_key: str = "", base_url: str = "", **kwargs):
|
||||
"""Initialize Custom provider for local/self-hosted models.
|
||||
|
||||
This provider supports any OpenAI-compatible API endpoint including:
|
||||
- Ollama (typically no API key required)
|
||||
- vLLM (may require API key)
|
||||
- LM Studio (may require API key)
|
||||
- Text Generation WebUI (may require API key)
|
||||
- Enterprise/self-hosted APIs (typically require API key)
|
||||
|
||||
Args:
|
||||
api_key: API key for the custom endpoint. Can be empty string for
|
||||
providers that don't require authentication (like Ollama).
|
||||
Falls back to CUSTOM_API_KEY environment variable if not provided.
|
||||
base_url: Base URL for the custom API endpoint (e.g., 'http://host.docker.internal:11434/v1').
|
||||
Falls back to CUSTOM_API_URL environment variable if not provided.
|
||||
**kwargs: Additional configuration passed to parent OpenAI-compatible provider
|
||||
|
||||
Raises:
|
||||
ValueError: If no base_url is provided via parameter or environment variable
|
||||
"""
|
||||
# Fall back to environment variables only if not provided
|
||||
if not base_url:
|
||||
base_url = os.getenv("CUSTOM_API_URL", "")
|
||||
if not api_key:
|
||||
api_key = os.getenv("CUSTOM_API_KEY", "")
|
||||
|
||||
if not base_url:
|
||||
raise ValueError(
|
||||
"Custom API URL must be provided via base_url parameter or CUSTOM_API_URL environment variable"
|
||||
)
|
||||
|
||||
# For Ollama and other providers that don't require authentication,
|
||||
# set a dummy API key to avoid OpenAI client header issues
|
||||
if not api_key:
|
||||
api_key = "dummy-key-for-unauthenticated-endpoint"
|
||||
logging.debug("Using dummy API key for unauthenticated custom endpoint")
|
||||
|
||||
logging.info(f"Initializing Custom provider with endpoint: {base_url}")
|
||||
|
||||
super().__init__(api_key, base_url=base_url, **kwargs)
|
||||
|
||||
# Initialize model registry (shared with OpenRouter for consistent aliases)
|
||||
if CustomProvider._registry is None:
|
||||
CustomProvider._registry = OpenRouterModelRegistry()
|
||||
|
||||
# Log loaded models and aliases
|
||||
models = self._registry.list_models()
|
||||
aliases = self._registry.list_aliases()
|
||||
logging.info(f"Custom provider loaded {len(models)} models with {len(aliases)} aliases")
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model aliases to actual model names.
|
||||
|
||||
For Ollama-style models, strips version tags (e.g., 'llama3.2:latest' -> 'llama3.2')
|
||||
since the base model name is what's typically used in API calls.
|
||||
|
||||
Args:
|
||||
model_name: Input model name or alias
|
||||
|
||||
Returns:
|
||||
Resolved model name with version tags stripped if applicable
|
||||
"""
|
||||
# First, try to resolve through registry as-is
|
||||
config = self._registry.resolve(model_name)
|
||||
|
||||
if config:
|
||||
if config.model_name != model_name:
|
||||
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
||||
return config.model_name
|
||||
else:
|
||||
# If not found in registry, handle version tags for local models
|
||||
# Strip version tags (anything after ':') for Ollama-style models
|
||||
if ":" in model_name:
|
||||
base_model = model_name.split(":")[0]
|
||||
logging.debug(f"Stripped version tag from '{model_name}' -> '{base_model}'")
|
||||
|
||||
# Try to resolve the base model through registry
|
||||
base_config = self._registry.resolve(base_model)
|
||||
if base_config:
|
||||
logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'")
|
||||
return base_config.model_name
|
||||
else:
|
||||
return base_model
|
||||
else:
|
||||
# If not found in registry and no version tag, return as-is
|
||||
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
||||
return model_name
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a custom model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model (or alias)
|
||||
|
||||
Returns:
|
||||
ModelCapabilities from registry or generic defaults
|
||||
"""
|
||||
# Try to get from registry first
|
||||
capabilities = self._registry.get_capabilities(model_name)
|
||||
|
||||
if capabilities:
|
||||
# Update provider type to CUSTOM
|
||||
capabilities.provider = ProviderType.CUSTOM
|
||||
return capabilities
|
||||
else:
|
||||
# Resolve any potential aliases and create generic capabilities
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
logging.debug(
|
||||
f"Using generic capabilities for '{resolved_name}' via Custom API. "
|
||||
"Consider adding to custom_models.json for specific capabilities."
|
||||
)
|
||||
|
||||
# Create generic capabilities with conservative defaults
|
||||
capabilities = ModelCapabilities(
|
||||
provider=ProviderType.CUSTOM,
|
||||
model_name=resolved_name,
|
||||
friendly_name=f"{self.FRIENDLY_NAME} ({resolved_name})",
|
||||
context_window=32_768, # Conservative default
|
||||
supports_extended_thinking=False, # Most custom models don't support this
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # Conservative default
|
||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
|
||||
)
|
||||
|
||||
# Mark as generic for validation purposes
|
||||
capabilities._is_generic = True
|
||||
|
||||
return capabilities
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
return ProviderType.CUSTOM
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is allowed.
|
||||
|
||||
For custom endpoints, only accept models that are explicitly intended for
|
||||
local/custom usage. This provider should NOT handle OpenRouter or cloud models.
|
||||
|
||||
Args:
|
||||
model_name: Model name to validate
|
||||
|
||||
Returns:
|
||||
True if model is intended for custom/local endpoint
|
||||
"""
|
||||
logging.debug(f"Custom provider validating model: '{model_name}'")
|
||||
|
||||
# Try to resolve through registry first
|
||||
config = self._registry.resolve(model_name)
|
||||
if config:
|
||||
model_id = config.model_name
|
||||
# Only accept models that are clearly local/custom based on the resolved name
|
||||
# Local models should not have vendor/ prefix (except for special cases)
|
||||
is_local_model = (
|
||||
"/" not in model_id # Simple names like "llama3.2"
|
||||
or "local" in model_id.lower() # Explicit local indicator
|
||||
or
|
||||
# Check if any of the aliases contain local indicators
|
||||
any("local" in alias.lower() or "ollama" in alias.lower() for alias in config.aliases)
|
||||
if hasattr(config, "aliases")
|
||||
else False
|
||||
)
|
||||
|
||||
if is_local_model:
|
||||
logging.debug(f"Model '{model_name}' -> '{model_id}' validated via registry (local model)")
|
||||
return True
|
||||
else:
|
||||
# This is a cloud/OpenRouter model - reject it for custom provider
|
||||
logging.debug(f"Model '{model_name}' -> '{model_id}' rejected (cloud model for OpenRouter)")
|
||||
return False
|
||||
|
||||
# Strip :latest suffix and try validation again (it's just a version tag)
|
||||
clean_model_name = model_name
|
||||
if model_name.endswith(":latest"):
|
||||
clean_model_name = model_name[:-7] # Remove ":latest"
|
||||
logging.debug(f"Stripped :latest from '{model_name}' -> '{clean_model_name}'")
|
||||
# Try to resolve the clean name
|
||||
config = self._registry.resolve(clean_model_name)
|
||||
if config:
|
||||
return self.validate_model_name(clean_model_name) # Recursively validate clean name
|
||||
|
||||
# Accept models with explicit local indicators in the name
|
||||
if any(indicator in clean_model_name.lower() for indicator in ["local", "ollama", "vllm", "lmstudio"]):
|
||||
logging.debug(f"Model '{clean_model_name}' validated via local indicators")
|
||||
return True
|
||||
|
||||
# Accept simple model names without vendor prefix ONLY if they're not in registry
|
||||
# This allows for unknown local models like custom fine-tunes
|
||||
if "/" not in clean_model_name and ":" not in clean_model_name and not config:
|
||||
logging.debug(f"Model '{clean_model_name}' validated via simple name pattern (unknown local model)")
|
||||
return True
|
||||
|
||||
logging.debug(f"Model '{model_name}' NOT validated by custom provider")
|
||||
return False
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""Generate content using the custom API.
|
||||
|
||||
Args:
|
||||
prompt: User prompt to send to the model
|
||||
model_name: Name of the model to use
|
||||
system_prompt: Optional system prompt for model behavior
|
||||
temperature: Sampling temperature
|
||||
max_output_tokens: Maximum tokens to generate
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated content and metadata
|
||||
"""
|
||||
# Resolve model alias to actual model name
|
||||
resolved_model = self._resolve_model_name(model_name)
|
||||
|
||||
# Call parent method with resolved model name
|
||||
return super().generate_content(
|
||||
prompt=prompt,
|
||||
model_name=resolved_model,
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature,
|
||||
max_output_tokens=max_output_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
"""Check if the model supports extended thinking mode.
|
||||
|
||||
Most custom/local models don't support extended thinking.
|
||||
|
||||
Args:
|
||||
model_name: Model to check
|
||||
|
||||
Returns:
|
||||
False (custom models generally don't support thinking mode)
|
||||
"""
|
||||
return False
|
||||
Reference in New Issue
Block a user