266 lines
7.5 KiB
Python
266 lines
7.5 KiB
Python
"""OpenAI-compatible API routes for Cursor Composer and other consumers."""
|
|
|
|
import asyncio
|
|
import json
|
|
import uuid
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import Any
|
|
|
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request
|
|
from starlette.responses import StreamingResponse
|
|
|
|
from fusionagi.api.dependencies import (
|
|
ensure_initialized,
|
|
get_event_bus,
|
|
get_orchestrator,
|
|
get_safety_pipeline,
|
|
get_openai_bridge_config,
|
|
verify_openai_bridge_auth,
|
|
)
|
|
from fusionagi.api.openai_compat.translators import (
|
|
messages_to_prompt,
|
|
final_response_to_openai,
|
|
estimate_usage,
|
|
)
|
|
from fusionagi.core import run_dvadasa
|
|
from fusionagi.schemas.commands import parse_user_input
|
|
|
|
router = APIRouter(tags=["openai-compat"])
|
|
|
|
# Chunk size for streaming (chars per SSE delta)
|
|
_STREAM_CHUNK_SIZE = 50
|
|
|
|
|
|
def _openai_error(status_code: int, message: str, error_type: str) -> HTTPException:
|
|
"""Raise HTTPException with OpenAI-style error body."""
|
|
return HTTPException(
|
|
status_code=status_code,
|
|
detail={"error": {"message": message, "type": error_type}},
|
|
)
|
|
|
|
|
|
def _ensure_openai_init() -> None:
|
|
"""Ensure orchestrator and dependencies are initialized."""
|
|
ensure_initialized()
|
|
|
|
|
|
async def _verify_auth_dep(authorization: str | None = Header(default=None)) -> None:
|
|
"""Dependency: verify auth for OpenAI bridge routes."""
|
|
verify_openai_bridge_auth(authorization)
|
|
|
|
|
|
@router.get("/models", dependencies=[Depends(_verify_auth_dep)])
|
|
async def list_models() -> dict[str, Any]:
|
|
"""
|
|
List available models (OpenAI-compatible).
|
|
Returns fusionagi-dvadasa as the single model.
|
|
"""
|
|
cfg = get_openai_bridge_config()
|
|
return {
|
|
"object": "list",
|
|
"data": [
|
|
{
|
|
"id": cfg.model_id,
|
|
"object": "model",
|
|
"created": 1704067200,
|
|
"owned_by": "fusionagi",
|
|
}
|
|
],
|
|
}
|
|
|
|
|
|
@router.post(
|
|
"/chat/completions",
|
|
dependencies=[Depends(_verify_auth_dep)],
|
|
response_model=None,
|
|
)
|
|
async def create_chat_completion(request: Request):
|
|
"""
|
|
Create chat completion (OpenAI-compatible).
|
|
Supports both sync (stream=false) and streaming (stream=true).
|
|
"""
|
|
_ensure_openai_init()
|
|
|
|
try:
|
|
body = await request.json()
|
|
except Exception as e:
|
|
raise _openai_error(400, f"Invalid JSON body: {e}", "invalid_request_error")
|
|
|
|
messages = body.get("messages")
|
|
if not messages or not isinstance(messages, list):
|
|
raise _openai_error(
|
|
400,
|
|
"messages is required and must be a non-empty array",
|
|
"invalid_request_error",
|
|
)
|
|
|
|
from fusionagi.api.openai_compat.translators import _extract_content
|
|
|
|
has_content = any(_extract_content(m).strip() for m in messages)
|
|
if not has_content:
|
|
raise _openai_error(
|
|
400,
|
|
"messages must contain at least one user or assistant message with content",
|
|
"invalid_request_error",
|
|
)
|
|
|
|
prompt = messages_to_prompt(messages)
|
|
if not prompt.strip():
|
|
raise _openai_error(
|
|
400,
|
|
"messages must contain at least one user or assistant message with content",
|
|
"invalid_request_error",
|
|
)
|
|
|
|
pipeline = get_safety_pipeline()
|
|
if pipeline:
|
|
pre_result = pipeline.pre_check(prompt)
|
|
if not pre_result.allowed:
|
|
raise _openai_error(
|
|
400,
|
|
pre_result.reason or "Input moderation failed",
|
|
"invalid_request_error",
|
|
)
|
|
|
|
orch = get_orchestrator()
|
|
bus = get_event_bus()
|
|
if not orch:
|
|
raise _openai_error(503, "Service not initialized", "internal_error")
|
|
|
|
cfg = get_openai_bridge_config()
|
|
request_model = body.get("model") or cfg.model_id
|
|
stream = body.get("stream", False) is True
|
|
|
|
task_id = orch.submit_task(goal=prompt[:200])
|
|
parsed = parse_user_input(prompt)
|
|
|
|
if stream:
|
|
return StreamingResponse(
|
|
_stream_chat_completion(
|
|
orch=orch,
|
|
bus=bus,
|
|
task_id=task_id,
|
|
prompt=prompt,
|
|
parsed=parsed,
|
|
request_model=request_model,
|
|
messages=messages,
|
|
pipeline=pipeline,
|
|
cfg=cfg,
|
|
),
|
|
media_type="text/event-stream",
|
|
)
|
|
|
|
# Sync path
|
|
final = run_dvadasa(
|
|
orchestrator=orch,
|
|
task_id=task_id,
|
|
user_prompt=prompt,
|
|
parsed=parsed,
|
|
event_bus=bus,
|
|
timeout_per_head=cfg.timeout_per_head,
|
|
)
|
|
|
|
if not final:
|
|
raise _openai_error(500, "Dvādaśa failed to produce response", "internal_error")
|
|
|
|
if pipeline:
|
|
post_result = pipeline.post_check(final.final_answer)
|
|
if not post_result.passed:
|
|
raise _openai_error(
|
|
400,
|
|
f"Output scan failed: {', '.join(post_result.flags)}",
|
|
"invalid_request_error",
|
|
)
|
|
|
|
result = final_response_to_openai(
|
|
final=final,
|
|
task_id=task_id,
|
|
request_model=request_model,
|
|
messages=messages,
|
|
)
|
|
return result
|
|
|
|
|
|
async def _stream_chat_completion(
|
|
orch: Any,
|
|
bus: Any,
|
|
task_id: str,
|
|
prompt: str,
|
|
parsed: Any,
|
|
request_model: str,
|
|
messages: list[dict[str, Any]],
|
|
pipeline: Any,
|
|
cfg: Any,
|
|
):
|
|
"""
|
|
Async generator that runs Dvādaśa and streams the final_answer as SSE chunks.
|
|
"""
|
|
loop = asyncio.get_event_loop()
|
|
executor = ThreadPoolExecutor(max_workers=1)
|
|
|
|
def run() -> Any:
|
|
return run_dvadasa(
|
|
orchestrator=orch,
|
|
task_id=task_id,
|
|
user_prompt=prompt,
|
|
parsed=parsed,
|
|
event_bus=bus,
|
|
timeout_per_head=cfg.timeout_per_head,
|
|
)
|
|
|
|
try:
|
|
final = await loop.run_in_executor(executor, run)
|
|
except Exception as e:
|
|
yield f"data: {json.dumps({'error': {'message': str(e), 'type': 'internal_error'}})}\n\n"
|
|
return
|
|
|
|
if not final:
|
|
yield f"data: {json.dumps({'error': {'message': 'Dvādaśa failed', 'type': 'internal_error'}})}\n\n"
|
|
return
|
|
|
|
if pipeline:
|
|
post_result = pipeline.post_check(final.final_answer)
|
|
if not post_result.passed:
|
|
yield f"data: {json.dumps({'error': {'message': 'Output scan failed', 'type': 'invalid_request_error'}})}\n\n"
|
|
return
|
|
|
|
chat_id = f"chatcmpl-{task_id[:24]}" if len(task_id) >= 24 else f"chatcmpl-{task_id}"
|
|
|
|
# Stream final_answer in chunks
|
|
text = final.final_answer
|
|
for i in range(0, len(text), _STREAM_CHUNK_SIZE):
|
|
chunk = text[i : i + _STREAM_CHUNK_SIZE]
|
|
chunk_json = {
|
|
"id": chat_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": 0,
|
|
"model": request_model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {"content": chunk},
|
|
"finish_reason": None,
|
|
}
|
|
],
|
|
}
|
|
yield f"data: {json.dumps(chunk_json)}\n\n"
|
|
|
|
# Final chunk with finish_reason
|
|
usage = estimate_usage(messages, text)
|
|
final_chunk = {
|
|
"id": chat_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": 0,
|
|
"model": request_model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {},
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"usage": usage,
|
|
}
|
|
yield f"data: {json.dumps(final_chunk)}\n\n"
|
|
yield "data: [DONE]\n\n"
|