WIP major refactor and features

This commit is contained in:
Fahad
2025-06-12 07:14:59 +04:00
parent e06a6fd1fc
commit 2a067a7f4e
46 changed files with 2960 additions and 1011 deletions

View File

@@ -20,13 +20,12 @@ import re
from abc import ABC, abstractmethod
from typing import Any, Literal, Optional
from google import genai
from google.genai import types
from mcp.types import TextContent
from pydantic import BaseModel, Field
from config import DEFAULT_MODEL, MAX_CONTEXT_TOKENS, MCP_PROMPT_SIZE_LIMIT
from utils import check_token_limit
from providers import ModelProviderRegistry, ModelProvider, ModelResponse
from utils.conversation_memory import (
MAX_CONVERSATION_TURNS,
add_turn,
@@ -52,7 +51,7 @@ class ToolRequest(BaseModel):
model: Optional[str] = Field(
None,
description=f"Model to use: 'pro' (Gemini 2.5 Pro with extended thinking) or 'flash' (Gemini 2.0 Flash - faster). Defaults to '{DEFAULT_MODEL}' if not specified.",
description="Model to use. See tool's input schema for available models and their capabilities.",
)
temperature: Optional[float] = Field(None, description="Temperature for response (tool-specific defaults)")
# Thinking mode controls how much computational budget the model uses for reasoning
@@ -144,6 +143,38 @@ class BaseTool(ABC):
"""
pass
def get_model_field_schema(self) -> dict[str, Any]:
"""
Generate the model field schema based on auto mode configuration.
When auto mode is enabled, the model parameter becomes required
and includes detailed descriptions of each model's capabilities.
Returns:
Dict containing the model field JSON schema
"""
from config import DEFAULT_MODEL, IS_AUTO_MODE, MODEL_CAPABILITIES_DESC
if IS_AUTO_MODE:
# In auto mode, model is required and we provide detailed descriptions
model_desc_parts = ["Choose the best model for this task based on these capabilities:"]
for model, desc in MODEL_CAPABILITIES_DESC.items():
model_desc_parts.append(f"- '{model}': {desc}")
return {
"type": "string",
"description": "\n".join(model_desc_parts),
"enum": list(MODEL_CAPABILITIES_DESC.keys()),
}
else:
# Normal mode - model is optional with default
available_models = list(MODEL_CAPABILITIES_DESC.keys())
models_str = ', '.join(f"'{m}'" for m in available_models)
return {
"type": "string",
"description": f"Model to use. Available: {models_str}. Defaults to '{DEFAULT_MODEL}' if not specified.",
}
def get_default_temperature(self) -> float:
"""
Return the default temperature setting for this tool.
@@ -293,6 +324,11 @@ class BaseTool(ABC):
"""
if not request_files:
return ""
# If conversation history is already embedded, skip file processing
if hasattr(self, '_has_embedded_history') and self._has_embedded_history:
logger.debug(f"[FILES] {self.name}: Skipping file processing - conversation history already embedded")
return ""
# Extract remaining budget from arguments if available
if remaining_budget is None:
@@ -300,15 +336,59 @@ class BaseTool(ABC):
args_to_use = arguments or getattr(self, "_current_arguments", {})
remaining_budget = args_to_use.get("_remaining_tokens")
# Use remaining budget if provided, otherwise fall back to max_tokens or default
# Use remaining budget if provided, otherwise fall back to max_tokens or model-specific default
if remaining_budget is not None:
effective_max_tokens = remaining_budget - reserve_tokens
elif max_tokens is not None:
effective_max_tokens = max_tokens - reserve_tokens
else:
from config import MAX_CONTENT_TOKENS
effective_max_tokens = MAX_CONTENT_TOKENS - reserve_tokens
# Get model-specific limits
# First check if model_context was passed from server.py
model_context = None
if arguments:
model_context = arguments.get("_model_context") or getattr(self, "_current_arguments", {}).get("_model_context")
if model_context:
# Use the passed model context
try:
token_allocation = model_context.calculate_token_allocation()
effective_max_tokens = token_allocation.file_tokens - reserve_tokens
logger.debug(f"[FILES] {self.name}: Using passed model context for {model_context.model_name}: "
f"{token_allocation.file_tokens:,} file tokens from {token_allocation.total_tokens:,} total")
except Exception as e:
logger.warning(f"[FILES] {self.name}: Error using passed model context: {e}")
# Fall through to manual calculation
model_context = None
if not model_context:
# Manual calculation as fallback
model_name = getattr(self, "_current_model_name", None) or DEFAULT_MODEL
try:
provider = self.get_model_provider(model_name)
capabilities = provider.get_capabilities(model_name)
# Calculate content allocation based on model capacity
if capabilities.max_tokens < 300_000:
# Smaller context models: 60% content, 40% response
model_content_tokens = int(capabilities.max_tokens * 0.6)
else:
# Larger context models: 80% content, 20% response
model_content_tokens = int(capabilities.max_tokens * 0.8)
effective_max_tokens = model_content_tokens - reserve_tokens
logger.debug(f"[FILES] {self.name}: Using model-specific limit for {model_name}: "
f"{model_content_tokens:,} content tokens from {capabilities.max_tokens:,} total")
except (ValueError, AttributeError) as e:
# Handle specific errors: provider not found, model not supported, missing attributes
logger.warning(f"[FILES] {self.name}: Could not get model capabilities for {model_name}: {type(e).__name__}: {e}")
# Fall back to conservative default for safety
from config import MAX_CONTENT_TOKENS
effective_max_tokens = min(MAX_CONTENT_TOKENS, 100_000) - reserve_tokens
except Exception as e:
# Catch any other unexpected errors
logger.error(f"[FILES] {self.name}: Unexpected error getting model capabilities: {type(e).__name__}: {e}")
from config import MAX_CONTENT_TOKENS
effective_max_tokens = min(MAX_CONTENT_TOKENS, 100_000) - reserve_tokens
# Ensure we have a reasonable minimum budget
effective_max_tokens = max(1000, effective_max_tokens)
@@ -601,34 +681,59 @@ If any of these would strengthen your analysis, specify what Claude should searc
)
return [TextContent(type="text", text=error_output.model_dump_json())]
# Prepare the full prompt by combining system prompt with user request
# This is delegated to the tool implementation for customization
prompt = await self.prepare_prompt(request)
# Add follow-up instructions for new conversations (not threaded)
# Check if we have continuation_id - if so, conversation history is already embedded
continuation_id = getattr(request, "continuation_id", None)
if not continuation_id:
# Import here to avoid circular imports
if continuation_id:
# When continuation_id is present, server.py has already injected the
# conversation history into the appropriate field. We need to check if
# the prompt already contains conversation history marker.
logger.debug(f"Continuing {self.name} conversation with thread {continuation_id}")
# Store the original arguments to detect enhanced prompts
self._has_embedded_history = False
# Check if conversation history is already embedded in the prompt field
field_value = getattr(request, "prompt", "")
field_name = "prompt"
if "=== CONVERSATION HISTORY ===" in field_value:
# Conversation history is already embedded, use it directly
prompt = field_value
self._has_embedded_history = True
logger.debug(f"{self.name}: Using pre-embedded conversation history from {field_name}")
else:
# No embedded history, prepare prompt normally
prompt = await self.prepare_prompt(request)
logger.debug(f"{self.name}: No embedded history found, prepared prompt normally")
else:
# New conversation, prepare prompt normally
prompt = await self.prepare_prompt(request)
# Add follow-up instructions for new conversations
from server import get_follow_up_instructions
follow_up_instructions = get_follow_up_instructions(0) # New conversation, turn 0
prompt = f"{prompt}\n\n{follow_up_instructions}"
logger.debug(f"Added follow-up instructions for new {self.name} conversation")
# Also log to file for debugging MCP issues
try:
with open("/tmp/gemini_debug.log", "a") as f:
f.write(f"[{self.name}] Added follow-up instructions for new conversation\n")
except Exception:
pass
else:
logger.debug(f"Continuing {self.name} conversation with thread {continuation_id}")
# History reconstruction is handled by server.py:reconstruct_thread_context
# No need to rebuild it here - prompt already contains conversation history
# Extract model configuration from request or use defaults
model_name = getattr(request, "model", None) or DEFAULT_MODEL
model_name = getattr(request, "model", None)
if not model_name:
model_name = DEFAULT_MODEL
# In auto mode, model parameter is required
from config import IS_AUTO_MODE
if IS_AUTO_MODE and model_name.lower() == "auto":
error_output = ToolOutput(
status="error",
content="Model parameter is required. Please specify which model to use for this task.",
content_type="text",
)
return [TextContent(type="text", text=error_output.model_dump_json())]
# Store model name for use by helper methods like _prepare_file_content_for_prompt
self._current_model_name = model_name
temperature = getattr(request, "temperature", None)
if temperature is None:
temperature = self.get_default_temperature()
@@ -636,28 +741,45 @@ If any of these would strengthen your analysis, specify what Claude should searc
if thinking_mode is None:
thinking_mode = self.get_default_thinking_mode()
# Create model instance with appropriate configuration
# This handles both regular models and thinking-enabled models
model = self.create_model(model_name, temperature, thinking_mode)
# Get the appropriate model provider
provider = self.get_model_provider(model_name)
# Get system prompt for this tool
system_prompt = self.get_system_prompt()
# Generate AI response using the configured model
logger.info(f"Sending request to Gemini API for {self.name}")
# Generate AI response using the provider
logger.info(f"Sending request to {provider.get_provider_type().value} API for {self.name}")
logger.debug(f"Prompt length: {len(prompt)} characters")
response = model.generate_content(prompt)
logger.info(f"Received response from Gemini API for {self.name}")
# Generate content with provider abstraction
model_response = provider.generate_content(
prompt=prompt,
model_name=model_name,
system_prompt=system_prompt,
temperature=temperature,
thinking_mode=thinking_mode if provider.supports_thinking_mode(model_name) else None
)
logger.info(f"Received response from {provider.get_provider_type().value} API for {self.name}")
# Process the model's response
if response.candidates and response.candidates[0].content.parts:
raw_text = response.candidates[0].content.parts[0].text
if model_response.content:
raw_text = model_response.content
# Parse response to check for clarification requests or format output
tool_output = self._parse_response(raw_text, request)
# Pass model info for conversation tracking
model_info = {
"provider": provider,
"model_name": model_name,
"model_response": model_response
}
tool_output = self._parse_response(raw_text, request, model_info)
logger.info(f"Successfully completed {self.name} tool execution")
else:
# Handle cases where the model couldn't generate a response
# This might happen due to safety filters or other constraints
finish_reason = response.candidates[0].finish_reason if response.candidates else "Unknown"
finish_reason = model_response.metadata.get("finish_reason", "Unknown")
logger.warning(f"Response blocked or incomplete for {self.name}. Finish reason: {finish_reason}")
tool_output = ToolOutput(
status="error",
@@ -678,13 +800,24 @@ If any of these would strengthen your analysis, specify what Claude should searc
if "500 INTERNAL" in error_msg and "Please retry" in error_msg:
logger.warning(f"500 INTERNAL error in {self.name} - attempting retry")
try:
# Single retry attempt
model = self._get_model_wrapper(request)
raw_response = await model.generate_content(prompt)
response = raw_response.text
# If successful, process normally
return [TextContent(type="text", text=self._process_response(response, request).model_dump_json())]
# Single retry attempt using provider
retry_response = provider.generate_content(
prompt=prompt,
model_name=model_name,
system_prompt=system_prompt,
temperature=temperature,
thinking_mode=thinking_mode if provider.supports_thinking_mode(model_name) else None
)
if retry_response.content:
# If successful, process normally
retry_model_info = {
"provider": provider,
"model_name": model_name,
"model_response": retry_response
}
tool_output = self._parse_response(retry_response.content, request, retry_model_info)
return [TextContent(type="text", text=tool_output.model_dump_json())]
except Exception as retry_e:
logger.error(f"Retry failed for {self.name} tool: {str(retry_e)}")
@@ -699,7 +832,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
)
return [TextContent(type="text", text=error_output.model_dump_json())]
def _parse_response(self, raw_text: str, request) -> ToolOutput:
def _parse_response(self, raw_text: str, request, model_info: Optional[dict] = None) -> ToolOutput:
"""
Parse the raw response and determine if it's a clarification request or follow-up.
@@ -745,11 +878,11 @@ If any of these would strengthen your analysis, specify what Claude should searc
pass
# Normal text response - format using tool-specific formatting
formatted_content = self.format_response(raw_text, request)
formatted_content = self.format_response(raw_text, request, model_info)
# If we found a follow-up question, prepare the threading response
if follow_up_question:
return self._create_follow_up_response(formatted_content, follow_up_question, request)
return self._create_follow_up_response(formatted_content, follow_up_question, request, model_info)
# Check if we should offer Claude a continuation opportunity
continuation_offer = self._check_continuation_opportunity(request)
@@ -758,7 +891,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
logger.debug(
f"Creating continuation offer for {self.name} with {continuation_offer['remaining_turns']} turns remaining"
)
return self._create_continuation_offer_response(formatted_content, continuation_offer, request)
return self._create_continuation_offer_response(formatted_content, continuation_offer, request, model_info)
else:
logger.debug(f"No continuation offer created for {self.name}")
@@ -766,12 +899,32 @@ If any of these would strengthen your analysis, specify what Claude should searc
continuation_id = getattr(request, "continuation_id", None)
if continuation_id:
request_files = getattr(request, "files", []) or []
# Extract model metadata for conversation tracking
model_provider = None
model_name = None
model_metadata = None
if model_info:
provider = model_info.get("provider")
if provider:
model_provider = provider.get_provider_type().value
model_name = model_info.get("model_name")
model_response = model_info.get("model_response")
if model_response:
model_metadata = {
"usage": model_response.usage,
"metadata": model_response.metadata
}
success = add_turn(
continuation_id,
"assistant",
formatted_content,
files=request_files,
tool_name=self.name,
model_provider=model_provider,
model_name=model_name,
model_metadata=model_metadata,
)
if not success:
logging.warning(f"Failed to add turn to thread {continuation_id} for {self.name}")
@@ -820,7 +973,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
return None
def _create_follow_up_response(self, content: str, follow_up_data: dict, request) -> ToolOutput:
def _create_follow_up_response(self, content: str, follow_up_data: dict, request, model_info: Optional[dict] = None) -> ToolOutput:
"""
Create a response with follow-up question for conversation threading.
@@ -832,56 +985,57 @@ If any of these would strengthen your analysis, specify what Claude should searc
Returns:
ToolOutput configured for conversation continuation
"""
# Create or get thread ID
# Always create a new thread (with parent linkage if continuation)
continuation_id = getattr(request, "continuation_id", None)
request_files = getattr(request, "files", []) or []
try:
# Create new thread with parent linkage if continuing
thread_id = create_thread(
tool_name=self.name,
initial_request=request.model_dump() if hasattr(request, "model_dump") else {},
parent_thread_id=continuation_id # Link to parent thread if continuing
)
if continuation_id:
# This is a continuation - add this turn to existing thread
request_files = getattr(request, "files", []) or []
success = add_turn(
continuation_id,
# Add the assistant's response with follow-up
# Extract model metadata
model_provider = None
model_name = None
model_metadata = None
if model_info:
provider = model_info.get("provider")
if provider:
model_provider = provider.get_provider_type().value
model_name = model_info.get("model_name")
model_response = model_info.get("model_response")
if model_response:
model_metadata = {
"usage": model_response.usage,
"metadata": model_response.metadata
}
add_turn(
thread_id, # Add to the new thread
"assistant",
content,
follow_up_question=follow_up_data.get("follow_up_question"),
files=request_files,
tool_name=self.name,
model_provider=model_provider,
model_name=model_name,
model_metadata=model_metadata,
)
except Exception as e:
# Threading failed, return normal response
logger = logging.getLogger(f"tools.{self.name}")
logger.warning(f"Follow-up threading failed in {self.name}: {str(e)}")
return ToolOutput(
status="success",
content=content,
content_type="markdown",
metadata={"tool_name": self.name, "follow_up_error": str(e)},
)
if not success:
# Thread not found or at limit, return normal response
return ToolOutput(
status="success",
content=content,
content_type="markdown",
metadata={"tool_name": self.name},
)
thread_id = continuation_id
else:
# Create new thread
try:
thread_id = create_thread(
tool_name=self.name, initial_request=request.model_dump() if hasattr(request, "model_dump") else {}
)
# Add the assistant's response with follow-up
request_files = getattr(request, "files", []) or []
add_turn(
thread_id,
"assistant",
content,
follow_up_question=follow_up_data.get("follow_up_question"),
files=request_files,
tool_name=self.name,
)
except Exception as e:
# Threading failed, return normal response
logger = logging.getLogger(f"tools.{self.name}")
logger.warning(f"Follow-up threading failed in {self.name}: {str(e)}")
return ToolOutput(
status="success",
content=content,
content_type="markdown",
metadata={"tool_name": self.name, "follow_up_error": str(e)},
)
# Create follow-up request
follow_up_request = FollowUpRequest(
@@ -925,13 +1079,14 @@ If any of these would strengthen your analysis, specify what Claude should searc
try:
if continuation_id:
# Check remaining turns in existing thread
from utils.conversation_memory import get_thread
# Check remaining turns in thread chain
from utils.conversation_memory import get_thread_chain
context = get_thread(continuation_id)
if context:
current_turns = len(context.turns)
remaining_turns = MAX_CONVERSATION_TURNS - current_turns - 1 # -1 for this response
chain = get_thread_chain(continuation_id)
if chain:
# Count total turns across all threads in chain
total_turns = sum(len(thread.turns) for thread in chain)
remaining_turns = MAX_CONVERSATION_TURNS - total_turns - 1 # -1 for this response
else:
# Thread not found, don't offer continuation
return None
@@ -949,7 +1104,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
# If anything fails, don't offer continuation
return None
def _create_continuation_offer_response(self, content: str, continuation_data: dict, request) -> ToolOutput:
def _create_continuation_offer_response(self, content: str, continuation_data: dict, request, model_info: Optional[dict] = None) -> ToolOutput:
"""
Create a response offering Claude the opportunity to continue conversation.
@@ -962,14 +1117,43 @@ If any of these would strengthen your analysis, specify what Claude should searc
ToolOutput configured with continuation offer
"""
try:
# Create new thread for potential continuation
# Create new thread for potential continuation (with parent link if continuing)
continuation_id = getattr(request, "continuation_id", None)
thread_id = create_thread(
tool_name=self.name, initial_request=request.model_dump() if hasattr(request, "model_dump") else {}
tool_name=self.name,
initial_request=request.model_dump() if hasattr(request, "model_dump") else {},
parent_thread_id=continuation_id # Link to parent if this is a continuation
)
# Add this response as the first turn (assistant turn)
request_files = getattr(request, "files", []) or []
add_turn(thread_id, "assistant", content, files=request_files, tool_name=self.name)
# Extract model metadata
model_provider = None
model_name = None
model_metadata = None
if model_info:
provider = model_info.get("provider")
if provider:
model_provider = provider.get_provider_type().value
model_name = model_info.get("model_name")
model_response = model_info.get("model_response")
if model_response:
model_metadata = {
"usage": model_response.usage,
"metadata": model_response.metadata
}
add_turn(
thread_id,
"assistant",
content,
files=request_files,
tool_name=self.name,
model_provider=model_provider,
model_name=model_name,
model_metadata=model_metadata,
)
# Create continuation offer
remaining_turns = continuation_data["remaining_turns"]
@@ -1022,7 +1206,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
"""
pass
def format_response(self, response: str, request) -> str:
def format_response(self, response: str, request, model_info: Optional[dict] = None) -> str:
"""
Format the model's response for display.
@@ -1033,6 +1217,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
Args:
response: The raw response from the model
request: The original request for context
model_info: Optional dict with model metadata (provider, model_name, model_response)
Returns:
str: Formatted response
@@ -1059,154 +1244,41 @@ If any of these would strengthen your analysis, specify what Claude should searc
f"{context_type} too large (~{estimated_tokens:,} tokens). Maximum is {MAX_CONTEXT_TOKENS:,} tokens."
)
def create_model(self, model_name: str, temperature: float, thinking_mode: str = "medium"):
def get_model_provider(self, model_name: str) -> ModelProvider:
"""
Create a configured Gemini model instance.
This method handles model creation with appropriate settings including
temperature and thinking budget configuration for models that support it.
Get a model provider for the specified model.
Args:
model_name: Name of the Gemini model to use (or shorthand like 'flash', 'pro')
temperature: Temperature setting for response generation
thinking_mode: Thinking depth mode (affects computational budget)
model_name: Name of the model to use (can be provider-specific or generic)
Returns:
Model instance configured and ready for generation
ModelProvider instance configured for the model
Raises:
ValueError: If no provider supports the requested model
"""
# Define model shorthands for user convenience
model_shorthands = {
"pro": "gemini-2.5-pro-preview-06-05",
"flash": "gemini-2.0-flash-exp",
}
# Resolve shorthand to full model name
resolved_model_name = model_shorthands.get(model_name.lower(), model_name)
# Map thinking modes to computational budget values
# Higher budgets allow for more complex reasoning but increase latency
thinking_budgets = {
"minimal": 128, # Minimum for 2.5 Pro - fast responses
"low": 2048, # Light reasoning tasks
"medium": 8192, # Balanced reasoning (default)
"high": 16384, # Complex analysis
"max": 32768, # Maximum reasoning depth
}
thinking_budget = thinking_budgets.get(thinking_mode, 8192)
# Gemini 2.5 models support thinking configuration for enhanced reasoning
# Skip special handling in test environment to allow mocking
if "2.5" in resolved_model_name and not os.environ.get("PYTEST_CURRENT_TEST"):
try:
# Retrieve API key for Gemini client creation
api_key = os.environ.get("GEMINI_API_KEY")
if not api_key:
raise ValueError("GEMINI_API_KEY environment variable is required")
client = genai.Client(api_key=api_key)
# Create a wrapper class to provide a consistent interface
# This abstracts the differences between API versions
class ModelWrapper:
def __init__(self, client, model_name, temperature, thinking_budget):
self.client = client
self.model_name = model_name
self.temperature = temperature
self.thinking_budget = thinking_budget
def generate_content(self, prompt):
response = self.client.models.generate_content(
model=self.model_name,
contents=prompt,
config=types.GenerateContentConfig(
temperature=self.temperature,
candidate_count=1,
thinking_config=types.ThinkingConfig(thinking_budget=self.thinking_budget),
),
)
# Wrap the response to match the expected format
# This ensures compatibility across different API versions
class ResponseWrapper:
def __init__(self, text):
self.text = text
self.candidates = [
type(
"obj",
(object,),
{
"content": type(
"obj",
(object,),
{
"parts": [
type(
"obj",
(object,),
{"text": text},
)
]
},
)(),
"finish_reason": "STOP",
},
)
]
return ResponseWrapper(response.text)
return ModelWrapper(client, resolved_model_name, temperature, thinking_budget)
except Exception:
# Fall back to regular API if thinking configuration fails
# This ensures the tool remains functional even with API changes
pass
# For models that don't support thinking configuration, use standard API
api_key = os.environ.get("GEMINI_API_KEY")
if not api_key:
raise ValueError("GEMINI_API_KEY environment variable is required")
client = genai.Client(api_key=api_key)
# Create a simple wrapper for models without thinking configuration
# This provides the same interface as the thinking-enabled wrapper
class SimpleModelWrapper:
def __init__(self, client, model_name, temperature):
self.client = client
self.model_name = model_name
self.temperature = temperature
def generate_content(self, prompt):
response = self.client.models.generate_content(
model=self.model_name,
contents=prompt,
config=types.GenerateContentConfig(
temperature=self.temperature,
candidate_count=1,
),
)
# Convert to match expected format
class ResponseWrapper:
def __init__(self, text):
self.text = text
self.candidates = [
type(
"obj",
(object,),
{
"content": type(
"obj",
(object,),
{"parts": [type("obj", (object,), {"text": text})]},
)(),
"finish_reason": "STOP",
},
)
]
return ResponseWrapper(response.text)
return SimpleModelWrapper(client, resolved_model_name, temperature)
# Get provider from registry
provider = ModelProviderRegistry.get_provider_for_model(model_name)
if not provider:
# Try to determine provider from model name patterns
if "gemini" in model_name.lower() or model_name.lower() in ["flash", "pro"]:
# Register Gemini provider if not already registered
from providers.gemini import GeminiModelProvider
from providers.base import ProviderType
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
provider = ModelProviderRegistry.get_provider(ProviderType.GOOGLE)
elif "gpt" in model_name.lower() or "o3" in model_name.lower():
# Register OpenAI provider if not already registered
from providers.openai import OpenAIModelProvider
from providers.base import ProviderType
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
provider = ModelProviderRegistry.get_provider(ProviderType.OPENAI)
if not provider:
raise ValueError(
f"No provider found for model '{model_name}'. "
f"Ensure the appropriate API key is set and the model name is correct."
)
return provider