262 lines
9.4 KiB
Python
262 lines
9.4 KiB
Python
"""OpenAI LLM adapter with error handling and retry logic."""
|
|
|
|
import time
|
|
from typing import Any
|
|
|
|
from fusionagi.adapters.base import LLMAdapter
|
|
from fusionagi._logger import logger
|
|
|
|
|
|
class OpenAIAdapterError(Exception):
|
|
"""Base exception for OpenAI adapter errors."""
|
|
|
|
pass
|
|
|
|
|
|
class OpenAIRateLimitError(OpenAIAdapterError):
|
|
"""Raised when rate limited by OpenAI API."""
|
|
|
|
pass
|
|
|
|
|
|
class OpenAIAuthenticationError(OpenAIAdapterError):
|
|
"""Raised when authentication fails."""
|
|
|
|
pass
|
|
|
|
|
|
class OpenAIAdapter(LLMAdapter):
|
|
"""
|
|
OpenAI API adapter with retry logic and error handling.
|
|
|
|
Requires openai package and OPENAI_API_KEY.
|
|
|
|
Features:
|
|
- Automatic retry with exponential backoff for transient errors
|
|
- Proper error classification (rate limits, auth errors, etc.)
|
|
- Structured output support via complete_structured()
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: str = "gpt-4o-mini",
|
|
api_key: str | None = None,
|
|
max_retries: int = 3,
|
|
retry_delay: float = 1.0,
|
|
retry_multiplier: float = 2.0,
|
|
max_retry_delay: float = 30.0,
|
|
**client_kwargs: Any,
|
|
) -> None:
|
|
"""
|
|
Initialize the OpenAI adapter.
|
|
|
|
Args:
|
|
model: Default model to use (e.g., "gpt-4o-mini", "gpt-4o").
|
|
api_key: OpenAI API key. If None, uses OPENAI_API_KEY env var.
|
|
max_retries: Maximum number of retry attempts for transient errors.
|
|
retry_delay: Initial delay between retries in seconds.
|
|
retry_multiplier: Multiplier for exponential backoff.
|
|
max_retry_delay: Maximum delay between retries.
|
|
**client_kwargs: Additional arguments passed to OpenAI client.
|
|
"""
|
|
self._model = model
|
|
self._api_key = api_key
|
|
self._max_retries = max_retries
|
|
self._retry_delay = retry_delay
|
|
self._retry_multiplier = retry_multiplier
|
|
self._max_retry_delay = max_retry_delay
|
|
self._client_kwargs = client_kwargs
|
|
self._client: Any = None
|
|
self._openai_module: Any = None
|
|
|
|
def _get_client(self) -> Any:
|
|
if self._client is None:
|
|
try:
|
|
import openai
|
|
self._openai_module = openai
|
|
self._client = openai.OpenAI(api_key=self._api_key, **self._client_kwargs)
|
|
except ImportError as e:
|
|
raise ImportError("Install with: pip install fusionagi[openai]") from e
|
|
return self._client
|
|
|
|
def _is_retryable_error(self, error: Exception) -> bool:
|
|
"""Check if an error is retryable (transient)."""
|
|
if self._openai_module is None:
|
|
return False
|
|
|
|
# Rate limit errors are retryable
|
|
if hasattr(self._openai_module, "RateLimitError"):
|
|
if isinstance(error, self._openai_module.RateLimitError):
|
|
return True
|
|
|
|
# API connection errors are retryable
|
|
if hasattr(self._openai_module, "APIConnectionError"):
|
|
if isinstance(error, self._openai_module.APIConnectionError):
|
|
return True
|
|
|
|
# Internal server errors are retryable
|
|
if hasattr(self._openai_module, "InternalServerError"):
|
|
if isinstance(error, self._openai_module.InternalServerError):
|
|
return True
|
|
|
|
# Timeout errors are retryable
|
|
if hasattr(self._openai_module, "APITimeoutError"):
|
|
if isinstance(error, self._openai_module.APITimeoutError):
|
|
return True
|
|
|
|
return False
|
|
|
|
def _classify_error(self, error: Exception) -> Exception:
|
|
"""Convert OpenAI exceptions to adapter exceptions."""
|
|
if self._openai_module is None:
|
|
return OpenAIAdapterError(str(error))
|
|
|
|
if hasattr(self._openai_module, "RateLimitError"):
|
|
if isinstance(error, self._openai_module.RateLimitError):
|
|
return OpenAIRateLimitError(str(error))
|
|
|
|
if hasattr(self._openai_module, "AuthenticationError"):
|
|
if isinstance(error, self._openai_module.AuthenticationError):
|
|
return OpenAIAuthenticationError(str(error))
|
|
|
|
return OpenAIAdapterError(str(error))
|
|
|
|
def complete(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""
|
|
Call OpenAI chat completion with retry logic.
|
|
|
|
Args:
|
|
messages: List of message dicts with 'role' and 'content'.
|
|
**kwargs: Additional arguments for the API call (e.g., temperature).
|
|
|
|
Returns:
|
|
The assistant's response content.
|
|
|
|
Raises:
|
|
OpenAIAuthenticationError: If authentication fails.
|
|
OpenAIRateLimitError: If rate limited after all retries.
|
|
OpenAIAdapterError: For other API errors after all retries.
|
|
"""
|
|
# Validate messages format
|
|
if not messages:
|
|
logger.warning("OpenAI complete called with empty messages")
|
|
return ""
|
|
|
|
for i, msg in enumerate(messages):
|
|
if not isinstance(msg, dict):
|
|
raise ValueError(f"Message {i} must be a dict, got {type(msg).__name__}")
|
|
if "role" not in msg:
|
|
raise ValueError(f"Message {i} missing 'role' key")
|
|
if "content" not in msg:
|
|
raise ValueError(f"Message {i} missing 'content' key")
|
|
|
|
client = self._get_client()
|
|
model = kwargs.get("model", self._model)
|
|
call_kwargs = {**kwargs, "model": model}
|
|
|
|
last_error: Exception | None = None
|
|
delay = self._retry_delay
|
|
|
|
for attempt in range(self._max_retries + 1):
|
|
try:
|
|
resp = client.chat.completions.create(
|
|
messages=messages,
|
|
**call_kwargs,
|
|
)
|
|
choice = resp.choices[0] if resp.choices else None
|
|
if choice and choice.message and choice.message.content:
|
|
return choice.message.content
|
|
logger.debug("OpenAI empty response", extra={"model": model, "attempt": attempt})
|
|
return ""
|
|
|
|
except Exception as e:
|
|
last_error = e
|
|
|
|
# Don't retry authentication errors
|
|
if self._openai_module and hasattr(self._openai_module, "AuthenticationError"):
|
|
if isinstance(e, self._openai_module.AuthenticationError):
|
|
logger.error("OpenAI authentication failed", extra={"error": str(e)})
|
|
raise OpenAIAuthenticationError(str(e)) from e
|
|
|
|
# Check if retryable
|
|
if not self._is_retryable_error(e):
|
|
logger.error(
|
|
"OpenAI non-retryable error",
|
|
extra={"error": str(e), "error_type": type(e).__name__},
|
|
)
|
|
raise self._classify_error(e) from e
|
|
|
|
# Log retry attempt
|
|
if attempt < self._max_retries:
|
|
logger.warning(
|
|
"OpenAI retryable error, retrying",
|
|
extra={
|
|
"error": str(e),
|
|
"attempt": attempt + 1,
|
|
"max_retries": self._max_retries,
|
|
"delay": delay,
|
|
},
|
|
)
|
|
time.sleep(delay)
|
|
delay = min(delay * self._retry_multiplier, self._max_retry_delay)
|
|
|
|
# All retries exhausted
|
|
logger.error(
|
|
"OpenAI all retries exhausted",
|
|
extra={"error": str(last_error), "attempts": self._max_retries + 1},
|
|
)
|
|
raise self._classify_error(last_error) from last_error
|
|
|
|
def complete_structured(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
schema: dict[str, Any] | None = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""
|
|
Call OpenAI with JSON mode for structured output.
|
|
|
|
Args:
|
|
messages: List of message dicts with 'role' and 'content'.
|
|
schema: Optional JSON schema for response validation (informational).
|
|
**kwargs: Additional arguments for the API call.
|
|
|
|
Returns:
|
|
Parsed JSON response or None if parsing fails.
|
|
"""
|
|
import json
|
|
|
|
# Enable JSON mode
|
|
call_kwargs = {**kwargs, "response_format": {"type": "json_object"}}
|
|
|
|
# Add schema hint to system message if provided
|
|
if schema and messages:
|
|
schema_hint = f"\n\nRespond with JSON matching this schema: {json.dumps(schema)}"
|
|
if messages[0].get("role") == "system":
|
|
messages = [
|
|
{**messages[0], "content": messages[0]["content"] + schema_hint},
|
|
*messages[1:],
|
|
]
|
|
else:
|
|
messages = [
|
|
{"role": "system", "content": f"You must respond with valid JSON.{schema_hint}"},
|
|
*messages,
|
|
]
|
|
|
|
raw = self.complete(messages, **call_kwargs)
|
|
if not raw:
|
|
return None
|
|
|
|
try:
|
|
return json.loads(raw)
|
|
except json.JSONDecodeError as e:
|
|
logger.warning(
|
|
"OpenAI JSON parse failed",
|
|
extra={"error": str(e), "raw_response": raw[:200]},
|
|
)
|
|
return None
|