100 lines
2.1 KiB
Go
100 lines
2.1 KiB
Go
package state
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
// StateManager manages conversation state
|
|
type StateManager struct {
|
|
db *pgxpool.Pool
|
|
}
|
|
|
|
// NewStateManager creates a new state manager
|
|
func NewStateManager(db *pgxpool.Pool) *StateManager {
|
|
return &StateManager{db: db}
|
|
}
|
|
|
|
// ConversationState represents conversation state
|
|
type ConversationState struct {
|
|
SessionID string
|
|
UserID string
|
|
Workflow string
|
|
Step string
|
|
Context map[string]interface{}
|
|
CreatedAt time.Time
|
|
UpdatedAt time.Time
|
|
ExpiresAt time.Time
|
|
}
|
|
|
|
// SaveState saves conversation state
|
|
func (s *StateManager) SaveState(ctx context.Context, state *ConversationState) error {
|
|
contextJSON, err := json.Marshal(state.Context)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal context: %w", err)
|
|
}
|
|
|
|
query := `
|
|
INSERT INTO conversation_states (
|
|
session_id, user_id, workflow, step, context, created_at, updated_at, expires_at
|
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
|
ON CONFLICT (session_id) DO UPDATE SET
|
|
workflow = $3,
|
|
step = $4,
|
|
context = $5,
|
|
updated_at = $7,
|
|
expires_at = $8
|
|
`
|
|
|
|
_, err = s.db.Exec(ctx, query,
|
|
state.SessionID,
|
|
state.UserID,
|
|
state.Workflow,
|
|
state.Step,
|
|
contextJSON,
|
|
state.CreatedAt,
|
|
time.Now(),
|
|
state.ExpiresAt,
|
|
)
|
|
|
|
return err
|
|
}
|
|
|
|
// GetState gets conversation state
|
|
func (s *StateManager) GetState(ctx context.Context, sessionID string) (*ConversationState, error) {
|
|
query := `
|
|
SELECT session_id, user_id, workflow, step, context, created_at, updated_at, expires_at
|
|
FROM conversation_states
|
|
WHERE session_id = $1
|
|
`
|
|
|
|
var state ConversationState
|
|
var contextJSON []byte
|
|
|
|
err := s.db.QueryRow(ctx, query, sessionID).Scan(
|
|
&state.SessionID,
|
|
&state.UserID,
|
|
&state.Workflow,
|
|
&state.Step,
|
|
&contextJSON,
|
|
&state.CreatedAt,
|
|
&state.UpdatedAt,
|
|
&state.ExpiresAt,
|
|
)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get state: %w", err)
|
|
}
|
|
|
|
if err := json.Unmarshal(contextJSON, &state.Context); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal context: %w", err)
|
|
}
|
|
|
|
return &state, nil
|
|
}
|
|
|