Files

166 lines
3.5 KiB
Go

package migrations
import (
"database/sql"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
_ "github.com/jackc/pgx/v5/stdlib"
)
// Migration represents a database migration
type Migration struct {
Version string
Up string
Down string
}
// Migrator handles database migrations
type Migrator struct {
db *sql.DB
}
// NewMigrator creates a new migrator
func NewMigrator(db *sql.DB) *Migrator {
return &Migrator{db: db}
}
// RunMigrations runs all pending migrations
func (m *Migrator) RunMigrations(migrationsDir string) error {
// Create migrations table if it doesn't exist
if err := m.createMigrationsTable(); err != nil {
return fmt.Errorf("failed to create migrations table: %w", err)
}
// Load migration files
migrations, err := m.loadMigrations(migrationsDir)
if err != nil {
return fmt.Errorf("failed to load migrations: %w", err)
}
// Get applied migrations
applied, err := m.getAppliedMigrations()
if err != nil {
return fmt.Errorf("failed to get applied migrations: %w", err)
}
// Run pending migrations
for _, migration := range migrations {
if applied[migration.Version] {
continue
}
if err := m.runMigration(migration); err != nil {
return fmt.Errorf("failed to run migration %s: %w", migration.Version, err)
}
}
return nil
}
func (m *Migrator) createMigrationsTable() error {
query := `
CREATE TABLE IF NOT EXISTS schema_migrations (
version VARCHAR(255) PRIMARY KEY,
applied_at TIMESTAMP DEFAULT NOW()
)
`
_, err := m.db.Exec(query)
return err
}
func (m *Migrator) loadMigrations(dir string) ([]Migration, error) {
files, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
migrations := make(map[string]*Migration)
for _, file := range files {
if file.IsDir() {
continue
}
filename := file.Name()
if !strings.HasSuffix(filename, ".up.sql") && !strings.HasSuffix(filename, ".down.sql") {
continue
}
version := strings.TrimSuffix(filename, ".up.sql")
version = strings.TrimSuffix(version, ".down.sql")
if migrations[version] == nil {
migrations[version] = &Migration{Version: version}
}
content, err := os.ReadFile(filepath.Join(dir, filename))
if err != nil {
return nil, err
}
if strings.HasSuffix(filename, ".up.sql") {
migrations[version].Up = string(content)
} else if strings.HasSuffix(filename, ".down.sql") {
migrations[version].Down = string(content)
}
}
// Convert to slice and sort
result := make([]Migration, 0, len(migrations))
for _, m := range migrations {
result = append(result, *m)
}
sort.Slice(result, func(i, j int) bool {
return result[i].Version < result[j].Version
})
return result, nil
}
func (m *Migrator) getAppliedMigrations() (map[string]bool, error) {
rows, err := m.db.Query("SELECT version FROM schema_migrations")
if err != nil {
return nil, err
}
defer rows.Close()
applied := make(map[string]bool)
for rows.Next() {
var version string
if err := rows.Scan(&version); err != nil {
return nil, err
}
applied[version] = true
}
return applied, rows.Err()
}
func (m *Migrator) runMigration(migration Migration) error {
tx, err := m.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// Execute migration
if _, err := tx.Exec(migration.Up); err != nil {
return fmt.Errorf("failed to execute migration: %w", err)
}
// Record migration
if _, err := tx.Exec(
"INSERT INTO schema_migrations (version) VALUES ($1)",
migration.Version,
); err != nil {
return fmt.Errorf("failed to record migration: %w", err)
}
return tx.Commit()
}