"""OpenAI model provider implementation.""" import logging from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from tools.models import ToolModelCategory from .openai_compatible import OpenAICompatibleProvider from .openai_registry import OpenAIModelRegistry from .shared import ModelCapabilities, ProviderType logger = logging.getLogger(__name__) class OpenAIModelProvider(OpenAICompatibleProvider): """Implementation that talks to api.openai.com using rich model metadata. In addition to the built-in catalogue, the provider can surface models defined in ``conf/custom_models.json`` (for organisations running their own OpenAI-compatible gateways) while still respecting restriction policies. """ MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {} _registry: Optional[OpenAIModelRegistry] = None def __init__(self, api_key: str, **kwargs): """Initialize OpenAI provider with API key.""" self._ensure_registry() # Set default OpenAI base URL, allow override for regions/custom endpoints kwargs.setdefault("base_url", "https://api.openai.com/v1") super().__init__(api_key, **kwargs) self._invalidate_capability_cache() # ------------------------------------------------------------------ # Registry access # ------------------------------------------------------------------ @classmethod def _ensure_registry(cls, *, force_reload: bool = False) -> None: """Load capability registry into MODEL_CAPABILITIES.""" if cls._registry is not None and not force_reload: return try: registry = OpenAIModelRegistry() except Exception as exc: # pragma: no cover - defensive logging logger.warning("Unable to load OpenAI model registry: %s", exc) cls._registry = None cls.MODEL_CAPABILITIES = {} return cls._registry = registry cls.MODEL_CAPABILITIES = dict(registry.model_map) @classmethod def reload_registry(cls) -> None: """Force registry reload (primarily for tests).""" cls._ensure_registry(force_reload=True) def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]: self._ensure_registry() return super().get_all_model_capabilities() def get_model_registry(self) -> Optional[dict[str, ModelCapabilities]]: if self._registry is None: return None return dict(self._registry.model_map) # ------------------------------------------------------------------ # Capability surface # ------------------------------------------------------------------ def _lookup_capabilities( self, canonical_name: str, requested_name: Optional[str] = None, ) -> Optional[ModelCapabilities]: """Look up OpenAI capabilities from built-ins or the custom registry.""" self._ensure_registry() builtin = super()._lookup_capabilities(canonical_name, requested_name) if builtin is not None: return builtin try: from .openrouter_registry import OpenRouterModelRegistry registry = OpenRouterModelRegistry() config = registry.get_model_config(canonical_name) if config and config.provider == ProviderType.OPENAI: return config except Exception as exc: # pragma: no cover - registry failures are non-critical logger.debug(f"Could not resolve custom OpenAI model '{canonical_name}': {exc}") return None def _finalise_capabilities( self, capabilities: ModelCapabilities, canonical_name: str, requested_name: str, ) -> ModelCapabilities: """Ensure registry-sourced models report the correct provider type.""" if capabilities.provider != ProviderType.OPENAI: capabilities.provider = ProviderType.OPENAI return capabilities def _raise_unsupported_model(self, model_name: str) -> None: raise ValueError(f"Unsupported OpenAI model: {model_name}") # ------------------------------------------------------------------ # Provider identity # ------------------------------------------------------------------ def get_provider_type(self) -> ProviderType: """Get the provider type.""" return ProviderType.OPENAI # ------------------------------------------------------------------ # Provider preferences # ------------------------------------------------------------------ def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]: """Get OpenAI's preferred model for a given category from allowed models. Args: category: The tool category requiring a model allowed_models: Pre-filtered list of models allowed by restrictions Returns: Preferred model name or None """ from tools.models import ToolModelCategory if not allowed_models: return None # Helper to find first available from preference list def find_first(preferences: list[str]) -> Optional[str]: """Return first available model from preference list.""" for model in preferences: if model in allowed_models: return model return None if category == ToolModelCategory.EXTENDED_REASONING: # Prefer models with extended thinking support # GPT-5-Codex first for coding tasks preferred = find_first(["gpt-5-codex", "gpt-5-pro", "o3", "o3-pro", "gpt-5"]) return preferred if preferred else allowed_models[0] elif category == ToolModelCategory.FAST_RESPONSE: # Prefer fast, cost-efficient models # GPT-5 models for speed, GPT-5-Codex after (premium pricing but cached) preferred = find_first(["gpt-5", "gpt-5-mini", "gpt-5-codex", "o4-mini", "o3-mini"]) return preferred if preferred else allowed_models[0] else: # BALANCED or default # Prefer balanced performance/cost models # Include GPT-5-Codex for coding workflows preferred = find_first(["gpt-5", "gpt-5-codex", "gpt-5-pro", "gpt-5-mini", "o4-mini", "o3-mini"]) return preferred if preferred else allowed_models[0] # Load registry data at import time so dependent providers (Azure) can reuse it OpenAIModelProvider._ensure_registry()