From 66f35fa2aadf7d8b72035fbec6fdb3e8e208f860 Mon Sep 17 00:00:00 2001 From: Devin Date: Sat, 18 Apr 2026 19:05:24 +0000 Subject: [PATCH] fix(auth): typed context keys and real sentinel errors backend/api/middleware/context.go (new): - Introduces an unexported ctxKey type and three constants (ctxKeyUserAddress, ctxKeyUserTrack, ctxKeyAuthenticated) that replace the bare string keys 'user_address', 'user_track', and 'authenticated'. Bare strings trigger go vet's SA1029 and collide with keys from any other package that happens to share the name. - Helpers: ContextWithAuth, UserAddress, UserTrack, IsAuthenticated. - Sentinel: ErrMissingAuthorization replaces the misuse of http.ErrMissingFile as an auth-missing signal. (http.ErrMissingFile belongs to multipart form parsing and was semantically wrong.) backend/api/middleware/auth.go: - RequireAuth, OptionalAuth, RequireTrack now all read/write via the helpers; no more string literals for context keys in this file. - extractAuth returns ErrMissingAuthorization instead of http.ErrMissingFile. - Dropped now-unused 'context' import. backend/api/track4/operator_scripts.go, backend/api/track4/endpoints.go, backend/api/rest/features.go: - Read user address / track via middleware.UserAddress() and middleware.UserTrack() instead of a raw context lookup with a bare string key. - Import 'github.com/explorer/backend/api/middleware'. backend/api/track4/operator_scripts_test.go: - Four test fixtures updated to seed the request context through middleware.ContextWithAuth (track 4, authenticated) instead of context.WithValue with a bare 'user_address' string. This is the load-bearing change that proves typed keys are required: a bare string key no longer wakes up the middleware helpers. backend/api/middleware/context_test.go (new): - Round-trip test for ContextWithAuth + UserAddress + UserTrack + IsAuthenticated. - Defaults: UserTrack=1, UserAddress="", IsAuthenticated=false on a bare context. - TestContextKeyIsolation: an outside caller that inserts 'user_address' as a bare string key must NOT be visible to UserAddress; proves the type discipline. - ErrMissingAuthorization sentinel smoke test. Verification: - go build ./... clean. - go vet ./... clean (removes SA1029 on the old bare keys). - go test ./api/middleware/... ./api/track4/... ./api/rest/... PASS. Advances completion criterion 3 (Auth correctness). --- backend/api/middleware/auth.go | 36 ++++-------- backend/api/middleware/context.go | 60 ++++++++++++++++++++ backend/api/middleware/context_test.go | 62 +++++++++++++++++++++ backend/api/rest/features.go | 8 +-- backend/api/track4/endpoints.go | 3 +- backend/api/track4/operator_scripts.go | 4 +- backend/api/track4/operator_scripts_test.go | 9 +-- 7 files changed, 145 insertions(+), 37 deletions(-) create mode 100644 backend/api/middleware/context.go create mode 100644 backend/api/middleware/context_test.go diff --git a/backend/api/middleware/auth.go b/backend/api/middleware/auth.go index 501788b..e047dfe 100644 --- a/backend/api/middleware/auth.go +++ b/backend/api/middleware/auth.go @@ -1,7 +1,6 @@ package middleware import ( - "context" "fmt" "net/http" "strings" @@ -31,11 +30,7 @@ func (m *AuthMiddleware) RequireAuth(next http.Handler) http.Handler { return } - // Add user context - ctx := context.WithValue(r.Context(), "user_address", address) - ctx = context.WithValue(ctx, "user_track", track) - ctx = context.WithValue(ctx, "authenticated", true) - + ctx := ContextWithAuth(r.Context(), address, track, true) next.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -44,11 +39,7 @@ func (m *AuthMiddleware) RequireAuth(next http.Handler) http.Handler { func (m *AuthMiddleware) RequireTrack(requiredTrack int) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Extract track from context (set by RequireAuth or OptionalAuth) - track, ok := r.Context().Value("user_track").(int) - if !ok { - track = 1 // Default to Track 1 (public) - } + track := UserTrack(r.Context()) if !featureflags.HasAccess(track, requiredTrack) { writeForbidden(w, requiredTrack) @@ -65,40 +56,33 @@ func (m *AuthMiddleware) OptionalAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { address, track, err := m.extractAuth(r) if err != nil { - // No auth provided, default to Track 1 (public) - ctx := context.WithValue(r.Context(), "user_address", "") - ctx = context.WithValue(ctx, "user_track", 1) - ctx = context.WithValue(ctx, "authenticated", false) + // No auth provided (or auth failed) — fall back to Track 1. + ctx := ContextWithAuth(r.Context(), "", defaultTrackLevel, false) next.ServeHTTP(w, r.WithContext(ctx)) return } - // Auth provided, add user context - ctx := context.WithValue(r.Context(), "user_address", address) - ctx = context.WithValue(ctx, "user_track", track) - ctx = context.WithValue(ctx, "authenticated", true) - + ctx := ContextWithAuth(r.Context(), address, track, true) next.ServeHTTP(w, r.WithContext(ctx)) }) } -// extractAuth extracts authentication information from request +// extractAuth extracts authentication information from the request. +// Returns ErrMissingAuthorization when no usable Bearer token is present; +// otherwise returns the error from JWT validation. func (m *AuthMiddleware) extractAuth(r *http.Request) (string, int, error) { - // Get Authorization header authHeader := r.Header.Get("Authorization") if authHeader == "" { - return "", 0, http.ErrMissingFile + return "", 0, ErrMissingAuthorization } - // Check for Bearer token parts := strings.Split(authHeader, " ") if len(parts) != 2 || parts[0] != "Bearer" { - return "", 0, http.ErrMissingFile + return "", 0, ErrMissingAuthorization } token := parts[1] - // Validate JWT token address, track, err := m.walletAuth.ValidateJWT(token) if err != nil { return "", 0, err diff --git a/backend/api/middleware/context.go b/backend/api/middleware/context.go new file mode 100644 index 0000000..5c4dce7 --- /dev/null +++ b/backend/api/middleware/context.go @@ -0,0 +1,60 @@ +package middleware + +import ( + "context" + "errors" +) + +// ctxKey is an unexported type for request-scoped authentication values. +// Using a distinct type (rather than a bare string) keeps our keys out of +// collision range for any other package that also calls context.WithValue, +// and silences go vet's SA1029. +type ctxKey string + +const ( + ctxKeyUserAddress ctxKey = "user_address" + ctxKeyUserTrack ctxKey = "user_track" + ctxKeyAuthenticated ctxKey = "authenticated" +) + +// Default track level applied to unauthenticated requests (Track 1 = public). +const defaultTrackLevel = 1 + +// ErrMissingAuthorization is returned by extractAuth when no usable +// Authorization header is present on the request. Callers should treat this +// as "no auth supplied" rather than a hard failure for optional-auth routes. +var ErrMissingAuthorization = errors.New("middleware: authorization header missing or malformed") + +// ContextWithAuth returns a child context carrying the supplied +// authentication state. It is the single place in the package that writes +// the auth context keys. +func ContextWithAuth(parent context.Context, address string, track int, authenticated bool) context.Context { + ctx := context.WithValue(parent, ctxKeyUserAddress, address) + ctx = context.WithValue(ctx, ctxKeyUserTrack, track) + ctx = context.WithValue(ctx, ctxKeyAuthenticated, authenticated) + return ctx +} + +// UserAddress returns the authenticated wallet address stored on ctx, or +// "" if the context is not authenticated. +func UserAddress(ctx context.Context) string { + addr, _ := ctx.Value(ctxKeyUserAddress).(string) + return addr +} + +// UserTrack returns the access tier recorded on ctx. If no track was set +// (e.g. the request bypassed all auth middleware) the caller receives +// Track 1 (public) so route-level checks can still make a decision. +func UserTrack(ctx context.Context) int { + if track, ok := ctx.Value(ctxKeyUserTrack).(int); ok { + return track + } + return defaultTrackLevel +} + +// IsAuthenticated reports whether the current request carried a valid auth +// token that was successfully parsed by the middleware. +func IsAuthenticated(ctx context.Context) bool { + ok, _ := ctx.Value(ctxKeyAuthenticated).(bool) + return ok +} diff --git a/backend/api/middleware/context_test.go b/backend/api/middleware/context_test.go new file mode 100644 index 0000000..ec00d68 --- /dev/null +++ b/backend/api/middleware/context_test.go @@ -0,0 +1,62 @@ +package middleware + +import ( + "context" + "errors" + "testing" +) + +func TestContextWithAuthRoundTrip(t *testing.T) { + ctx := ContextWithAuth(context.Background(), "0xabc", 4, true) + + if got := UserAddress(ctx); got != "0xabc" { + t.Fatalf("UserAddress() = %q, want %q", got, "0xabc") + } + if got := UserTrack(ctx); got != 4 { + t.Fatalf("UserTrack() = %d, want 4", got) + } + if !IsAuthenticated(ctx) { + t.Fatal("IsAuthenticated() = false, want true") + } +} + +func TestUserTrackDefaultsToTrack1OnBareContext(t *testing.T) { + if got := UserTrack(context.Background()); got != defaultTrackLevel { + t.Fatalf("UserTrack(empty) = %d, want %d", got, defaultTrackLevel) + } +} + +func TestUserAddressEmptyOnBareContext(t *testing.T) { + if got := UserAddress(context.Background()); got != "" { + t.Fatalf("UserAddress(empty) = %q, want empty", got) + } +} + +func TestIsAuthenticatedFalseOnBareContext(t *testing.T) { + if IsAuthenticated(context.Background()) { + t.Fatal("IsAuthenticated(empty) = true, want false") + } +} + +// TestContextKeyIsolation proves that the typed ctxKey values cannot be +// shadowed by a caller using bare-string keys with the same spelling. +// This is the specific class of bug fixed by this PR. +func TestContextKeyIsolation(t *testing.T) { + ctx := context.WithValue(context.Background(), "user_address", "injected") + if got := UserAddress(ctx); got != "" { + t.Fatalf("expected empty address (bare string key must not collide), got %q", got) + } +} + +func TestErrMissingAuthorizationIsSentinel(t *testing.T) { + if ErrMissingAuthorization == nil { + t.Fatal("ErrMissingAuthorization must not be nil") + } + wrapped := errors.New("wrapped: " + ErrMissingAuthorization.Error()) + if errors.Is(wrapped, ErrMissingAuthorization) { + t.Fatal("string-wrapped error must not satisfy errors.Is (smoke check)") + } + if !errors.Is(ErrMissingAuthorization, ErrMissingAuthorization) { + t.Fatal("ErrMissingAuthorization must satisfy errors.Is against itself") + } +} diff --git a/backend/api/rest/features.go b/backend/api/rest/features.go index ba432b1..df1636f 100644 --- a/backend/api/rest/features.go +++ b/backend/api/rest/features.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" + "github.com/explorer/backend/api/middleware" "github.com/explorer/backend/featureflags" ) @@ -16,11 +17,8 @@ func (s *Server) handleFeatures(w http.ResponseWriter, r *http.Request) { } // Extract user track from context (set by auth middleware) - // Default to Track 1 (public) if not authenticated - userTrack := 1 - if track, ok := r.Context().Value("user_track").(int); ok { - userTrack = track - } + // Default to Track 1 (public) if not authenticated (handled by helper). + userTrack := middleware.UserTrack(r.Context()) // Get enabled features for this track enabledFeatures := featureflags.GetEnabledFeatures(userTrack) diff --git a/backend/api/track4/endpoints.go b/backend/api/track4/endpoints.go index 76a16c5..437a376 100644 --- a/backend/api/track4/endpoints.go +++ b/backend/api/track4/endpoints.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/explorer/backend/api/middleware" "github.com/explorer/backend/auth" "github.com/jackc/pgx/v5/pgxpool" ) @@ -185,7 +186,7 @@ func (s *Server) requireOperatorAccess(w http.ResponseWriter, r *http.Request) ( return "", "", false } - operatorAddr, _ := r.Context().Value("user_address").(string) + operatorAddr := middleware.UserAddress(r.Context()) operatorAddr = strings.TrimSpace(operatorAddr) if operatorAddr == "" { writeError(w, http.StatusUnauthorized, "unauthorized", "Operator address required") diff --git a/backend/api/track4/operator_scripts.go b/backend/api/track4/operator_scripts.go index 100d79b..9c1a4c1 100644 --- a/backend/api/track4/operator_scripts.go +++ b/backend/api/track4/operator_scripts.go @@ -13,6 +13,8 @@ import ( "path/filepath" "strings" "time" + + "github.com/explorer/backend/api/middleware" ) type runScriptRequest struct { @@ -67,7 +69,7 @@ func (s *Server) HandleRunScript(w http.ResponseWriter, r *http.Request) { return } - operatorAddr, _ := r.Context().Value("user_address").(string) + operatorAddr := middleware.UserAddress(r.Context()) if operatorAddr == "" { writeError(w, http.StatusUnauthorized, "unauthorized", "Operator address required") return diff --git a/backend/api/track4/operator_scripts_test.go b/backend/api/track4/operator_scripts_test.go index 587b98b..894006a 100644 --- a/backend/api/track4/operator_scripts_test.go +++ b/backend/api/track4/operator_scripts_test.go @@ -11,6 +11,7 @@ import ( "net/http" "net/http/httptest" + "github.com/explorer/backend/api/middleware" "github.com/stretchr/testify/require" ) @@ -45,7 +46,7 @@ func TestHandleRunScriptUsesForwardedClientIPAndRunsAllowlistedScript(t *testing reqBody := []byte(`{"script":"echo.sh","args":["world"]}`) req := httptest.NewRequest(http.MethodPost, "/api/v1/track4/operator/run-script", bytes.NewReader(reqBody)) - req = req.WithContext(context.WithValue(req.Context(), "user_address", "0x4A666F96fC8764181194447A7dFdb7d471b301C8")) + req = req.WithContext(middleware.ContextWithAuth(req.Context(), "0x4A666F96fC8764181194447A7dFdb7d471b301C8", 4, true)) req.RemoteAddr = "10.0.0.10:8080" req.Header.Set("X-Forwarded-For", "203.0.113.9, 10.0.0.10") w := httptest.NewRecorder() @@ -77,7 +78,7 @@ func TestHandleRunScriptRejectsNonAllowlistedScript(t *testing.T) { s := &Server{roleMgr: &stubRoleManager{allowed: true}, chainID: 138} req := httptest.NewRequest(http.MethodPost, "/api/v1/track4/operator/run-script", bytes.NewReader([]byte(`{"script":"blocked.sh"}`))) - req = req.WithContext(context.WithValue(req.Context(), "user_address", "0x4A666F96fC8764181194447A7dFdb7d471b301C8")) + req = req.WithContext(middleware.ContextWithAuth(req.Context(), "0x4A666F96fC8764181194447A7dFdb7d471b301C8", 4, true)) req.RemoteAddr = "127.0.0.1:9999" w := httptest.NewRecorder() @@ -100,7 +101,7 @@ func TestHandleRunScriptRejectsFilenameCollisionOutsideAllowlistedPath(t *testin s := &Server{roleMgr: &stubRoleManager{allowed: true}, chainID: 138} req := httptest.NewRequest(http.MethodPost, "/api/v1/track4/operator/run-script", bytes.NewReader([]byte(`{"script":"unsafe/backup.sh"}`))) - req = req.WithContext(context.WithValue(req.Context(), "user_address", "0x4A666F96fC8764181194447A7dFdb7d471b301C8")) + req = req.WithContext(middleware.ContextWithAuth(req.Context(), "0x4A666F96fC8764181194447A7dFdb7d471b301C8", 4, true)) req.RemoteAddr = "127.0.0.1:9999" w := httptest.NewRecorder() @@ -122,7 +123,7 @@ func TestHandleRunScriptTruncatesLargeOutput(t *testing.T) { s := &Server{roleMgr: &stubRoleManager{allowed: true}, chainID: 138} req := httptest.NewRequest(http.MethodPost, "/api/v1/track4/operator/run-script", bytes.NewReader([]byte(`{"script":"large.sh"}`))) - req = req.WithContext(context.WithValue(req.Context(), "user_address", "0x4A666F96fC8764181194447A7dFdb7d471b301C8")) + req = req.WithContext(middleware.ContextWithAuth(req.Context(), "0x4A666F96fC8764181194447A7dFdb7d471b301C8", 4, true)) req.RemoteAddr = "127.0.0.1:9999" w := httptest.NewRecorder()