"""Custom API provider implementation.""" import logging import os from typing import Optional from .base import ( ModelCapabilities, ModelResponse, ProviderType, RangeTemperatureConstraint, ) from .openai_compatible import OpenAICompatibleProvider from .openrouter_registry import OpenRouterModelRegistry class CustomProvider(OpenAICompatibleProvider): """Custom API provider for local models. Supports local inference servers like Ollama, vLLM, LM Studio, and any OpenAI-compatible API endpoint. """ FRIENDLY_NAME = "Custom API" # Model registry for managing configurations and aliases (shared with OpenRouter) _registry: Optional[OpenRouterModelRegistry] = None def __init__(self, api_key: str = "", base_url: str = "", **kwargs): """Initialize Custom provider for local/self-hosted models. This provider supports any OpenAI-compatible API endpoint including: - Ollama (typically no API key required) - vLLM (may require API key) - LM Studio (may require API key) - Text Generation WebUI (may require API key) - Enterprise/self-hosted APIs (typically require API key) Args: api_key: API key for the custom endpoint. Can be empty string for providers that don't require authentication (like Ollama). Falls back to CUSTOM_API_KEY environment variable if not provided. base_url: Base URL for the custom API endpoint (e.g., 'http://host.docker.internal:11434/v1'). Falls back to CUSTOM_API_URL environment variable if not provided. **kwargs: Additional configuration passed to parent OpenAI-compatible provider Raises: ValueError: If no base_url is provided via parameter or environment variable """ # Fall back to environment variables only if not provided if not base_url: base_url = os.getenv("CUSTOM_API_URL", "") if not api_key: api_key = os.getenv("CUSTOM_API_KEY", "") if not base_url: raise ValueError( "Custom API URL must be provided via base_url parameter or CUSTOM_API_URL environment variable" ) # For Ollama and other providers that don't require authentication, # set a dummy API key to avoid OpenAI client header issues if not api_key: api_key = "dummy-key-for-unauthenticated-endpoint" logging.debug("Using dummy API key for unauthenticated custom endpoint") logging.info(f"Initializing Custom provider with endpoint: {base_url}") super().__init__(api_key, base_url=base_url, **kwargs) # Initialize model registry (shared with OpenRouter for consistent aliases) if CustomProvider._registry is None: CustomProvider._registry = OpenRouterModelRegistry() # Log loaded models and aliases models = self._registry.list_models() aliases = self._registry.list_aliases() logging.info(f"Custom provider loaded {len(models)} models with {len(aliases)} aliases") def _resolve_model_name(self, model_name: str) -> str: """Resolve model aliases to actual model names. For Ollama-style models, strips version tags (e.g., 'llama3.2:latest' -> 'llama3.2') since the base model name is what's typically used in API calls. Args: model_name: Input model name or alias Returns: Resolved model name with version tags stripped if applicable """ # First, try to resolve through registry as-is config = self._registry.resolve(model_name) if config: if config.model_name != model_name: logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'") return config.model_name else: # If not found in registry, handle version tags for local models # Strip version tags (anything after ':') for Ollama-style models if ":" in model_name: base_model = model_name.split(":")[0] logging.debug(f"Stripped version tag from '{model_name}' -> '{base_model}'") # Try to resolve the base model through registry base_config = self._registry.resolve(base_model) if base_config: logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'") return base_config.model_name else: return base_model else: # If not found in registry and no version tag, return as-is logging.debug(f"Model '{model_name}' not found in registry, using as-is") return model_name def get_capabilities(self, model_name: str) -> ModelCapabilities: """Get capabilities for a custom model. Args: model_name: Name of the model (or alias) Returns: ModelCapabilities from registry or generic defaults """ # Try to get from registry first capabilities = self._registry.get_capabilities(model_name) if capabilities: # Update provider type to CUSTOM capabilities.provider = ProviderType.CUSTOM return capabilities else: # Resolve any potential aliases and create generic capabilities resolved_name = self._resolve_model_name(model_name) logging.debug( f"Using generic capabilities for '{resolved_name}' via Custom API. " "Consider adding to custom_models.json for specific capabilities." ) # Create generic capabilities with conservative defaults capabilities = ModelCapabilities( provider=ProviderType.CUSTOM, model_name=resolved_name, friendly_name=f"{self.FRIENDLY_NAME} ({resolved_name})", context_window=32_768, # Conservative default supports_extended_thinking=False, # Most custom models don't support this supports_system_prompts=True, supports_streaming=True, supports_function_calling=False, # Conservative default temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7), ) # Mark as generic for validation purposes capabilities._is_generic = True return capabilities def get_provider_type(self) -> ProviderType: """Get the provider type.""" return ProviderType.CUSTOM def validate_model_name(self, model_name: str) -> bool: """Validate if the model name is allowed. For custom endpoints, only accept models that are explicitly intended for local/custom usage. This provider should NOT handle OpenRouter or cloud models. Args: model_name: Model name to validate Returns: True if model is intended for custom/local endpoint """ logging.debug(f"Custom provider validating model: '{model_name}'") # Try to resolve through registry first config = self._registry.resolve(model_name) if config: model_id = config.model_name # Only accept models that are clearly local/custom based on the resolved name # Local models should not have vendor/ prefix (except for special cases) is_local_model = ( "/" not in model_id # Simple names like "llama3.2" or "local" in model_id.lower() # Explicit local indicator or # Check if any of the aliases contain local indicators any("local" in alias.lower() or "ollama" in alias.lower() for alias in config.aliases) if hasattr(config, "aliases") else False ) if is_local_model: logging.debug(f"Model '{model_name}' -> '{model_id}' validated via registry (local model)") return True else: # This is a cloud/OpenRouter model - reject it for custom provider logging.debug(f"Model '{model_name}' -> '{model_id}' rejected (cloud model for OpenRouter)") return False # Strip :latest suffix and try validation again (it's just a version tag) clean_model_name = model_name if model_name.endswith(":latest"): clean_model_name = model_name[:-7] # Remove ":latest" logging.debug(f"Stripped :latest from '{model_name}' -> '{clean_model_name}'") # Try to resolve the clean name config = self._registry.resolve(clean_model_name) if config: return self.validate_model_name(clean_model_name) # Recursively validate clean name # Accept models with explicit local indicators in the name if any(indicator in clean_model_name.lower() for indicator in ["local", "ollama", "vllm", "lmstudio"]): logging.debug(f"Model '{clean_model_name}' validated via local indicators") return True # Accept simple model names without vendor prefix ONLY if they're not in registry # This allows for unknown local models like custom fine-tunes if "/" not in clean_model_name and ":" not in clean_model_name and not config: logging.debug(f"Model '{clean_model_name}' validated via simple name pattern (unknown local model)") return True logging.debug(f"Model '{model_name}' NOT validated by custom provider") return False def generate_content( self, prompt: str, model_name: str, system_prompt: Optional[str] = None, temperature: float = 0.7, max_output_tokens: Optional[int] = None, **kwargs, ) -> ModelResponse: """Generate content using the custom API. Args: prompt: User prompt to send to the model model_name: Name of the model to use system_prompt: Optional system prompt for model behavior temperature: Sampling temperature max_output_tokens: Maximum tokens to generate **kwargs: Additional provider-specific parameters Returns: ModelResponse with generated content and metadata """ # Resolve model alias to actual model name resolved_model = self._resolve_model_name(model_name) # Call parent method with resolved model name return super().generate_content( prompt=prompt, model_name=resolved_model, system_prompt=system_prompt, temperature=temperature, max_output_tokens=max_output_tokens, **kwargs, ) def supports_thinking_mode(self, model_name: str) -> bool: """Check if the model supports extended thinking mode. Most custom/local models don't support extended thinking. Args: model_name: Model to check Returns: False (custom models generally don't support thinking mode) """ return False