Files
explorer-monorepo/backend/auth/auth.go

151 lines
3.7 KiB
Go

package auth
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"time"
"github.com/jackc/pgx/v5/pgxpool"
"golang.org/x/crypto/bcrypt"
)
// Auth handles user authentication
type Auth struct {
db *pgxpool.Pool
}
// NewAuth creates a new auth handler
func NewAuth(db *pgxpool.Pool) *Auth {
return &Auth{db: db}
}
// User represents a user
type User struct {
ID string
Email string
Username string
CreatedAt time.Time
}
// RegisterUser registers a new user
func (a *Auth) RegisterUser(ctx context.Context, email, username, password string) (*User, error) {
// Hash password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("failed to hash password: %w", err)
}
// Insert user
query := `
INSERT INTO users (email, username, password_hash)
VALUES ($1, $2, $3)
RETURNING id, email, username, created_at
`
var user User
err = a.db.QueryRow(ctx, query, email, username, hashedPassword).Scan(
&user.ID, &user.Email, &user.Username, &user.CreatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
}
return &user, nil
}
// AuthenticateUser authenticates a user
func (a *Auth) AuthenticateUser(ctx context.Context, email, password string) (*User, error) {
var user User
var passwordHash string
query := `SELECT id, email, username, password_hash, created_at FROM users WHERE email = $1`
err := a.db.QueryRow(ctx, query, email).Scan(
&user.ID, &user.Email, &user.Username, &passwordHash, &user.CreatedAt,
)
if err != nil {
return nil, fmt.Errorf("invalid credentials")
}
// Verify password
if err := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(password)); err != nil {
return nil, fmt.Errorf("invalid credentials")
}
return &user, nil
}
// GenerateAPIKey generates a new API key for a user
func (a *Auth) GenerateAPIKey(ctx context.Context, userID, name string, tier string) (string, error) {
// Generate random key
keyBytes := make([]byte, 32)
if _, err := rand.Read(keyBytes); err != nil {
return "", fmt.Errorf("failed to generate key: %w", err)
}
apiKey := "ek_" + hex.EncodeToString(keyBytes)
// Hash key for storage
hashedKey := sha256.Sum256([]byte(apiKey))
hashedKeyHex := hex.EncodeToString(hashedKey[:])
// Determine rate limits based on tier
var rateLimitPerSecond, rateLimitPerMinute int
switch tier {
case "free":
rateLimitPerSecond = 5
rateLimitPerMinute = 100
case "pro":
rateLimitPerSecond = 20
rateLimitPerMinute = 1000
case "enterprise":
rateLimitPerSecond = 100
rateLimitPerMinute = 10000
default:
rateLimitPerSecond = 5
rateLimitPerMinute = 100
}
// Store API key
query := `
INSERT INTO api_keys (user_id, key_hash, name, tier, rate_limit_per_second, rate_limit_per_minute)
VALUES ($1, $2, $3, $4, $5, $6)
`
_, err := a.db.Exec(ctx, query, userID, hashedKeyHex, name, tier, rateLimitPerSecond, rateLimitPerMinute)
if err != nil {
return "", fmt.Errorf("failed to store API key: %w", err)
}
return apiKey, nil
}
// ValidateAPIKey validates an API key
func (a *Auth) ValidateAPIKey(ctx context.Context, apiKey string) (string, error) {
hashedKey := sha256.Sum256([]byte(apiKey))
hashedKeyHex := hex.EncodeToString(hashedKey[:])
var userID string
var revoked bool
query := `SELECT user_id, revoked FROM api_keys WHERE key_hash = $1`
err := a.db.QueryRow(ctx, query, hashedKeyHex).Scan(&userID, &revoked)
if err != nil {
return "", fmt.Errorf("invalid API key")
}
if revoked {
return "", fmt.Errorf("API key revoked")
}
// Update last used
a.db.Exec(ctx, `UPDATE api_keys SET last_used_at = NOW() WHERE key_hash = $1`, hashedKeyHex)
return userID, nil
}