From 525f4598ce67e2bde157294b8538905dd73cda7b Mon Sep 17 00:00:00 2001 From: Sven Lito Date: Fri, 5 Sep 2025 11:04:02 +0700 Subject: [PATCH] refactor: address code review feedback from Gemini - Extract restriction checking logic into reusable helper method - Refactor validate_model_name to reduce code duplication - Fix logging import by using existing module-level logger - Clean up test file by removing print statement and main block - All tests continue to pass after refactoring --- providers/openai_provider.py | 101 ++++++++++++++++----------------- tests/test_issue_245_simple.py | 6 -- 2 files changed, 48 insertions(+), 59 deletions(-) diff --git a/providers/openai_provider.py b/providers/openai_provider.py index 2398113..82b10ef 100644 --- a/providers/openai_provider.py +++ b/providers/openai_provider.py @@ -178,12 +178,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider): """Get capabilities for a specific OpenAI model.""" # First check if it's a key in SUPPORTED_MODELS if model_name in self.SUPPORTED_MODELS: - # Check if model is allowed by restrictions - from utils.model_restrictions import get_restriction_service - - restriction_service = get_restriction_service() - if not restriction_service.is_allowed(ProviderType.OPENAI, model_name, model_name): - raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.") + self._check_model_restrictions(model_name, model_name) return self.SUPPORTED_MODELS[model_name] # Try resolving as alias @@ -191,23 +186,13 @@ class OpenAIModelProvider(OpenAICompatibleProvider): # Check if resolved name is a key if resolved_name in self.SUPPORTED_MODELS: - # Check if model is allowed by restrictions - from utils.model_restrictions import get_restriction_service - - restriction_service = get_restriction_service() - if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name): - raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.") + self._check_model_restrictions(resolved_name, model_name) return self.SUPPORTED_MODELS[resolved_name] # Finally check if resolved name matches any API model name for key, capabilities in self.SUPPORTED_MODELS.items(): if resolved_name == capabilities.model_name: - # Check if model is allowed by restrictions - from utils.model_restrictions import get_restriction_service - - restriction_service = get_restriction_service() - if not restriction_service.is_allowed(ProviderType.OPENAI, key, model_name): - raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.") + self._check_model_restrictions(key, model_name) return capabilities # Check custom models registry for user-configured OpenAI models @@ -218,12 +203,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider): config = registry.get_model_config(resolved_name) if config and config.provider == ProviderType.OPENAI: - # Check if model is allowed by restrictions - from utils.model_restrictions import get_restriction_service - - restriction_service = get_restriction_service() - if not restriction_service.is_allowed(ProviderType.OPENAI, config.model_name, model_name): - raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.") + self._check_model_restrictions(config.model_name, model_name) # Update provider type to ensure consistency config.provider = ProviderType.OPENAI @@ -231,13 +211,26 @@ class OpenAIModelProvider(OpenAICompatibleProvider): except Exception as e: # Log but don't fail - registry might not be available - import logging - - logger = logging.getLogger(__name__) logger.debug(f"Could not check custom models registry for '{resolved_name}': {e}") raise ValueError(f"Unsupported OpenAI model: {model_name}") + def _check_model_restrictions(self, provider_model_name: str, user_model_name: str) -> None: + """Check if a model is allowed by restriction policy. + + Args: + provider_model_name: The model name used by the provider + user_model_name: The model name requested by the user + + Raises: + ValueError: If the model is not allowed by restriction policy + """ + from utils.model_restrictions import get_restriction_service + + restriction_service = get_restriction_service() + if not restriction_service.is_allowed(ProviderType.OPENAI, provider_model_name, user_model_name): + raise ValueError(f"OpenAI model '{user_model_name}' is not allowed by restriction policy.") + def get_provider_type(self) -> ProviderType: """Get the provider type.""" return ProviderType.OPENAI @@ -246,39 +239,41 @@ class OpenAIModelProvider(OpenAICompatibleProvider): """Validate if the model name is supported and allowed.""" resolved_name = self._resolve_model_name(model_name) - # First check if model is in built-in SUPPORTED_MODELS + # First, determine which model name to check against restrictions. + model_to_check = None + is_custom_model = False + if resolved_name in self.SUPPORTED_MODELS: - # Check if model is allowed by restrictions - from utils.model_restrictions import get_restriction_service + model_to_check = resolved_name + else: + # If not a built-in model, check the custom models registry. + try: + from .openrouter_registry import OpenRouterModelRegistry - restriction_service = get_restriction_service() - if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name): - logger.debug(f"OpenAI model '{model_name}' -> '{resolved_name}' blocked by restrictions") - return False - return True + registry = OpenRouterModelRegistry() + config = registry.get_model_config(resolved_name) - # Check custom models registry for user-configured OpenAI models - try: - from .openrouter_registry import OpenRouterModelRegistry + if config and config.provider == ProviderType.OPENAI: + model_to_check = config.model_name + is_custom_model = True + except Exception as e: + # Log but don't fail - registry might not be available. + logger.debug(f"Could not check custom models registry for '{resolved_name}': {e}") - registry = OpenRouterModelRegistry() - config = registry.get_model_config(resolved_name) + # If no model was found (neither built-in nor custom), it's invalid. + if not model_to_check: + return False - if config and config.provider == ProviderType.OPENAI: - # Check if model is allowed by restrictions - from utils.model_restrictions import get_restriction_service + # Now, perform the restriction check once. + from utils.model_restrictions import get_restriction_service - restriction_service = get_restriction_service() - if not restriction_service.is_allowed(ProviderType.OPENAI, config.model_name, model_name): - logger.debug(f"OpenAI custom model '{model_name}' -> '{resolved_name}' blocked by restrictions") - return False - return True + restriction_service = get_restriction_service() + if not restriction_service.is_allowed(ProviderType.OPENAI, model_to_check, model_name): + model_type = "custom " if is_custom_model else "" + logger.debug(f"OpenAI {model_type}model '{model_name}' -> '{resolved_name}' blocked by restrictions") + return False - except Exception as e: - # Log but don't fail - registry might not be available - logger.debug(f"Could not check custom models registry for '{resolved_name}': {e}") - - return False + return True def generate_content( self, diff --git a/tests/test_issue_245_simple.py b/tests/test_issue_245_simple.py index ff2928a..bd58ce8 100644 --- a/tests/test_issue_245_simple.py +++ b/tests/test_issue_245_simple.py @@ -75,9 +75,3 @@ def test_issue_245_custom_openai_temperature_ignored(): # Verify the fix: NO temperature should be sent to the API call_kwargs = mock_client.chat.completions.create.call_args[1] assert "temperature" not in call_kwargs, "Fix failed: temperature still being sent!" - - print("✅ Issue #245 is FIXED! Temperature parameter correctly ignored for custom models.") - - -if __name__ == "__main__": - test_issue_245_custom_openai_temperature_ignored()