Files
explorer-monorepo/backend/api/rest/ai_runtime.go
2026-03-27 14:12:14 -07:00

293 lines
7.7 KiB
Go

package rest
import (
"encoding/json"
"log"
"net"
"net/http"
"strings"
"sync"
"time"
)
type AIRateLimiter struct {
mu sync.Mutex
entries map[string][]time.Time
}
func NewAIRateLimiter() *AIRateLimiter {
return &AIRateLimiter{
entries: make(map[string][]time.Time),
}
}
func (l *AIRateLimiter) Allow(key string, limit int, window time.Duration) (bool, time.Duration) {
if limit <= 0 {
return true, 0
}
now := time.Now()
cutoff := now.Add(-window)
l.mu.Lock()
defer l.mu.Unlock()
timestamps := l.entries[key]
kept := timestamps[:0]
for _, ts := range timestamps {
if ts.After(cutoff) {
kept = append(kept, ts)
}
}
if len(kept) >= limit {
retryAfter := kept[0].Add(window).Sub(now)
l.entries[key] = kept
if retryAfter < 0 {
retryAfter = 0
}
return false, retryAfter
}
kept = append(kept, now)
l.entries[key] = kept
return true, 0
}
type AIMetrics struct {
mu sync.Mutex
ContextRequests int64 `json:"contextRequests"`
ChatRequests int64 `json:"chatRequests"`
RateLimited int64 `json:"rateLimited"`
UpstreamFailures int64 `json:"upstreamFailures"`
LastRequestAt string `json:"lastRequestAt,omitempty"`
LastErrorCode string `json:"lastErrorCode,omitempty"`
StatusCounts map[string]int64 `json:"statusCounts"`
ErrorCounts map[string]int64 `json:"errorCounts"`
LastDurationsMs map[string]float64 `json:"lastDurationsMs"`
LastRequests []map[string]string `json:"lastRequests"`
}
func NewAIMetrics() *AIMetrics {
return &AIMetrics{
StatusCounts: make(map[string]int64),
ErrorCounts: make(map[string]int64),
LastDurationsMs: make(map[string]float64),
LastRequests: []map[string]string{},
}
}
func (m *AIMetrics) Record(endpoint string, statusCode int, duration time.Duration, errorCode, clientIP string) {
m.mu.Lock()
defer m.mu.Unlock()
if endpoint == "context" {
m.ContextRequests++
}
if endpoint == "chat" {
m.ChatRequests++
}
if errorCode == "rate_limited" {
m.RateLimited++
}
if strings.HasPrefix(errorCode, "upstream_") {
m.UpstreamFailures++
}
statusKey := endpoint + ":" + http.StatusText(statusCode)
m.StatusCounts[statusKey]++
if errorCode != "" {
m.ErrorCounts[errorCode]++
m.LastErrorCode = errorCode
}
m.LastRequestAt = time.Now().UTC().Format(time.RFC3339)
m.LastDurationsMs[endpoint] = float64(duration.Milliseconds())
m.LastRequests = append([]map[string]string{{
"endpoint": endpoint,
"status": http.StatusText(statusCode),
"statusCode": http.StatusText(statusCode),
"clientIp": clientIP,
"at": m.LastRequestAt,
"errorCode": errorCode,
}}, m.LastRequests...)
if len(m.LastRequests) > 12 {
m.LastRequests = m.LastRequests[:12]
}
}
func (m *AIMetrics) Snapshot() map[string]any {
m.mu.Lock()
defer m.mu.Unlock()
statusCounts := make(map[string]int64, len(m.StatusCounts))
for key, value := range m.StatusCounts {
statusCounts[key] = value
}
errorCounts := make(map[string]int64, len(m.ErrorCounts))
for key, value := range m.ErrorCounts {
errorCounts[key] = value
}
lastDurations := make(map[string]float64, len(m.LastDurationsMs))
for key, value := range m.LastDurationsMs {
lastDurations[key] = value
}
lastRequests := make([]map[string]string, len(m.LastRequests))
for i := range m.LastRequests {
copyMap := make(map[string]string, len(m.LastRequests[i]))
for key, value := range m.LastRequests[i] {
copyMap[key] = value
}
lastRequests[i] = copyMap
}
return map[string]any{
"contextRequests": m.ContextRequests,
"chatRequests": m.ChatRequests,
"rateLimited": m.RateLimited,
"upstreamFailures": m.UpstreamFailures,
"lastRequestAt": m.LastRequestAt,
"lastErrorCode": m.LastErrorCode,
"statusCounts": statusCounts,
"errorCounts": errorCounts,
"lastDurationsMs": lastDurations,
"lastRequests": lastRequests,
}
}
func clientIPAddress(r *http.Request) string {
for _, header := range []string{"X-Forwarded-For", "X-Real-IP"} {
if raw := strings.TrimSpace(r.Header.Get(header)); raw != "" {
if header == "X-Forwarded-For" {
parts := strings.Split(raw, ",")
if len(parts) > 0 {
return strings.TrimSpace(parts[0])
}
}
return raw
}
}
host, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr))
if err == nil && host != "" {
return host
}
return strings.TrimSpace(r.RemoteAddr)
}
func explorerAIContextRateLimit() (int, time.Duration) {
return 60, time.Minute
}
func explorerAIChatRateLimit() (int, time.Duration) {
return 12, time.Minute
}
func (s *Server) allowAIRequest(r *http.Request, endpoint string) (bool, time.Duration) {
limit := 0
window := time.Minute
switch endpoint {
case "context":
limit, window = explorerAIContextRateLimit()
case "chat":
limit, window = explorerAIChatRateLimit()
}
clientIP := clientIPAddress(r)
return s.aiLimiter.Allow(endpoint+":"+clientIP, limit, window)
}
func (s *Server) logAIRequest(endpoint string, statusCode int, duration time.Duration, clientIP, model, errorCode string) {
statusText := http.StatusText(statusCode)
if statusText == "" {
statusText = "unknown"
}
log.Printf("AI endpoint=%s status=%d duration_ms=%d client_ip=%s model=%s error_code=%s", endpoint, statusCode, duration.Milliseconds(), clientIP, model, errorCode)
}
func (s *Server) handleAIMetrics(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeMethodNotAllowed(w)
return
}
contextLimit, contextWindow := explorerAIContextRateLimit()
chatLimit, chatWindow := explorerAIChatRateLimit()
response := map[string]any{
"generatedAt": time.Now().UTC().Format(time.RFC3339),
"rateLimits": map[string]any{
"context": map[string]any{
"requests": contextLimit,
"window": contextWindow.String(),
},
"chat": map[string]any{
"requests": chatLimit,
"window": chatWindow.String(),
},
},
"metrics": s.aiMetrics.Snapshot(),
}
writeJSON(w, http.StatusOK, response)
}
func writeJSON(w http.ResponseWriter, statusCode int, payload any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
_ = json.NewEncoder(w).Encode(payload)
}
type AIUpstreamError struct {
StatusCode int
Code string
Message string
Details string
}
func (e *AIUpstreamError) Error() string {
if e == nil {
return ""
}
if e.Details != "" {
return e.Message + ": " + e.Details
}
return e.Message
}
func mapAIUpstreamError(err error) (int, string, string, string) {
if err == nil {
return http.StatusOK, "", "", ""
}
upstreamErr, ok := err.(*AIUpstreamError)
if !ok {
return http.StatusBadGateway, "bad_gateway", "explorer ai request failed", err.Error()
}
switch upstreamErr.Code {
case "upstream_quota_exhausted":
return http.StatusServiceUnavailable, upstreamErr.Code, "explorer ai upstream quota exhausted", upstreamErr.Details
case "upstream_auth_failed":
return http.StatusBadGateway, upstreamErr.Code, "explorer ai upstream authentication failed", upstreamErr.Details
case "upstream_timeout":
return http.StatusGatewayTimeout, upstreamErr.Code, "explorer ai upstream timed out", upstreamErr.Details
case "upstream_bad_response":
return http.StatusBadGateway, upstreamErr.Code, "explorer ai upstream returned an invalid response", upstreamErr.Details
default:
return http.StatusBadGateway, upstreamErr.Code, upstreamErr.Message, upstreamErr.Details
}
}
func writeErrorDetailed(w http.ResponseWriter, statusCode int, code, message, details string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
_ = json.NewEncoder(w).Encode(ErrorResponse{
Error: ErrorDetail{
Code: code,
Message: message,
Details: details,
},
})
}