124 lines
3.6 KiB
Go
124 lines
3.6 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/explorer/backend/auth"
|
|
"github.com/explorer/backend/featureflags"
|
|
)
|
|
|
|
// AuthMiddleware handles authentication and authorization
|
|
type AuthMiddleware struct {
|
|
walletAuth *auth.WalletAuth
|
|
}
|
|
|
|
// NewAuthMiddleware creates a new auth middleware
|
|
func NewAuthMiddleware(walletAuth *auth.WalletAuth) *AuthMiddleware {
|
|
return &AuthMiddleware{
|
|
walletAuth: walletAuth,
|
|
}
|
|
}
|
|
|
|
// RequireAuth is middleware that requires authentication
|
|
func (m *AuthMiddleware) RequireAuth(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
address, track, err := m.extractAuth(r)
|
|
if err != nil {
|
|
writeUnauthorized(w)
|
|
return
|
|
}
|
|
|
|
// Add user context
|
|
ctx := context.WithValue(r.Context(), "user_address", address)
|
|
ctx = context.WithValue(ctx, "user_track", track)
|
|
ctx = context.WithValue(ctx, "authenticated", true)
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
// RequireTrack is middleware that requires a specific track level
|
|
func (m *AuthMiddleware) RequireTrack(requiredTrack int) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Extract track from context (set by RequireAuth or OptionalAuth)
|
|
track, ok := r.Context().Value("user_track").(int)
|
|
if !ok {
|
|
track = 1 // Default to Track 1 (public)
|
|
}
|
|
|
|
if !featureflags.HasAccess(track, requiredTrack) {
|
|
writeForbidden(w, requiredTrack)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// OptionalAuth is middleware that optionally authenticates (for Track 1 endpoints)
|
|
func (m *AuthMiddleware) OptionalAuth(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
address, track, err := m.extractAuth(r)
|
|
if err != nil {
|
|
// No auth provided, default to Track 1 (public)
|
|
ctx := context.WithValue(r.Context(), "user_address", "")
|
|
ctx = context.WithValue(ctx, "user_track", 1)
|
|
ctx = context.WithValue(ctx, "authenticated", false)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
|
|
// Auth provided, add user context
|
|
ctx := context.WithValue(r.Context(), "user_address", address)
|
|
ctx = context.WithValue(ctx, "user_track", track)
|
|
ctx = context.WithValue(ctx, "authenticated", true)
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
// extractAuth extracts authentication information from request
|
|
func (m *AuthMiddleware) extractAuth(r *http.Request) (string, int, error) {
|
|
// Get Authorization header
|
|
authHeader := r.Header.Get("Authorization")
|
|
if authHeader == "" {
|
|
return "", 0, http.ErrMissingFile
|
|
}
|
|
|
|
// Check for Bearer token
|
|
parts := strings.Split(authHeader, " ")
|
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
|
return "", 0, http.ErrMissingFile
|
|
}
|
|
|
|
token := parts[1]
|
|
|
|
// Validate JWT token
|
|
address, track, err := m.walletAuth.ValidateJWT(token)
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
|
|
return address, track, nil
|
|
}
|
|
|
|
// writeUnauthorized writes a 401 Unauthorized response
|
|
func writeUnauthorized(w http.ResponseWriter) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
w.Write([]byte(`{"error":{"code":"unauthorized","message":"Authentication required"}}`))
|
|
}
|
|
|
|
// writeForbidden writes a 403 Forbidden response
|
|
func writeForbidden(w http.ResponseWriter, requiredTrack int) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusForbidden)
|
|
w.Write([]byte(`{"error":{"code":"forbidden","message":"Insufficient permissions","required_track":` + fmt.Sprintf("%d", requiredTrack) + `}}`))
|
|
}
|
|
|