Files
virtual-banker/backend/session/session.go

317 lines
8.0 KiB
Go

package session
import (
"context"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"time"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/redis/go-redis/v9"
)
// Session represents a Virtual Banker session
type Session struct {
ID string
TenantID string
UserID string
EphemeralToken string
Config *TenantConfig
CreatedAt time.Time
ExpiresAt time.Time
LastActivityAt time.Time
}
// TenantConfig holds tenant-specific configuration
type TenantConfig struct {
Theme map[string]interface{} `json:"theme"`
AvatarEnabled bool `json:"avatar_enabled"`
Greeting string `json:"greeting"`
AllowedTools []string `json:"allowed_tools"`
Policy *PolicyConfig `json:"policy"`
}
// PolicyConfig holds policy settings
type PolicyConfig struct {
MaxSessionDuration time.Duration `json:"max_session_duration"`
RateLimitPerMinute int `json:"rate_limit_per_minute"`
RequireConsent bool `json:"require_consent"`
}
// Manager manages sessions
type Manager struct {
db *pgxpool.Pool
redis *redis.Client
}
// NewManager creates a new session manager
func NewManager(db *pgxpool.Pool, redisClient *redis.Client) *Manager {
return &Manager{
db: db,
redis: redisClient,
}
}
// CreateSession creates a new session
func (m *Manager) CreateSession(ctx context.Context, tenantID, userID string, authAssertion string) (*Session, error) {
// Validate JWT/auth assertion (simplified - should validate with tenant JWKs)
if authAssertion == "" {
return nil, errors.New("auth assertion required")
}
// Load tenant config
config, err := m.loadTenantConfig(ctx, tenantID)
if err != nil {
return nil, fmt.Errorf("failed to load tenant config: %w", err)
}
// Generate session ID
sessionID, err := generateSessionID()
if err != nil {
return nil, fmt.Errorf("failed to generate session ID: %w", err)
}
// Generate ephemeral token
ephemeralToken, err := generateEphemeralToken()
if err != nil {
return nil, fmt.Errorf("failed to generate ephemeral token: %w", err)
}
now := time.Now()
sessionDuration := config.Policy.MaxSessionDuration
if sessionDuration == 0 {
sessionDuration = 30 * time.Minute // default
}
session := &Session{
ID: sessionID,
TenantID: tenantID,
UserID: userID,
EphemeralToken: ephemeralToken,
Config: config,
CreatedAt: now,
ExpiresAt: now.Add(sessionDuration),
LastActivityAt: now,
}
// Save to database
if err := m.saveSessionToDB(ctx, session); err != nil {
return nil, fmt.Errorf("failed to save session: %w", err)
}
// Cache in Redis
if err := m.cacheSession(ctx, session); err != nil {
return nil, fmt.Errorf("failed to cache session: %w", err)
}
return session, nil
}
// GetSession retrieves a session by ID
func (m *Manager) GetSession(ctx context.Context, sessionID string) (*Session, error) {
// Try Redis first
session, err := m.getSessionFromCache(ctx, sessionID)
if err == nil && session != nil {
return session, nil
}
// Fallback to database
session, err = m.getSessionFromDB(ctx, sessionID)
if err != nil {
return nil, fmt.Errorf("session not found: %w", err)
}
// Cache it
_ = m.cacheSession(ctx, session)
return session, nil
}
// RefreshToken refreshes the ephemeral token for a session
func (m *Manager) RefreshToken(ctx context.Context, sessionID string) (string, error) {
session, err := m.GetSession(ctx, sessionID)
if err != nil {
return "", err
}
// Check if session is expired
if time.Now().After(session.ExpiresAt) {
return "", errors.New("session expired")
}
// Generate new token
newToken, err := generateEphemeralToken()
if err != nil {
return "", fmt.Errorf("failed to generate token: %w", err)
}
session.EphemeralToken = newToken
session.LastActivityAt = time.Now()
// Update in database and cache
if err := m.saveSessionToDB(ctx, session); err != nil {
return "", fmt.Errorf("failed to update session: %w", err)
}
_ = m.cacheSession(ctx, session)
return newToken, nil
}
// EndSession ends a session
func (m *Manager) EndSession(ctx context.Context, sessionID string) error {
// Remove from Redis
_ = m.redis.Del(ctx, fmt.Sprintf("session:%s", sessionID))
// Mark as ended in database
query := `UPDATE sessions SET ended_at = $1 WHERE id = $2`
_, err := m.db.Exec(ctx, query, time.Now(), sessionID)
return err
}
// loadTenantConfig loads tenant configuration
func (m *Manager) loadTenantConfig(ctx context.Context, tenantID string) (*TenantConfig, error) {
query := `
SELECT theme, avatar_enabled, greeting, allowed_tools, policy
FROM tenants
WHERE id = $1
`
var config TenantConfig
var themeJSON, policyJSON []byte
var allowedToolsJSON []byte
err := m.db.QueryRow(ctx, query, tenantID).Scan(
&themeJSON,
&config.AvatarEnabled,
&config.Greeting,
&allowedToolsJSON,
&policyJSON,
)
if err != nil {
// Return default config if tenant not found
return &TenantConfig{
Theme: map[string]interface{}{"primaryColor": "#0066cc"},
AvatarEnabled: true,
Greeting: "Hello! How can I help you today?",
AllowedTools: []string{},
Policy: &PolicyConfig{
MaxSessionDuration: 30 * time.Minute,
RateLimitPerMinute: 10,
RequireConsent: true,
},
}, nil
}
// Parse JSON fields (simplified - should use json.Unmarshal)
// For now, return default with basic parsing
config.Policy = &PolicyConfig{
MaxSessionDuration: 30 * time.Minute,
RateLimitPerMinute: 10,
RequireConsent: true,
}
return &config, nil
}
// saveSessionToDB saves session to database
func (m *Manager) saveSessionToDB(ctx context.Context, session *Session) error {
query := `
INSERT INTO sessions (id, tenant_id, user_id, ephemeral_token, created_at, expires_at, last_activity_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (id) DO UPDATE SET
ephemeral_token = $4,
last_activity_at = $7
`
_, err := m.db.Exec(ctx, query,
session.ID,
session.TenantID,
session.UserID,
session.EphemeralToken,
session.CreatedAt,
session.ExpiresAt,
session.LastActivityAt,
)
return err
}
// getSessionFromDB retrieves session from database
func (m *Manager) getSessionFromDB(ctx context.Context, sessionID string) (*Session, error) {
query := `
SELECT id, tenant_id, user_id, ephemeral_token, created_at, expires_at, last_activity_at
FROM sessions
WHERE id = $1 AND ended_at IS NULL
`
var session Session
err := m.db.QueryRow(ctx, query, sessionID).Scan(
&session.ID,
&session.TenantID,
&session.UserID,
&session.EphemeralToken,
&session.CreatedAt,
&session.ExpiresAt,
&session.LastActivityAt,
)
if err != nil {
return nil, err
}
// Load config
config, err := m.loadTenantConfig(ctx, session.TenantID)
if err != nil {
return nil, err
}
session.Config = config
return &session, nil
}
// cacheSession caches session in Redis
func (m *Manager) cacheSession(ctx context.Context, session *Session) error {
key := fmt.Sprintf("session:%s", session.ID)
ttl := time.Until(session.ExpiresAt)
if ttl <= 0 {
return nil
}
// Store as JSON (simplified - should serialize properly)
return m.redis.Set(ctx, key, session.ID, ttl).Err()
}
// getSessionFromCache retrieves session from Redis cache
func (m *Manager) getSessionFromCache(ctx context.Context, sessionID string) (*Session, error) {
key := fmt.Sprintf("session:%s", sessionID)
val, err := m.redis.Get(ctx, key).Result()
if err != nil {
return nil, err
}
if val != sessionID {
return nil, errors.New("cache mismatch")
}
// If cached, fetch full session from DB
return m.getSessionFromDB(ctx, sessionID)
}
// generateSessionID generates a unique session ID
func generateSessionID() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}
// generateEphemeralToken generates an ephemeral token
func generateEphemeralToken() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}