refactor: improved retry logic and moved core logic to base class
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
@@ -405,55 +404,42 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
# DIAL-specific: Get cached client for deployment endpoint
|
||||
deployment_client = self._get_deployment_client(resolved_model)
|
||||
|
||||
# Retry logic with progressive delays
|
||||
last_exception = None
|
||||
attempt_counter = {"value": 0}
|
||||
|
||||
for attempt in range(self.MAX_RETRIES):
|
||||
try:
|
||||
# Generate completion using deployment-specific client
|
||||
response = deployment_client.chat.completions.create(**completion_params)
|
||||
def _attempt() -> ModelResponse:
|
||||
attempt_counter["value"] += 1
|
||||
response = deployment_client.chat.completions.create(**completion_params)
|
||||
|
||||
# Extract content and usage
|
||||
content = response.choices[0].message.content
|
||||
usage = self._extract_usage(response)
|
||||
content = response.choices[0].message.content
|
||||
usage = self._extract_usage(response)
|
||||
|
||||
return ModelResponse(
|
||||
content=content,
|
||||
usage=usage,
|
||||
model_name=model_name,
|
||||
friendly_name=self.FRIENDLY_NAME,
|
||||
provider=self.get_provider_type(),
|
||||
metadata={
|
||||
"finish_reason": response.choices[0].finish_reason,
|
||||
"model": response.model,
|
||||
"id": response.id,
|
||||
"created": response.created,
|
||||
},
|
||||
)
|
||||
return ModelResponse(
|
||||
content=content,
|
||||
usage=usage,
|
||||
model_name=model_name,
|
||||
friendly_name=self.FRIENDLY_NAME,
|
||||
provider=self.get_provider_type(),
|
||||
metadata={
|
||||
"finish_reason": response.choices[0].finish_reason,
|
||||
"model": response.model,
|
||||
"id": response.id,
|
||||
"created": response.created,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
try:
|
||||
return self._run_with_retries(
|
||||
operation=_attempt,
|
||||
max_attempts=self.MAX_RETRIES,
|
||||
delays=self.RETRY_DELAYS,
|
||||
log_prefix=f"DIAL API ({resolved_model})",
|
||||
)
|
||||
except Exception as exc:
|
||||
attempts = max(attempt_counter["value"], 1)
|
||||
if attempts == 1:
|
||||
raise ValueError(f"DIAL API error for model {model_name}: {exc}") from exc
|
||||
|
||||
# Check if this is a retryable error
|
||||
is_retryable = self._is_error_retryable(e)
|
||||
|
||||
if not is_retryable:
|
||||
# Non-retryable error, raise immediately
|
||||
raise ValueError(f"DIAL API error for model {model_name}: {str(e)}")
|
||||
|
||||
# If this isn't the last attempt and error is retryable, wait and retry
|
||||
if attempt < self.MAX_RETRIES - 1:
|
||||
delay = self.RETRY_DELAYS[attempt]
|
||||
logger.info(
|
||||
f"DIAL API error (attempt {attempt + 1}/{self.MAX_RETRIES}), " f"retrying in {delay}s: {str(e)}"
|
||||
)
|
||||
time.sleep(delay)
|
||||
continue
|
||||
|
||||
# All retries exhausted
|
||||
raise ValueError(
|
||||
f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
|
||||
)
|
||||
raise ValueError(f"DIAL API error for model {model_name} after {attempts} attempts: {exc}") from exc
|
||||
|
||||
def close(self) -> None:
|
||||
"""Clean up HTTP clients when provider is closed."""
|
||||
|
||||
Reference in New Issue
Block a user