166 lines
3.5 KiB
Go
166 lines
3.5 KiB
Go
package migrations
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"sort"
|
|
"strings"
|
|
|
|
_ "github.com/jackc/pgx/v5/stdlib"
|
|
)
|
|
|
|
// Migration represents a database migration
|
|
type Migration struct {
|
|
Version string
|
|
Up string
|
|
Down string
|
|
}
|
|
|
|
// Migrator handles database migrations
|
|
type Migrator struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// NewMigrator creates a new migrator
|
|
func NewMigrator(db *sql.DB) *Migrator {
|
|
return &Migrator{db: db}
|
|
}
|
|
|
|
// RunMigrations runs all pending migrations
|
|
func (m *Migrator) RunMigrations(migrationsDir string) error {
|
|
// Create migrations table if it doesn't exist
|
|
if err := m.createMigrationsTable(); err != nil {
|
|
return fmt.Errorf("failed to create migrations table: %w", err)
|
|
}
|
|
|
|
// Load migration files
|
|
migrations, err := m.loadMigrations(migrationsDir)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load migrations: %w", err)
|
|
}
|
|
|
|
// Get applied migrations
|
|
applied, err := m.getAppliedMigrations()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get applied migrations: %w", err)
|
|
}
|
|
|
|
// Run pending migrations
|
|
for _, migration := range migrations {
|
|
if applied[migration.Version] {
|
|
continue
|
|
}
|
|
|
|
if err := m.runMigration(migration); err != nil {
|
|
return fmt.Errorf("failed to run migration %s: %w", migration.Version, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *Migrator) createMigrationsTable() error {
|
|
query := `
|
|
CREATE TABLE IF NOT EXISTS schema_migrations (
|
|
version VARCHAR(255) PRIMARY KEY,
|
|
applied_at TIMESTAMP DEFAULT NOW()
|
|
)
|
|
`
|
|
_, err := m.db.Exec(query)
|
|
return err
|
|
}
|
|
|
|
func (m *Migrator) loadMigrations(dir string) ([]Migration, error) {
|
|
files, err := os.ReadDir(dir)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
migrations := make(map[string]*Migration)
|
|
|
|
for _, file := range files {
|
|
if file.IsDir() {
|
|
continue
|
|
}
|
|
|
|
filename := file.Name()
|
|
if !strings.HasSuffix(filename, ".up.sql") && !strings.HasSuffix(filename, ".down.sql") {
|
|
continue
|
|
}
|
|
|
|
version := strings.TrimSuffix(filename, ".up.sql")
|
|
version = strings.TrimSuffix(version, ".down.sql")
|
|
|
|
if migrations[version] == nil {
|
|
migrations[version] = &Migration{Version: version}
|
|
}
|
|
|
|
content, err := os.ReadFile(filepath.Join(dir, filename))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if strings.HasSuffix(filename, ".up.sql") {
|
|
migrations[version].Up = string(content)
|
|
} else if strings.HasSuffix(filename, ".down.sql") {
|
|
migrations[version].Down = string(content)
|
|
}
|
|
}
|
|
|
|
// Convert to slice and sort
|
|
result := make([]Migration, 0, len(migrations))
|
|
for _, m := range migrations {
|
|
result = append(result, *m)
|
|
}
|
|
|
|
sort.Slice(result, func(i, j int) bool {
|
|
return result[i].Version < result[j].Version
|
|
})
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func (m *Migrator) getAppliedMigrations() (map[string]bool, error) {
|
|
rows, err := m.db.Query("SELECT version FROM schema_migrations")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
applied := make(map[string]bool)
|
|
for rows.Next() {
|
|
var version string
|
|
if err := rows.Scan(&version); err != nil {
|
|
return nil, err
|
|
}
|
|
applied[version] = true
|
|
}
|
|
|
|
return applied, rows.Err()
|
|
}
|
|
|
|
func (m *Migrator) runMigration(migration Migration) error {
|
|
tx, err := m.db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
// Execute migration
|
|
if _, err := tx.Exec(migration.Up); err != nil {
|
|
return fmt.Errorf("failed to execute migration: %w", err)
|
|
}
|
|
|
|
// Record migration
|
|
if _, err := tx.Exec(
|
|
"INSERT INTO schema_migrations (version) VALUES ($1)",
|
|
migration.Version,
|
|
); err != nil {
|
|
return fmt.Errorf("failed to record migration: %w", err)
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|