Files

124 lines
3.6 KiB
Go

package middleware
import (
"context"
"fmt"
"net/http"
"strings"
"github.com/explorer/backend/auth"
"github.com/explorer/backend/featureflags"
)
// AuthMiddleware handles authentication and authorization
type AuthMiddleware struct {
walletAuth *auth.WalletAuth
}
// NewAuthMiddleware creates a new auth middleware
func NewAuthMiddleware(walletAuth *auth.WalletAuth) *AuthMiddleware {
return &AuthMiddleware{
walletAuth: walletAuth,
}
}
// RequireAuth is middleware that requires authentication
func (m *AuthMiddleware) RequireAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
address, track, err := m.extractAuth(r)
if err != nil {
writeUnauthorized(w)
return
}
// Add user context
ctx := context.WithValue(r.Context(), "user_address", address)
ctx = context.WithValue(ctx, "user_track", track)
ctx = context.WithValue(ctx, "authenticated", true)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// RequireTrack is middleware that requires a specific track level
func (m *AuthMiddleware) RequireTrack(requiredTrack int) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Extract track from context (set by RequireAuth or OptionalAuth)
track, ok := r.Context().Value("user_track").(int)
if !ok {
track = 1 // Default to Track 1 (public)
}
if !featureflags.HasAccess(track, requiredTrack) {
writeForbidden(w, requiredTrack)
return
}
next.ServeHTTP(w, r)
})
}
}
// OptionalAuth is middleware that optionally authenticates (for Track 1 endpoints)
func (m *AuthMiddleware) OptionalAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
address, track, err := m.extractAuth(r)
if err != nil {
// No auth provided, default to Track 1 (public)
ctx := context.WithValue(r.Context(), "user_address", "")
ctx = context.WithValue(ctx, "user_track", 1)
ctx = context.WithValue(ctx, "authenticated", false)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
// Auth provided, add user context
ctx := context.WithValue(r.Context(), "user_address", address)
ctx = context.WithValue(ctx, "user_track", track)
ctx = context.WithValue(ctx, "authenticated", true)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// extractAuth extracts authentication information from request
func (m *AuthMiddleware) extractAuth(r *http.Request) (string, int, error) {
// Get Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return "", 0, http.ErrMissingFile
}
// Check for Bearer token
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
return "", 0, http.ErrMissingFile
}
token := parts[1]
// Validate JWT token
address, track, err := m.walletAuth.ValidateJWT(token)
if err != nil {
return "", 0, err
}
return address, track, nil
}
// writeUnauthorized writes a 401 Unauthorized response
func writeUnauthorized(w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":{"code":"unauthorized","message":"Authentication required"}}`))
}
// writeForbidden writes a 403 Forbidden response
func writeForbidden(w http.ResponseWriter, requiredTrack int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
w.Write([]byte(`{"error":{"code":"forbidden","message":"Insufficient permissions","required_track":` + fmt.Sprintf("%d", requiredTrack) + `}}`))
}