diff --git a/.env.example b/.env.example index 81cb082..7986475 100644 --- a/.env.example +++ b/.env.example @@ -12,6 +12,7 @@ # Option 1: Use native APIs (recommended for direct access) # Get your Gemini API key from: https://makersuite.google.com/app/apikey GEMINI_API_KEY=your_gemini_api_key_here +# GEMINI_BASE_URL= # Optional: Custom Gemini endpoint (defaults to Google's API) # Get your OpenAI API key from: https://platform.openai.com/api-keys OPENAI_API_KEY=your_openai_api_key_here diff --git a/providers/gemini.py b/providers/gemini.py index 5c587ad..0cab004 100644 --- a/providers/gemini.py +++ b/providers/gemini.py @@ -117,16 +117,25 @@ class GeminiModelProvider(ModelProvider): } def __init__(self, api_key: str, **kwargs): - """Initialize Gemini provider with API key.""" + """Initialize Gemini provider with API key and optional base URL.""" super().__init__(api_key, **kwargs) self._client = None self._token_counters = {} # Cache for token counting + self._base_url = kwargs.get("base_url", None) # Optional custom endpoint @property def client(self): """Lazy initialization of Gemini client.""" if self._client is None: - self._client = genai.Client(api_key=self.api_key) + # Check if custom base URL is provided + if self._base_url: + # Use HttpOptions to set custom endpoint + http_options = types.HttpOptions(baseUrl=self._base_url) + logger.debug(f"Initializing Gemini client with custom endpoint: {self._base_url}") + self._client = genai.Client(api_key=self.api_key, http_options=http_options) + else: + # Use default Google endpoint + self._client = genai.Client(api_key=self.api_key) return self._client def get_capabilities(self, model_name: str) -> ModelCapabilities: diff --git a/providers/registry.py b/providers/registry.py index 1bb232d..7a1b94e 100644 --- a/providers/registry.py +++ b/providers/registry.py @@ -93,6 +93,16 @@ class ModelProviderRegistry: api_key = api_key or "" # Initialize custom provider with both API key and base URL provider = provider_class(api_key=api_key, base_url=custom_url) + elif provider_type == ProviderType.GOOGLE: + # For Gemini, check if custom base URL is configured + if not api_key: + return None + gemini_base_url = os.getenv("GEMINI_BASE_URL") + provider_kwargs = {"api_key": api_key} + if gemini_base_url: + provider_kwargs["base_url"] = gemini_base_url + logging.info(f"Initialized Gemini provider with custom endpoint: {gemini_base_url}") + provider = provider_class(**provider_kwargs) else: if not api_key: return None