package middleware import ( "context" "fmt" "net/http" "strings" "github.com/explorer/backend/auth" "github.com/explorer/backend/featureflags" ) // AuthMiddleware handles authentication and authorization type AuthMiddleware struct { walletAuth *auth.WalletAuth } // NewAuthMiddleware creates a new auth middleware func NewAuthMiddleware(walletAuth *auth.WalletAuth) *AuthMiddleware { return &AuthMiddleware{ walletAuth: walletAuth, } } // RequireAuth is middleware that requires authentication func (m *AuthMiddleware) RequireAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { address, track, err := m.extractAuth(r) if err != nil { writeUnauthorized(w) return } // Add user context ctx := context.WithValue(r.Context(), "user_address", address) ctx = context.WithValue(ctx, "user_track", track) ctx = context.WithValue(ctx, "authenticated", true) next.ServeHTTP(w, r.WithContext(ctx)) }) } // RequireTrack is middleware that requires a specific track level func (m *AuthMiddleware) RequireTrack(requiredTrack int) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Extract track from context (set by RequireAuth or OptionalAuth) track, ok := r.Context().Value("user_track").(int) if !ok { track = 1 // Default to Track 1 (public) } if !featureflags.HasAccess(track, requiredTrack) { writeForbidden(w, requiredTrack) return } next.ServeHTTP(w, r) }) } } // OptionalAuth is middleware that optionally authenticates (for Track 1 endpoints) func (m *AuthMiddleware) OptionalAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { address, track, err := m.extractAuth(r) if err != nil { // No auth provided, default to Track 1 (public) ctx := context.WithValue(r.Context(), "user_address", "") ctx = context.WithValue(ctx, "user_track", 1) ctx = context.WithValue(ctx, "authenticated", false) next.ServeHTTP(w, r.WithContext(ctx)) return } // Auth provided, add user context ctx := context.WithValue(r.Context(), "user_address", address) ctx = context.WithValue(ctx, "user_track", track) ctx = context.WithValue(ctx, "authenticated", true) next.ServeHTTP(w, r.WithContext(ctx)) }) } // extractAuth extracts authentication information from request func (m *AuthMiddleware) extractAuth(r *http.Request) (string, int, error) { // Get Authorization header authHeader := r.Header.Get("Authorization") if authHeader == "" { return "", 0, http.ErrMissingFile } // Check for Bearer token parts := strings.Split(authHeader, " ") if len(parts) != 2 || parts[0] != "Bearer" { return "", 0, http.ErrMissingFile } token := parts[1] // Validate JWT token address, track, err := m.walletAuth.ValidateJWT(token) if err != nil { return "", 0, err } return address, track, nil } // writeUnauthorized writes a 401 Unauthorized response func writeUnauthorized(w http.ResponseWriter) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(`{"error":{"code":"unauthorized","message":"Authentication required"}}`)) } // writeForbidden writes a 403 Forbidden response func writeForbidden(w http.ResponseWriter, requiredTrack int) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusForbidden) w.Write([]byte(`{"error":{"code":"forbidden","message":"Insufficient permissions","required_track":` + fmt.Sprintf("%d", requiredTrack) + `}}`)) }