refactor: address code review feedback from Gemini
- Extract restriction checking logic into reusable helper method - Refactor validate_model_name to reduce code duplication - Fix logging import by using existing module-level logger - Clean up test file by removing print statement and main block - All tests continue to pass after refactoring
This commit is contained in:
@@ -178,12 +178,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
"""Get capabilities for a specific OpenAI model."""
|
"""Get capabilities for a specific OpenAI model."""
|
||||||
# First check if it's a key in SUPPORTED_MODELS
|
# First check if it's a key in SUPPORTED_MODELS
|
||||||
if model_name in self.SUPPORTED_MODELS:
|
if model_name in self.SUPPORTED_MODELS:
|
||||||
# Check if model is allowed by restrictions
|
self._check_model_restrictions(model_name, model_name)
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service()
|
|
||||||
if not restriction_service.is_allowed(ProviderType.OPENAI, model_name, model_name):
|
|
||||||
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
|
|
||||||
return self.SUPPORTED_MODELS[model_name]
|
return self.SUPPORTED_MODELS[model_name]
|
||||||
|
|
||||||
# Try resolving as alias
|
# Try resolving as alias
|
||||||
@@ -191,23 +186,13 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
|
|
||||||
# Check if resolved name is a key
|
# Check if resolved name is a key
|
||||||
if resolved_name in self.SUPPORTED_MODELS:
|
if resolved_name in self.SUPPORTED_MODELS:
|
||||||
# Check if model is allowed by restrictions
|
self._check_model_restrictions(resolved_name, model_name)
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service()
|
|
||||||
if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name):
|
|
||||||
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
|
|
||||||
return self.SUPPORTED_MODELS[resolved_name]
|
return self.SUPPORTED_MODELS[resolved_name]
|
||||||
|
|
||||||
# Finally check if resolved name matches any API model name
|
# Finally check if resolved name matches any API model name
|
||||||
for key, capabilities in self.SUPPORTED_MODELS.items():
|
for key, capabilities in self.SUPPORTED_MODELS.items():
|
||||||
if resolved_name == capabilities.model_name:
|
if resolved_name == capabilities.model_name:
|
||||||
# Check if model is allowed by restrictions
|
self._check_model_restrictions(key, model_name)
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service()
|
|
||||||
if not restriction_service.is_allowed(ProviderType.OPENAI, key, model_name):
|
|
||||||
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
|
|
||||||
return capabilities
|
return capabilities
|
||||||
|
|
||||||
# Check custom models registry for user-configured OpenAI models
|
# Check custom models registry for user-configured OpenAI models
|
||||||
@@ -218,12 +203,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
config = registry.get_model_config(resolved_name)
|
config = registry.get_model_config(resolved_name)
|
||||||
|
|
||||||
if config and config.provider == ProviderType.OPENAI:
|
if config and config.provider == ProviderType.OPENAI:
|
||||||
# Check if model is allowed by restrictions
|
self._check_model_restrictions(config.model_name, model_name)
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service()
|
|
||||||
if not restriction_service.is_allowed(ProviderType.OPENAI, config.model_name, model_name):
|
|
||||||
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
|
|
||||||
|
|
||||||
# Update provider type to ensure consistency
|
# Update provider type to ensure consistency
|
||||||
config.provider = ProviderType.OPENAI
|
config.provider = ProviderType.OPENAI
|
||||||
@@ -231,13 +211,26 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log but don't fail - registry might not be available
|
# Log but don't fail - registry might not be available
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
logger.debug(f"Could not check custom models registry for '{resolved_name}': {e}")
|
logger.debug(f"Could not check custom models registry for '{resolved_name}': {e}")
|
||||||
|
|
||||||
raise ValueError(f"Unsupported OpenAI model: {model_name}")
|
raise ValueError(f"Unsupported OpenAI model: {model_name}")
|
||||||
|
|
||||||
|
def _check_model_restrictions(self, provider_model_name: str, user_model_name: str) -> None:
|
||||||
|
"""Check if a model is allowed by restriction policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_model_name: The model name used by the provider
|
||||||
|
user_model_name: The model name requested by the user
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the model is not allowed by restriction policy
|
||||||
|
"""
|
||||||
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
|
restriction_service = get_restriction_service()
|
||||||
|
if not restriction_service.is_allowed(ProviderType.OPENAI, provider_model_name, user_model_name):
|
||||||
|
raise ValueError(f"OpenAI model '{user_model_name}' is not allowed by restriction policy.")
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
return ProviderType.OPENAI
|
return ProviderType.OPENAI
|
||||||
@@ -246,39 +239,41 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
"""Validate if the model name is supported and allowed."""
|
"""Validate if the model name is supported and allowed."""
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
# First check if model is in built-in SUPPORTED_MODELS
|
# First, determine which model name to check against restrictions.
|
||||||
|
model_to_check = None
|
||||||
|
is_custom_model = False
|
||||||
|
|
||||||
if resolved_name in self.SUPPORTED_MODELS:
|
if resolved_name in self.SUPPORTED_MODELS:
|
||||||
# Check if model is allowed by restrictions
|
model_to_check = resolved_name
|
||||||
from utils.model_restrictions import get_restriction_service
|
else:
|
||||||
|
# If not a built-in model, check the custom models registry.
|
||||||
|
try:
|
||||||
|
from .openrouter_registry import OpenRouterModelRegistry
|
||||||
|
|
||||||
restriction_service = get_restriction_service()
|
registry = OpenRouterModelRegistry()
|
||||||
if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name):
|
config = registry.get_model_config(resolved_name)
|
||||||
logger.debug(f"OpenAI model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check custom models registry for user-configured OpenAI models
|
if config and config.provider == ProviderType.OPENAI:
|
||||||
try:
|
model_to_check = config.model_name
|
||||||
from .openrouter_registry import OpenRouterModelRegistry
|
is_custom_model = True
|
||||||
|
except Exception as e:
|
||||||
|
# Log but don't fail - registry might not be available.
|
||||||
|
logger.debug(f"Could not check custom models registry for '{resolved_name}': {e}")
|
||||||
|
|
||||||
registry = OpenRouterModelRegistry()
|
# If no model was found (neither built-in nor custom), it's invalid.
|
||||||
config = registry.get_model_config(resolved_name)
|
if not model_to_check:
|
||||||
|
return False
|
||||||
|
|
||||||
if config and config.provider == ProviderType.OPENAI:
|
# Now, perform the restriction check once.
|
||||||
# Check if model is allowed by restrictions
|
from utils.model_restrictions import get_restriction_service
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service()
|
restriction_service = get_restriction_service()
|
||||||
if not restriction_service.is_allowed(ProviderType.OPENAI, config.model_name, model_name):
|
if not restriction_service.is_allowed(ProviderType.OPENAI, model_to_check, model_name):
|
||||||
logger.debug(f"OpenAI custom model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
model_type = "custom " if is_custom_model else ""
|
||||||
return False
|
logger.debug(f"OpenAI {model_type}model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
||||||
return True
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
return True
|
||||||
# Log but don't fail - registry might not be available
|
|
||||||
logger.debug(f"Could not check custom models registry for '{resolved_name}': {e}")
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def generate_content(
|
def generate_content(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -75,9 +75,3 @@ def test_issue_245_custom_openai_temperature_ignored():
|
|||||||
# Verify the fix: NO temperature should be sent to the API
|
# Verify the fix: NO temperature should be sent to the API
|
||||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||||
assert "temperature" not in call_kwargs, "Fix failed: temperature still being sent!"
|
assert "temperature" not in call_kwargs, "Fix failed: temperature still being sent!"
|
||||||
|
|
||||||
print("✅ Issue #245 is FIXED! Temperature parameter correctly ignored for custom models.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_issue_245_custom_openai_temperature_ignored()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user