Use ModelCapabilities consistently instead of dictionaries

Moved aliases as part of SUPPORTED_MODELS instead of shorthand, more in line with how custom_models are declared
Further refactoring to cleanup some code
This commit is contained in:
Fahad
2025-06-23 16:58:59 +04:00
parent e94c028a3f
commit 498ea88293
16 changed files with 850 additions and 605 deletions

View File

@@ -99,15 +99,11 @@ class ListModelsTool(BaseTool):
output_lines.append("**Status**: Configured and available")
output_lines.append("\n**Models**:")
# Get models from the provider's SUPPORTED_MODELS
for model_name, config in provider.SUPPORTED_MODELS.items():
# Skip alias entries (string values)
if isinstance(config, str):
continue
# Get description and context from the model config
description = config.get("description", "No description available")
context_window = config.get("context_window", 0)
# Get models from the provider's model configurations
for model_name, capabilities in provider.get_model_configurations().items():
# Get description and context from the ModelCapabilities object
description = capabilities.description or "No description available"
context_window = capabilities.context_window
# Format context window
if context_window >= 1_000_000:
@@ -133,13 +129,14 @@ class ListModelsTool(BaseTool):
# Show aliases for this provider
aliases = []
for alias_name, target in provider.SUPPORTED_MODELS.items():
if isinstance(target, str): # This is an alias
aliases.append(f"- `{alias_name}` → `{target}`")
for model_name, capabilities in provider.get_model_configurations().items():
if capabilities.aliases:
for alias in capabilities.aliases:
aliases.append(f"- `{alias}` → `{model_name}`")
if aliases:
output_lines.append("\n**Aliases**:")
output_lines.extend(aliases)
output_lines.extend(sorted(aliases)) # Sort for consistent output
else:
output_lines.append(f"**Status**: Not configured (set {info['env_key']})")
@@ -237,7 +234,7 @@ class ListModelsTool(BaseTool):
for alias in registry.list_aliases():
config = registry.resolve(alias)
if config and hasattr(config, "is_custom") and config.is_custom:
if config and config.is_custom:
custom_models.append((alias, config))
if custom_models:

View File

@@ -256,8 +256,8 @@ class BaseTool(ABC):
# Find all custom models (is_custom=true)
for alias in registry.list_aliases():
config = registry.resolve(alias)
# Use hasattr for defensive programming - is_custom is optional with default False
if config and hasattr(config, "is_custom") and config.is_custom:
# Check if this is a custom model that requires custom endpoints
if config and config.is_custom:
if alias not in all_models:
all_models.append(alias)
except Exception as e:
@@ -311,12 +311,16 @@ class BaseTool(ABC):
ProviderType.GOOGLE: "Gemini models",
ProviderType.OPENAI: "OpenAI models",
ProviderType.XAI: "X.AI GROK models",
ProviderType.DIAL: "DIAL models",
ProviderType.CUSTOM: "Custom models",
ProviderType.OPENROUTER: "OpenRouter models",
}
# Check available providers and add their model descriptions
for provider_type in [ProviderType.GOOGLE, ProviderType.OPENAI, ProviderType.XAI]:
# Start with native providers
for provider_type in [ProviderType.GOOGLE, ProviderType.OPENAI, ProviderType.XAI, ProviderType.DIAL]:
# Only if this is registered / available
provider = ModelProviderRegistry.get_provider(provider_type)
if provider:
provider_section_added = False
@@ -324,13 +328,13 @@ class BaseTool(ABC):
try:
# Get model config to extract description
model_config = provider.SUPPORTED_MODELS.get(model_name)
if isinstance(model_config, dict) and "description" in model_config:
if model_config and model_config.description:
if not provider_section_added:
model_desc_parts.append(
f"\n{provider_names[provider_type]} - Available when {provider_type.value.upper()}_API_KEY is configured:"
)
provider_section_added = True
model_desc_parts.append(f"- '{model_name}': {model_config['description']}")
model_desc_parts.append(f"- '{model_name}': {model_config.description}")
except Exception:
# Skip models without descriptions
continue
@@ -346,8 +350,8 @@ class BaseTool(ABC):
# Find all custom models (is_custom=true)
for alias in registry.list_aliases():
config = registry.resolve(alias)
# Use hasattr for defensive programming - is_custom is optional with default False
if config and hasattr(config, "is_custom") and config.is_custom:
# Check if this is a custom model that requires custom endpoints
if config and config.is_custom:
# Format context window
context_tokens = config.context_window
if context_tokens >= 1_000_000: