refactor: address code review feedback from Gemini

- Extract restriction checking logic into reusable helper method
- Refactor validate_model_name to reduce code duplication
- Fix logging import by using existing module-level logger
- Clean up test file by removing print statement and main block
- All tests continue to pass after refactoring
This commit is contained in:
Sven Lito
2025-09-05 11:04:02 +07:00
parent fab1f24475
commit 525f4598ce
2 changed files with 48 additions and 59 deletions

View File

@@ -178,12 +178,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
"""Get capabilities for a specific OpenAI model.""" """Get capabilities for a specific OpenAI model."""
# First check if it's a key in SUPPORTED_MODELS # First check if it's a key in SUPPORTED_MODELS
if model_name in self.SUPPORTED_MODELS: if model_name in self.SUPPORTED_MODELS:
# Check if model is allowed by restrictions self._check_model_restrictions(model_name, model_name)
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.OPENAI, model_name, model_name):
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
return self.SUPPORTED_MODELS[model_name] return self.SUPPORTED_MODELS[model_name]
# Try resolving as alias # Try resolving as alias
@@ -191,23 +186,13 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
# Check if resolved name is a key # Check if resolved name is a key
if resolved_name in self.SUPPORTED_MODELS: if resolved_name in self.SUPPORTED_MODELS:
# Check if model is allowed by restrictions self._check_model_restrictions(resolved_name, model_name)
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name):
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
return self.SUPPORTED_MODELS[resolved_name] return self.SUPPORTED_MODELS[resolved_name]
# Finally check if resolved name matches any API model name # Finally check if resolved name matches any API model name
for key, capabilities in self.SUPPORTED_MODELS.items(): for key, capabilities in self.SUPPORTED_MODELS.items():
if resolved_name == capabilities.model_name: if resolved_name == capabilities.model_name:
# Check if model is allowed by restrictions self._check_model_restrictions(key, model_name)
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.OPENAI, key, model_name):
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
return capabilities return capabilities
# Check custom models registry for user-configured OpenAI models # Check custom models registry for user-configured OpenAI models
@@ -218,12 +203,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
config = registry.get_model_config(resolved_name) config = registry.get_model_config(resolved_name)
if config and config.provider == ProviderType.OPENAI: if config and config.provider == ProviderType.OPENAI:
# Check if model is allowed by restrictions self._check_model_restrictions(config.model_name, model_name)
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.OPENAI, config.model_name, model_name):
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
# Update provider type to ensure consistency # Update provider type to ensure consistency
config.provider = ProviderType.OPENAI config.provider = ProviderType.OPENAI
@@ -231,13 +211,26 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
except Exception as e: except Exception as e:
# Log but don't fail - registry might not be available # Log but don't fail - registry might not be available
import logging
logger = logging.getLogger(__name__)
logger.debug(f"Could not check custom models registry for '{resolved_name}': {e}") logger.debug(f"Could not check custom models registry for '{resolved_name}': {e}")
raise ValueError(f"Unsupported OpenAI model: {model_name}") raise ValueError(f"Unsupported OpenAI model: {model_name}")
def _check_model_restrictions(self, provider_model_name: str, user_model_name: str) -> None:
"""Check if a model is allowed by restriction policy.
Args:
provider_model_name: The model name used by the provider
user_model_name: The model name requested by the user
Raises:
ValueError: If the model is not allowed by restriction policy
"""
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.OPENAI, provider_model_name, user_model_name):
raise ValueError(f"OpenAI model '{user_model_name}' is not allowed by restriction policy.")
def get_provider_type(self) -> ProviderType: def get_provider_type(self) -> ProviderType:
"""Get the provider type.""" """Get the provider type."""
return ProviderType.OPENAI return ProviderType.OPENAI
@@ -246,39 +239,41 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
"""Validate if the model name is supported and allowed.""" """Validate if the model name is supported and allowed."""
resolved_name = self._resolve_model_name(model_name) resolved_name = self._resolve_model_name(model_name)
# First check if model is in built-in SUPPORTED_MODELS # First, determine which model name to check against restrictions.
model_to_check = None
is_custom_model = False
if resolved_name in self.SUPPORTED_MODELS: if resolved_name in self.SUPPORTED_MODELS:
# Check if model is allowed by restrictions model_to_check = resolved_name
from utils.model_restrictions import get_restriction_service else:
# If not a built-in model, check the custom models registry.
try:
from .openrouter_registry import OpenRouterModelRegistry
restriction_service = get_restriction_service() registry = OpenRouterModelRegistry()
if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name): config = registry.get_model_config(resolved_name)
logger.debug(f"OpenAI model '{model_name}' -> '{resolved_name}' blocked by restrictions")
return False
return True
# Check custom models registry for user-configured OpenAI models if config and config.provider == ProviderType.OPENAI:
try: model_to_check = config.model_name
from .openrouter_registry import OpenRouterModelRegistry is_custom_model = True
except Exception as e:
# Log but don't fail - registry might not be available.
logger.debug(f"Could not check custom models registry for '{resolved_name}': {e}")
registry = OpenRouterModelRegistry() # If no model was found (neither built-in nor custom), it's invalid.
config = registry.get_model_config(resolved_name) if not model_to_check:
return False
if config and config.provider == ProviderType.OPENAI: # Now, perform the restriction check once.
# Check if model is allowed by restrictions from utils.model_restrictions import get_restriction_service
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.OPENAI, config.model_name, model_name): if not restriction_service.is_allowed(ProviderType.OPENAI, model_to_check, model_name):
logger.debug(f"OpenAI custom model '{model_name}' -> '{resolved_name}' blocked by restrictions") model_type = "custom " if is_custom_model else ""
return False logger.debug(f"OpenAI {model_type}model '{model_name}' -> '{resolved_name}' blocked by restrictions")
return True return False
except Exception as e: return True
# Log but don't fail - registry might not be available
logger.debug(f"Could not check custom models registry for '{resolved_name}': {e}")
return False
def generate_content( def generate_content(
self, self,

View File

@@ -75,9 +75,3 @@ def test_issue_245_custom_openai_temperature_ignored():
# Verify the fix: NO temperature should be sent to the API # Verify the fix: NO temperature should be sent to the API
call_kwargs = mock_client.chat.completions.create.call_args[1] call_kwargs = mock_client.chat.completions.create.call_args[1]
assert "temperature" not in call_kwargs, "Fix failed: temperature still being sent!" assert "temperature" not in call_kwargs, "Fix failed: temperature still being sent!"
print("✅ Issue #245 is FIXED! Temperature parameter correctly ignored for custom models.")
if __name__ == "__main__":
test_issue_245_custom_openai_temperature_ignored()