Files

100 lines
2.1 KiB
Go

package state
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/jackc/pgx/v5/pgxpool"
)
// StateManager manages conversation state
type StateManager struct {
db *pgxpool.Pool
}
// NewStateManager creates a new state manager
func NewStateManager(db *pgxpool.Pool) *StateManager {
return &StateManager{db: db}
}
// ConversationState represents conversation state
type ConversationState struct {
SessionID string
UserID string
Workflow string
Step string
Context map[string]interface{}
CreatedAt time.Time
UpdatedAt time.Time
ExpiresAt time.Time
}
// SaveState saves conversation state
func (s *StateManager) SaveState(ctx context.Context, state *ConversationState) error {
contextJSON, err := json.Marshal(state.Context)
if err != nil {
return fmt.Errorf("failed to marshal context: %w", err)
}
query := `
INSERT INTO conversation_states (
session_id, user_id, workflow, step, context, created_at, updated_at, expires_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (session_id) DO UPDATE SET
workflow = $3,
step = $4,
context = $5,
updated_at = $7,
expires_at = $8
`
_, err = s.db.Exec(ctx, query,
state.SessionID,
state.UserID,
state.Workflow,
state.Step,
contextJSON,
state.CreatedAt,
time.Now(),
state.ExpiresAt,
)
return err
}
// GetState gets conversation state
func (s *StateManager) GetState(ctx context.Context, sessionID string) (*ConversationState, error) {
query := `
SELECT session_id, user_id, workflow, step, context, created_at, updated_at, expires_at
FROM conversation_states
WHERE session_id = $1
`
var state ConversationState
var contextJSON []byte
err := s.db.QueryRow(ctx, query, sessionID).Scan(
&state.SessionID,
&state.UserID,
&state.Workflow,
&state.Step,
&contextJSON,
&state.CreatedAt,
&state.UpdatedAt,
&state.ExpiresAt,
)
if err != nil {
return nil, fmt.Errorf("failed to get state: %w", err)
}
if err := json.Unmarshal(contextJSON, &state.Context); err != nil {
return nil, fmt.Errorf("failed to unmarshal context: %w", err)
}
return &state, nil
}