317 lines
8.0 KiB
Go
317 lines
8.0 KiB
Go
package session
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
// Session represents a Virtual Banker session
|
|
type Session struct {
|
|
ID string
|
|
TenantID string
|
|
UserID string
|
|
EphemeralToken string
|
|
Config *TenantConfig
|
|
CreatedAt time.Time
|
|
ExpiresAt time.Time
|
|
LastActivityAt time.Time
|
|
}
|
|
|
|
// TenantConfig holds tenant-specific configuration
|
|
type TenantConfig struct {
|
|
Theme map[string]interface{} `json:"theme"`
|
|
AvatarEnabled bool `json:"avatar_enabled"`
|
|
Greeting string `json:"greeting"`
|
|
AllowedTools []string `json:"allowed_tools"`
|
|
Policy *PolicyConfig `json:"policy"`
|
|
}
|
|
|
|
// PolicyConfig holds policy settings
|
|
type PolicyConfig struct {
|
|
MaxSessionDuration time.Duration `json:"max_session_duration"`
|
|
RateLimitPerMinute int `json:"rate_limit_per_minute"`
|
|
RequireConsent bool `json:"require_consent"`
|
|
}
|
|
|
|
// Manager manages sessions
|
|
type Manager struct {
|
|
db *pgxpool.Pool
|
|
redis *redis.Client
|
|
}
|
|
|
|
// NewManager creates a new session manager
|
|
func NewManager(db *pgxpool.Pool, redisClient *redis.Client) *Manager {
|
|
return &Manager{
|
|
db: db,
|
|
redis: redisClient,
|
|
}
|
|
}
|
|
|
|
// CreateSession creates a new session
|
|
func (m *Manager) CreateSession(ctx context.Context, tenantID, userID string, authAssertion string) (*Session, error) {
|
|
// Validate JWT/auth assertion (simplified - should validate with tenant JWKs)
|
|
if authAssertion == "" {
|
|
return nil, errors.New("auth assertion required")
|
|
}
|
|
|
|
// Load tenant config
|
|
config, err := m.loadTenantConfig(ctx, tenantID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to load tenant config: %w", err)
|
|
}
|
|
|
|
// Generate session ID
|
|
sessionID, err := generateSessionID()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate session ID: %w", err)
|
|
}
|
|
|
|
// Generate ephemeral token
|
|
ephemeralToken, err := generateEphemeralToken()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate ephemeral token: %w", err)
|
|
}
|
|
|
|
now := time.Now()
|
|
sessionDuration := config.Policy.MaxSessionDuration
|
|
if sessionDuration == 0 {
|
|
sessionDuration = 30 * time.Minute // default
|
|
}
|
|
|
|
session := &Session{
|
|
ID: sessionID,
|
|
TenantID: tenantID,
|
|
UserID: userID,
|
|
EphemeralToken: ephemeralToken,
|
|
Config: config,
|
|
CreatedAt: now,
|
|
ExpiresAt: now.Add(sessionDuration),
|
|
LastActivityAt: now,
|
|
}
|
|
|
|
// Save to database
|
|
if err := m.saveSessionToDB(ctx, session); err != nil {
|
|
return nil, fmt.Errorf("failed to save session: %w", err)
|
|
}
|
|
|
|
// Cache in Redis
|
|
if err := m.cacheSession(ctx, session); err != nil {
|
|
return nil, fmt.Errorf("failed to cache session: %w", err)
|
|
}
|
|
|
|
return session, nil
|
|
}
|
|
|
|
// GetSession retrieves a session by ID
|
|
func (m *Manager) GetSession(ctx context.Context, sessionID string) (*Session, error) {
|
|
// Try Redis first
|
|
session, err := m.getSessionFromCache(ctx, sessionID)
|
|
if err == nil && session != nil {
|
|
return session, nil
|
|
}
|
|
|
|
// Fallback to database
|
|
session, err = m.getSessionFromDB(ctx, sessionID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("session not found: %w", err)
|
|
}
|
|
|
|
// Cache it
|
|
_ = m.cacheSession(ctx, session)
|
|
|
|
return session, nil
|
|
}
|
|
|
|
// RefreshToken refreshes the ephemeral token for a session
|
|
func (m *Manager) RefreshToken(ctx context.Context, sessionID string) (string, error) {
|
|
session, err := m.GetSession(ctx, sessionID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Check if session is expired
|
|
if time.Now().After(session.ExpiresAt) {
|
|
return "", errors.New("session expired")
|
|
}
|
|
|
|
// Generate new token
|
|
newToken, err := generateEphemeralToken()
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to generate token: %w", err)
|
|
}
|
|
|
|
session.EphemeralToken = newToken
|
|
session.LastActivityAt = time.Now()
|
|
|
|
// Update in database and cache
|
|
if err := m.saveSessionToDB(ctx, session); err != nil {
|
|
return "", fmt.Errorf("failed to update session: %w", err)
|
|
}
|
|
_ = m.cacheSession(ctx, session)
|
|
|
|
return newToken, nil
|
|
}
|
|
|
|
// EndSession ends a session
|
|
func (m *Manager) EndSession(ctx context.Context, sessionID string) error {
|
|
// Remove from Redis
|
|
_ = m.redis.Del(ctx, fmt.Sprintf("session:%s", sessionID))
|
|
|
|
// Mark as ended in database
|
|
query := `UPDATE sessions SET ended_at = $1 WHERE id = $2`
|
|
_, err := m.db.Exec(ctx, query, time.Now(), sessionID)
|
|
return err
|
|
}
|
|
|
|
// loadTenantConfig loads tenant configuration
|
|
func (m *Manager) loadTenantConfig(ctx context.Context, tenantID string) (*TenantConfig, error) {
|
|
query := `
|
|
SELECT theme, avatar_enabled, greeting, allowed_tools, policy
|
|
FROM tenants
|
|
WHERE id = $1
|
|
`
|
|
|
|
var config TenantConfig
|
|
var themeJSON, policyJSON []byte
|
|
var allowedToolsJSON []byte
|
|
|
|
err := m.db.QueryRow(ctx, query, tenantID).Scan(
|
|
&themeJSON,
|
|
&config.AvatarEnabled,
|
|
&config.Greeting,
|
|
&allowedToolsJSON,
|
|
&policyJSON,
|
|
)
|
|
if err != nil {
|
|
// Return default config if tenant not found
|
|
return &TenantConfig{
|
|
Theme: map[string]interface{}{"primaryColor": "#0066cc"},
|
|
AvatarEnabled: true,
|
|
Greeting: "Hello! How can I help you today?",
|
|
AllowedTools: []string{},
|
|
Policy: &PolicyConfig{
|
|
MaxSessionDuration: 30 * time.Minute,
|
|
RateLimitPerMinute: 10,
|
|
RequireConsent: true,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// Parse JSON fields (simplified - should use json.Unmarshal)
|
|
// For now, return default with basic parsing
|
|
config.Policy = &PolicyConfig{
|
|
MaxSessionDuration: 30 * time.Minute,
|
|
RateLimitPerMinute: 10,
|
|
RequireConsent: true,
|
|
}
|
|
|
|
return &config, nil
|
|
}
|
|
|
|
// saveSessionToDB saves session to database
|
|
func (m *Manager) saveSessionToDB(ctx context.Context, session *Session) error {
|
|
query := `
|
|
INSERT INTO sessions (id, tenant_id, user_id, ephemeral_token, created_at, expires_at, last_activity_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
|
ON CONFLICT (id) DO UPDATE SET
|
|
ephemeral_token = $4,
|
|
last_activity_at = $7
|
|
`
|
|
|
|
_, err := m.db.Exec(ctx, query,
|
|
session.ID,
|
|
session.TenantID,
|
|
session.UserID,
|
|
session.EphemeralToken,
|
|
session.CreatedAt,
|
|
session.ExpiresAt,
|
|
session.LastActivityAt,
|
|
)
|
|
return err
|
|
}
|
|
|
|
// getSessionFromDB retrieves session from database
|
|
func (m *Manager) getSessionFromDB(ctx context.Context, sessionID string) (*Session, error) {
|
|
query := `
|
|
SELECT id, tenant_id, user_id, ephemeral_token, created_at, expires_at, last_activity_at
|
|
FROM sessions
|
|
WHERE id = $1 AND ended_at IS NULL
|
|
`
|
|
|
|
var session Session
|
|
err := m.db.QueryRow(ctx, query, sessionID).Scan(
|
|
&session.ID,
|
|
&session.TenantID,
|
|
&session.UserID,
|
|
&session.EphemeralToken,
|
|
&session.CreatedAt,
|
|
&session.ExpiresAt,
|
|
&session.LastActivityAt,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Load config
|
|
config, err := m.loadTenantConfig(ctx, session.TenantID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
session.Config = config
|
|
|
|
return &session, nil
|
|
}
|
|
|
|
// cacheSession caches session in Redis
|
|
func (m *Manager) cacheSession(ctx context.Context, session *Session) error {
|
|
key := fmt.Sprintf("session:%s", session.ID)
|
|
ttl := time.Until(session.ExpiresAt)
|
|
if ttl <= 0 {
|
|
return nil
|
|
}
|
|
|
|
// Store as JSON (simplified - should serialize properly)
|
|
return m.redis.Set(ctx, key, session.ID, ttl).Err()
|
|
}
|
|
|
|
// getSessionFromCache retrieves session from Redis cache
|
|
func (m *Manager) getSessionFromCache(ctx context.Context, sessionID string) (*Session, error) {
|
|
key := fmt.Sprintf("session:%s", sessionID)
|
|
val, err := m.redis.Get(ctx, key).Result()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if val != sessionID {
|
|
return nil, errors.New("cache mismatch")
|
|
}
|
|
|
|
// If cached, fetch full session from DB
|
|
return m.getSessionFromDB(ctx, sessionID)
|
|
}
|
|
|
|
// generateSessionID generates a unique session ID
|
|
func generateSessionID() (string, error) {
|
|
b := make([]byte, 16)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.URLEncoding.EncodeToString(b), nil
|
|
}
|
|
|
|
// generateEphemeralToken generates an ephemeral token
|
|
func generateEphemeralToken() (string, error) {
|
|
b := make([]byte, 32)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.URLEncoding.EncodeToString(b), nil
|
|
}
|