Files
explorer-monorepo/backend/auth/auth.go
defiQUG f46bd213ba refactor: rename SolaceScanScout to Solace and update related configurations
- Updated branding from "SolaceScanScout" to "Solace" across various files including deployment scripts, API responses, and documentation.
- Changed default base URL for Playwright tests and updated security headers to reflect the new branding.
- Enhanced README and API documentation to include new authentication endpoints and product access details.

This refactor aligns the project branding and improves clarity in the API documentation.
2026-04-10 12:52:17 -07:00

632 lines
18 KiB
Go

package auth
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"time"
"github.com/jackc/pgx/v5/pgxpool"
"golang.org/x/crypto/bcrypt"
)
// Auth handles user authentication
type Auth struct {
db *pgxpool.Pool
}
// NewAuth creates a new auth handler
func NewAuth(db *pgxpool.Pool) *Auth {
return &Auth{db: db}
}
// User represents a user
type User struct {
ID string
Email string
Username string
CreatedAt time.Time
}
type APIKeyInfo struct {
ID string `json:"id"`
Name string `json:"name"`
Tier string `json:"tier"`
ProductSlug string `json:"productSlug"`
Scopes []string `json:"scopes"`
MonthlyQuota int `json:"monthlyQuota"`
RequestsUsed int `json:"requestsUsed"`
Approved bool `json:"approved"`
ApprovedAt *time.Time `json:"approvedAt"`
RateLimitPerSecond int `json:"rateLimitPerSecond"`
RateLimitPerMinute int `json:"rateLimitPerMinute"`
LastUsedAt *time.Time `json:"lastUsedAt"`
ExpiresAt *time.Time `json:"expiresAt"`
Revoked bool `json:"revoked"`
CreatedAt time.Time `json:"createdAt"`
}
type ValidatedAPIKey struct {
UserID string `json:"userId"`
APIKeyID string `json:"apiKeyId"`
Name string `json:"name"`
Tier string `json:"tier"`
ProductSlug string `json:"productSlug"`
Scopes []string `json:"scopes"`
MonthlyQuota int `json:"monthlyQuota"`
RequestsUsed int `json:"requestsUsed"`
RateLimitPerSecond int `json:"rateLimitPerSecond"`
RateLimitPerMinute int `json:"rateLimitPerMinute"`
LastUsedAt *time.Time `json:"lastUsedAt"`
ExpiresAt *time.Time `json:"expiresAt"`
}
type ProductSubscription struct {
ID string `json:"id"`
ProductSlug string `json:"productSlug"`
Tier string `json:"tier"`
Status string `json:"status"`
MonthlyQuota int `json:"monthlyQuota"`
RequestsUsed int `json:"requestsUsed"`
RequiresApproval bool `json:"requiresApproval"`
ApprovedAt *time.Time `json:"approvedAt"`
ApprovedBy *string `json:"approvedBy"`
Notes *string `json:"notes"`
CreatedAt time.Time `json:"createdAt"`
}
type APIKeyUsageLog struct {
ID int64 `json:"id"`
APIKeyID string `json:"apiKeyId"`
KeyName string `json:"keyName"`
ProductSlug string `json:"productSlug"`
MethodName string `json:"methodName"`
RequestCount int `json:"requestCount"`
LastIP *string `json:"lastIp"`
CreatedAt time.Time `json:"createdAt"`
}
func (a *Auth) ListAllSubscriptions(ctx context.Context, status string) ([]ProductSubscription, error) {
query := `
SELECT id, product_slug, tier, status, COALESCE(monthly_quota, 0), COALESCE(requests_used, 0),
requires_approval, approved_at, approved_by, notes, created_at
FROM user_product_subscriptions
`
args := []any{}
if status != "" {
query += ` WHERE status = $1`
args = append(args, status)
}
query += ` ORDER BY created_at DESC`
rows, err := a.db.Query(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to list all subscriptions: %w", err)
}
defer rows.Close()
subs := make([]ProductSubscription, 0)
for rows.Next() {
var sub ProductSubscription
var approvedAt *time.Time
var approvedBy, notes *string
if err := rows.Scan(
&sub.ID,
&sub.ProductSlug,
&sub.Tier,
&sub.Status,
&sub.MonthlyQuota,
&sub.RequestsUsed,
&sub.RequiresApproval,
&approvedAt,
&approvedBy,
&notes,
&sub.CreatedAt,
); err != nil {
return nil, fmt.Errorf("failed to scan subscription: %w", err)
}
sub.ApprovedAt = approvedAt
sub.ApprovedBy = approvedBy
sub.Notes = notes
subs = append(subs, sub)
}
return subs, nil
}
func (a *Auth) UpdateSubscriptionStatus(
ctx context.Context,
subscriptionID string,
status string,
approvedBy string,
notes string,
) (*ProductSubscription, error) {
query := `
UPDATE user_product_subscriptions
SET status = $2,
approved_at = CASE WHEN $2 = 'active' THEN NOW() ELSE approved_at END,
approved_by = CASE WHEN $2 = 'active' THEN NULLIF($3, '') ELSE approved_by END,
notes = CASE WHEN NULLIF($4, '') IS NOT NULL THEN $4 ELSE notes END,
updated_at = NOW()
WHERE id = $1
RETURNING id, product_slug, tier, status, COALESCE(monthly_quota, 0), COALESCE(requests_used, 0),
requires_approval, approved_at, approved_by, notes, created_at
`
var sub ProductSubscription
var approvedAt *time.Time
var approvedByPtr, notesPtr *string
if err := a.db.QueryRow(ctx, query, subscriptionID, status, approvedBy, notes).Scan(
&sub.ID,
&sub.ProductSlug,
&sub.Tier,
&sub.Status,
&sub.MonthlyQuota,
&sub.RequestsUsed,
&sub.RequiresApproval,
&approvedAt,
&approvedByPtr,
&notesPtr,
&sub.CreatedAt,
); err != nil {
return nil, fmt.Errorf("failed to update subscription: %w", err)
}
sub.ApprovedAt = approvedAt
sub.ApprovedBy = approvedByPtr
sub.Notes = notesPtr
return &sub, nil
}
// RegisterUser registers a new user
func (a *Auth) RegisterUser(ctx context.Context, email, username, password string) (*User, error) {
// Hash password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("failed to hash password: %w", err)
}
// Insert user
query := `
INSERT INTO users (email, username, password_hash)
VALUES ($1, $2, $3)
RETURNING id, email, username, created_at
`
var user User
err = a.db.QueryRow(ctx, query, email, username, hashedPassword).Scan(
&user.ID, &user.Email, &user.Username, &user.CreatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
}
return &user, nil
}
// AuthenticateUser authenticates a user
func (a *Auth) AuthenticateUser(ctx context.Context, email, password string) (*User, error) {
var user User
var passwordHash string
query := `SELECT id, email, username, password_hash, created_at FROM users WHERE email = $1`
err := a.db.QueryRow(ctx, query, email).Scan(
&user.ID, &user.Email, &user.Username, &passwordHash, &user.CreatedAt,
)
if err != nil {
return nil, fmt.Errorf("invalid credentials")
}
// Verify password
if err := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(password)); err != nil {
return nil, fmt.Errorf("invalid credentials")
}
_, _ = a.db.Exec(ctx, `UPDATE users SET last_login_at = NOW(), updated_at = NOW() WHERE id = $1`, user.ID)
return &user, nil
}
// GenerateAPIKey generates a new API key for a user
func (a *Auth) GenerateAPIKey(ctx context.Context, userID, name string, tier string) (string, error) {
return a.GenerateScopedAPIKey(ctx, userID, name, tier, "", nil, 0, false, 0)
}
func (a *Auth) GenerateScopedAPIKey(ctx context.Context, userID, name string, tier string, productSlug string, scopes []string, monthlyQuota int, approved bool, expiresDays int) (string, error) {
// Generate random key
keyBytes := make([]byte, 32)
if _, err := rand.Read(keyBytes); err != nil {
return "", fmt.Errorf("failed to generate key: %w", err)
}
apiKey := "ek_" + hex.EncodeToString(keyBytes)
// Hash key for storage
hashedKey := sha256.Sum256([]byte(apiKey))
hashedKeyHex := hex.EncodeToString(hashedKey[:])
// Determine rate limits based on tier
var rateLimitPerSecond, rateLimitPerMinute int
switch tier {
case "free":
rateLimitPerSecond = 5
rateLimitPerMinute = 100
case "pro":
rateLimitPerSecond = 20
rateLimitPerMinute = 1000
case "enterprise":
rateLimitPerSecond = 100
rateLimitPerMinute = 10000
default:
rateLimitPerSecond = 5
rateLimitPerMinute = 100
}
var expiresAt *time.Time
if expiresDays > 0 {
expires := time.Now().Add(time.Duration(expiresDays) * 24 * time.Hour)
expiresAt = &expires
}
// Store API key
query := `
INSERT INTO api_keys (
user_id, key_hash, name, tier, product_slug, scopes, monthly_quota,
rate_limit_per_second, rate_limit_per_minute, approved, approved_at, expires_at
)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, CASE WHEN $10 THEN NOW() ELSE NULL END, $11)
`
_, err := a.db.Exec(ctx, query, userID, hashedKeyHex, name, tier, productSlug, scopes, monthlyQuota, rateLimitPerSecond, rateLimitPerMinute, approved, expiresAt)
if err != nil {
return "", fmt.Errorf("failed to store API key: %w", err)
}
return apiKey, nil
}
// ValidateAPIKey validates an API key
func (a *Auth) ValidateAPIKey(ctx context.Context, apiKey string) (string, error) {
hashedKey := sha256.Sum256([]byte(apiKey))
hashedKeyHex := hex.EncodeToString(hashedKey[:])
var userID string
var revoked, approved bool
var expiresAt *time.Time
query := `SELECT user_id, revoked, approved, expires_at FROM api_keys WHERE key_hash = $1`
err := a.db.QueryRow(ctx, query, hashedKeyHex).Scan(&userID, &revoked, &approved, &expiresAt)
if err != nil {
return "", fmt.Errorf("invalid API key")
}
if revoked {
return "", fmt.Errorf("API key revoked")
}
if !approved {
return "", fmt.Errorf("API key pending approval")
}
if expiresAt != nil && time.Now().After(*expiresAt) {
return "", fmt.Errorf("API key expired")
}
// Update last used
a.db.Exec(ctx, `UPDATE api_keys SET last_used_at = NOW() WHERE key_hash = $1`, hashedKeyHex)
return userID, nil
}
func (a *Auth) ValidateAPIKeyDetailed(ctx context.Context, apiKey string, methodName string, requestCount int, lastIPAddress string) (*ValidatedAPIKey, error) {
hashedKey := sha256.Sum256([]byte(apiKey))
hashedKeyHex := hex.EncodeToString(hashedKey[:])
query := `
SELECT id, user_id, COALESCE(name, ''), tier, COALESCE(product_slug, ''), COALESCE(scopes, ARRAY[]::TEXT[]),
COALESCE(monthly_quota, 0), COALESCE(requests_used, 0), approved,
COALESCE(rate_limit_per_second, 0), COALESCE(rate_limit_per_minute, 0),
last_used_at, expires_at, revoked
FROM api_keys
WHERE key_hash = $1
`
var validated ValidatedAPIKey
var approved, revoked bool
var lastUsedAt, expiresAt *time.Time
if err := a.db.QueryRow(ctx, query, hashedKeyHex).Scan(
&validated.APIKeyID,
&validated.UserID,
&validated.Name,
&validated.Tier,
&validated.ProductSlug,
&validated.Scopes,
&validated.MonthlyQuota,
&validated.RequestsUsed,
&approved,
&validated.RateLimitPerSecond,
&validated.RateLimitPerMinute,
&lastUsedAt,
&expiresAt,
&revoked,
); err != nil {
return nil, fmt.Errorf("invalid API key")
}
if revoked {
return nil, fmt.Errorf("API key revoked")
}
if !approved {
return nil, fmt.Errorf("API key pending approval")
}
if expiresAt != nil && time.Now().After(*expiresAt) {
return nil, fmt.Errorf("API key expired")
}
if requestCount <= 0 {
requestCount = 1
}
_, _ = a.db.Exec(ctx, `
UPDATE api_keys
SET last_used_at = NOW(),
requests_used = COALESCE(requests_used, 0) + $2,
last_ip_address = NULLIF($3, '')::inet
WHERE key_hash = $1
`, hashedKeyHex, requestCount, lastIPAddress)
_, _ = a.db.Exec(ctx, `
INSERT INTO api_key_usage_logs (api_key_id, product_slug, method_name, request_count, window_start, window_end, last_ip_address)
VALUES ($1, NULLIF($2, ''), NULLIF($3, ''), $4, NOW(), NOW(), NULLIF($5, '')::inet)
`, validated.APIKeyID, validated.ProductSlug, methodName, requestCount, lastIPAddress)
validated.RequestsUsed += requestCount
validated.LastUsedAt = lastUsedAt
validated.ExpiresAt = expiresAt
return &validated, nil
}
func (a *Auth) ListAPIKeys(ctx context.Context, userID string) ([]APIKeyInfo, error) {
rows, err := a.db.Query(ctx, `
SELECT id, COALESCE(name, ''), tier, COALESCE(product_slug, ''), COALESCE(scopes, ARRAY[]::TEXT[]),
COALESCE(monthly_quota, 0), COALESCE(requests_used, 0), approved, approved_at,
COALESCE(rate_limit_per_second, 0), COALESCE(rate_limit_per_minute, 0),
last_used_at, expires_at, revoked, created_at
FROM api_keys
WHERE user_id = $1
ORDER BY created_at DESC
`, userID)
if err != nil {
return nil, fmt.Errorf("failed to list API keys: %w", err)
}
defer rows.Close()
keys := make([]APIKeyInfo, 0)
for rows.Next() {
var key APIKeyInfo
var lastUsedAt, expiresAt, approvedAt *time.Time
if err := rows.Scan(
&key.ID,
&key.Name,
&key.Tier,
&key.ProductSlug,
&key.Scopes,
&key.MonthlyQuota,
&key.RequestsUsed,
&key.Approved,
&approvedAt,
&key.RateLimitPerSecond,
&key.RateLimitPerMinute,
&lastUsedAt,
&expiresAt,
&key.Revoked,
&key.CreatedAt,
); err != nil {
return nil, fmt.Errorf("failed to scan API key: %w", err)
}
key.ApprovedAt = approvedAt
key.LastUsedAt = lastUsedAt
key.ExpiresAt = expiresAt
keys = append(keys, key)
}
return keys, nil
}
func (a *Auth) ListUsageLogs(ctx context.Context, userID string, limit int) ([]APIKeyUsageLog, error) {
if limit <= 0 {
limit = 20
}
rows, err := a.db.Query(ctx, `
SELECT logs.id, logs.api_key_id, COALESCE(keys.name, ''), COALESCE(logs.product_slug, ''),
COALESCE(logs.method_name, ''), logs.request_count,
CASE WHEN logs.last_ip_address IS NOT NULL THEN host(logs.last_ip_address) ELSE NULL END,
logs.created_at
FROM api_key_usage_logs logs
INNER JOIN api_keys keys ON keys.id = logs.api_key_id
WHERE keys.user_id = $1
ORDER BY logs.created_at DESC
LIMIT $2
`, userID, limit)
if err != nil {
return nil, fmt.Errorf("failed to list usage logs: %w", err)
}
defer rows.Close()
entries := make([]APIKeyUsageLog, 0)
for rows.Next() {
var entry APIKeyUsageLog
var lastIP *string
if err := rows.Scan(
&entry.ID,
&entry.APIKeyID,
&entry.KeyName,
&entry.ProductSlug,
&entry.MethodName,
&entry.RequestCount,
&lastIP,
&entry.CreatedAt,
); err != nil {
return nil, fmt.Errorf("failed to scan usage log: %w", err)
}
entry.LastIP = lastIP
entries = append(entries, entry)
}
return entries, nil
}
func (a *Auth) ListAllUsageLogs(ctx context.Context, productSlug string, limit int) ([]APIKeyUsageLog, error) {
if limit <= 0 {
limit = 50
}
query := `
SELECT logs.id, logs.api_key_id, COALESCE(keys.name, ''), COALESCE(logs.product_slug, ''),
COALESCE(logs.method_name, ''), logs.request_count,
CASE WHEN logs.last_ip_address IS NOT NULL THEN host(logs.last_ip_address) ELSE NULL END,
logs.created_at
FROM api_key_usage_logs logs
INNER JOIN api_keys keys ON keys.id = logs.api_key_id
`
args := []any{}
if productSlug != "" {
query += ` WHERE logs.product_slug = $1`
args = append(args, productSlug)
}
query += fmt.Sprintf(" ORDER BY logs.created_at DESC LIMIT $%d", len(args)+1)
args = append(args, limit)
rows, err := a.db.Query(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to list all usage logs: %w", err)
}
defer rows.Close()
entries := make([]APIKeyUsageLog, 0)
for rows.Next() {
var entry APIKeyUsageLog
var lastIP *string
if err := rows.Scan(
&entry.ID,
&entry.APIKeyID,
&entry.KeyName,
&entry.ProductSlug,
&entry.MethodName,
&entry.RequestCount,
&lastIP,
&entry.CreatedAt,
); err != nil {
return nil, fmt.Errorf("failed to scan usage log: %w", err)
}
entry.LastIP = lastIP
entries = append(entries, entry)
}
return entries, nil
}
func (a *Auth) RevokeAPIKey(ctx context.Context, userID, keyID string) error {
tag, err := a.db.Exec(ctx, `UPDATE api_keys SET revoked = true WHERE id = $1 AND user_id = $2`, keyID, userID)
if err != nil {
return fmt.Errorf("failed to revoke API key: %w", err)
}
if tag.RowsAffected() == 0 {
return fmt.Errorf("api key not found")
}
return nil
}
func (a *Auth) UpsertProductSubscription(
ctx context.Context,
userID, productSlug, tier, status string,
monthlyQuota int,
requiresApproval bool,
approvedBy string,
notes string,
) (*ProductSubscription, error) {
query := `
INSERT INTO user_product_subscriptions (
user_id, product_slug, tier, status, monthly_quota, requires_approval, approved_at, approved_by, notes
)
VALUES ($1, $2, $3, $4, $5, $6, CASE WHEN $4 = 'active' THEN NOW() ELSE NULL END, NULLIF($7, ''), NULLIF($8, ''))
ON CONFLICT (user_id, product_slug) DO UPDATE SET
tier = EXCLUDED.tier,
status = EXCLUDED.status,
monthly_quota = EXCLUDED.monthly_quota,
requires_approval = EXCLUDED.requires_approval,
approved_at = CASE WHEN EXCLUDED.status = 'active' THEN NOW() ELSE user_product_subscriptions.approved_at END,
approved_by = NULLIF(EXCLUDED.approved_by, ''),
notes = NULLIF(EXCLUDED.notes, ''),
updated_at = NOW()
RETURNING id, product_slug, tier, status, COALESCE(monthly_quota, 0), COALESCE(requests_used, 0),
requires_approval, approved_at, approved_by, notes, created_at
`
var sub ProductSubscription
var approvedAt *time.Time
var approvedByPtr, notesPtr *string
if err := a.db.QueryRow(ctx, query, userID, productSlug, tier, status, monthlyQuota, requiresApproval, approvedBy, notes).Scan(
&sub.ID,
&sub.ProductSlug,
&sub.Tier,
&sub.Status,
&sub.MonthlyQuota,
&sub.RequestsUsed,
&sub.RequiresApproval,
&approvedAt,
&approvedByPtr,
&notesPtr,
&sub.CreatedAt,
); err != nil {
return nil, fmt.Errorf("failed to save subscription: %w", err)
}
sub.ApprovedAt = approvedAt
sub.ApprovedBy = approvedByPtr
sub.Notes = notesPtr
return &sub, nil
}
func (a *Auth) ListSubscriptions(ctx context.Context, userID string) ([]ProductSubscription, error) {
rows, err := a.db.Query(ctx, `
SELECT id, product_slug, tier, status, COALESCE(monthly_quota, 0), COALESCE(requests_used, 0),
requires_approval, approved_at, approved_by, notes, created_at
FROM user_product_subscriptions
WHERE user_id = $1
ORDER BY created_at DESC
`, userID)
if err != nil {
return nil, fmt.Errorf("failed to list subscriptions: %w", err)
}
defer rows.Close()
subs := make([]ProductSubscription, 0)
for rows.Next() {
var sub ProductSubscription
var approvedAt *time.Time
var approvedBy, notes *string
if err := rows.Scan(
&sub.ID,
&sub.ProductSlug,
&sub.Tier,
&sub.Status,
&sub.MonthlyQuota,
&sub.RequestsUsed,
&sub.RequiresApproval,
&approvedAt,
&approvedBy,
&notes,
&sub.CreatedAt,
); err != nil {
return nil, fmt.Errorf("failed to scan subscription: %w", err)
}
sub.ApprovedAt = approvedAt
sub.ApprovedBy = approvedBy
sub.Notes = notes
subs = append(subs, sub)
}
return subs, nil
}