Rebranding, refactoring, renaming, cleanup, updated docs
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
"""Tests for the model provider abstraction system"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from providers import ModelProviderRegistry, ModelProvider, ModelResponse, ModelCapabilities
|
||||
from providers import ModelProviderRegistry, ModelResponse
|
||||
from providers.base import ProviderType
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
@@ -12,56 +11,56 @@ from providers.openai import OpenAIModelProvider
|
||||
|
||||
class TestModelProviderRegistry:
|
||||
"""Test the model provider registry"""
|
||||
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear registry before each test"""
|
||||
ModelProviderRegistry._providers.clear()
|
||||
ModelProviderRegistry._initialized_providers.clear()
|
||||
|
||||
|
||||
def test_register_provider(self):
|
||||
"""Test registering a provider"""
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
|
||||
assert ProviderType.GOOGLE in ModelProviderRegistry._providers
|
||||
assert ModelProviderRegistry._providers[ProviderType.GOOGLE] == GeminiModelProvider
|
||||
|
||||
|
||||
@patch.dict(os.environ, {"GEMINI_API_KEY": "test-key"})
|
||||
def test_get_provider(self):
|
||||
"""Test getting a provider instance"""
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
|
||||
provider = ModelProviderRegistry.get_provider(ProviderType.GOOGLE)
|
||||
|
||||
|
||||
assert provider is not None
|
||||
assert isinstance(provider, GeminiModelProvider)
|
||||
assert provider.api_key == "test-key"
|
||||
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_get_provider_no_api_key(self):
|
||||
"""Test getting provider without API key returns None"""
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
|
||||
provider = ModelProviderRegistry.get_provider(ProviderType.GOOGLE)
|
||||
|
||||
|
||||
assert provider is None
|
||||
|
||||
|
||||
@patch.dict(os.environ, {"GEMINI_API_KEY": "test-key"})
|
||||
def test_get_provider_for_model(self):
|
||||
"""Test getting provider for a specific model"""
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
|
||||
provider = ModelProviderRegistry.get_provider_for_model("gemini-2.0-flash-exp")
|
||||
|
||||
|
||||
assert provider is not None
|
||||
assert isinstance(provider, GeminiModelProvider)
|
||||
|
||||
|
||||
def test_get_available_providers(self):
|
||||
"""Test getting list of available providers"""
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
|
||||
|
||||
providers = ModelProviderRegistry.get_available_providers()
|
||||
|
||||
|
||||
assert len(providers) == 2
|
||||
assert ProviderType.GOOGLE in providers
|
||||
assert ProviderType.OPENAI in providers
|
||||
@@ -69,50 +68,50 @@ class TestModelProviderRegistry:
|
||||
|
||||
class TestGeminiProvider:
|
||||
"""Test Gemini model provider"""
|
||||
|
||||
|
||||
def test_provider_initialization(self):
|
||||
"""Test provider initialization"""
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
|
||||
assert provider.api_key == "test-key"
|
||||
assert provider.get_provider_type() == ProviderType.GOOGLE
|
||||
|
||||
|
||||
def test_get_capabilities(self):
|
||||
"""Test getting model capabilities"""
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
|
||||
capabilities = provider.get_capabilities("gemini-2.0-flash-exp")
|
||||
|
||||
|
||||
assert capabilities.provider == ProviderType.GOOGLE
|
||||
assert capabilities.model_name == "gemini-2.0-flash-exp"
|
||||
assert capabilities.max_tokens == 1_048_576
|
||||
assert not capabilities.supports_extended_thinking
|
||||
|
||||
|
||||
def test_get_capabilities_pro_model(self):
|
||||
"""Test getting capabilities for Pro model with thinking support"""
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
|
||||
capabilities = provider.get_capabilities("gemini-2.5-pro-preview-06-05")
|
||||
|
||||
|
||||
assert capabilities.supports_extended_thinking
|
||||
|
||||
|
||||
def test_model_shorthand_resolution(self):
|
||||
"""Test model shorthand resolution"""
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
|
||||
assert provider.validate_model_name("flash")
|
||||
assert provider.validate_model_name("pro")
|
||||
|
||||
|
||||
capabilities = provider.get_capabilities("flash")
|
||||
assert capabilities.model_name == "gemini-2.0-flash-exp"
|
||||
|
||||
|
||||
def test_supports_thinking_mode(self):
|
||||
"""Test thinking mode support detection"""
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
|
||||
assert not provider.supports_thinking_mode("gemini-2.0-flash-exp")
|
||||
assert provider.supports_thinking_mode("gemini-2.5-pro-preview-06-05")
|
||||
|
||||
|
||||
@patch("google.genai.Client")
|
||||
def test_generate_content(self, mock_client_class):
|
||||
"""Test content generation"""
|
||||
@@ -131,15 +130,11 @@ class TestGeminiProvider:
|
||||
mock_response.usage_metadata = mock_usage
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
response = provider.generate_content(
|
||||
prompt="Test prompt",
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
temperature=0.7
|
||||
)
|
||||
|
||||
|
||||
response = provider.generate_content(prompt="Test prompt", model_name="gemini-2.0-flash-exp", temperature=0.7)
|
||||
|
||||
assert isinstance(response, ModelResponse)
|
||||
assert response.content == "Generated content"
|
||||
assert response.model_name == "gemini-2.0-flash-exp"
|
||||
@@ -151,38 +146,38 @@ class TestGeminiProvider:
|
||||
|
||||
class TestOpenAIProvider:
|
||||
"""Test OpenAI model provider"""
|
||||
|
||||
|
||||
def test_provider_initialization(self):
|
||||
"""Test provider initialization"""
|
||||
provider = OpenAIModelProvider(api_key="test-key", organization="test-org")
|
||||
|
||||
|
||||
assert provider.api_key == "test-key"
|
||||
assert provider.organization == "test-org"
|
||||
assert provider.get_provider_type() == ProviderType.OPENAI
|
||||
|
||||
|
||||
def test_get_capabilities_o3(self):
|
||||
"""Test getting O3 model capabilities"""
|
||||
provider = OpenAIModelProvider(api_key="test-key")
|
||||
|
||||
|
||||
capabilities = provider.get_capabilities("o3-mini")
|
||||
|
||||
|
||||
assert capabilities.provider == ProviderType.OPENAI
|
||||
assert capabilities.model_name == "o3-mini"
|
||||
assert capabilities.max_tokens == 200_000
|
||||
assert not capabilities.supports_extended_thinking
|
||||
|
||||
|
||||
def test_validate_model_names(self):
|
||||
"""Test model name validation"""
|
||||
provider = OpenAIModelProvider(api_key="test-key")
|
||||
|
||||
|
||||
assert provider.validate_model_name("o3")
|
||||
assert provider.validate_model_name("o3-mini")
|
||||
assert not provider.validate_model_name("gpt-4o")
|
||||
assert not provider.validate_model_name("invalid-model")
|
||||
|
||||
|
||||
def test_no_thinking_mode_support(self):
|
||||
"""Test that no OpenAI models support thinking mode"""
|
||||
provider = OpenAIModelProvider(api_key="test-key")
|
||||
|
||||
|
||||
assert not provider.supports_thinking_mode("o3")
|
||||
assert not provider.supports_thinking_mode("o3-mini")
|
||||
assert not provider.supports_thinking_mode("o3-mini")
|
||||
|
||||
Reference in New Issue
Block a user