Files
explorer-monorepo/backend/auth/wallet_auth.go

289 lines
7.6 KiB
Go

package auth
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"time"
"github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/golang-jwt/jwt/v4"
"github.com/jackc/pgx/v5/pgxpool"
)
// WalletAuth handles wallet-based authentication
type WalletAuth struct {
db *pgxpool.Pool
jwtSecret []byte
}
// NewWalletAuth creates a new wallet auth handler
func NewWalletAuth(db *pgxpool.Pool, jwtSecret []byte) *WalletAuth {
return &WalletAuth{
db: db,
jwtSecret: jwtSecret,
}
}
// NonceRequest represents a nonce request
type NonceRequest struct {
Address string `json:"address"`
}
// NonceResponse represents a nonce response
type NonceResponse struct {
Nonce string `json:"nonce"`
ExpiresAt time.Time `json:"expires_at"`
}
// WalletAuthRequest represents a wallet authentication request
type WalletAuthRequest struct {
Address string `json:"address"`
Signature string `json:"signature"`
Nonce string `json:"nonce"`
}
// WalletAuthResponse represents a wallet authentication response
type WalletAuthResponse struct {
Token string `json:"token"`
ExpiresAt time.Time `json:"expires_at"`
Track int `json:"track"`
Permissions []string `json:"permissions"`
}
// GenerateNonce generates a random nonce for wallet authentication
func (w *WalletAuth) GenerateNonce(ctx context.Context, address string) (*NonceResponse, error) {
// Validate address format
if !common.IsHexAddress(address) {
return nil, fmt.Errorf("invalid address format")
}
// Normalize address to checksum format
addr := common.HexToAddress(address)
normalizedAddr := addr.Hex()
// Generate random nonce
nonceBytes := make([]byte, 32)
if _, err := rand.Read(nonceBytes); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %w", err)
}
nonce := hex.EncodeToString(nonceBytes)
// Store nonce in database with expiration (5 minutes)
expiresAt := time.Now().Add(5 * time.Minute)
query := `
INSERT INTO wallet_nonces (address, nonce, expires_at)
VALUES ($1, $2, $3)
ON CONFLICT (address) DO UPDATE SET
nonce = EXCLUDED.nonce,
expires_at = EXCLUDED.expires_at,
created_at = NOW()
`
_, err := w.db.Exec(ctx, query, normalizedAddr, nonce, expiresAt)
if err != nil {
return nil, fmt.Errorf("failed to store nonce: %w", err)
}
return &NonceResponse{
Nonce: nonce,
ExpiresAt: expiresAt,
}, nil
}
// AuthenticateWallet authenticates a wallet using signature
func (w *WalletAuth) AuthenticateWallet(ctx context.Context, req *WalletAuthRequest) (*WalletAuthResponse, error) {
// Validate address format
if !common.IsHexAddress(req.Address) {
return nil, fmt.Errorf("invalid address format")
}
// Normalize address
addr := common.HexToAddress(req.Address)
normalizedAddr := addr.Hex()
// Verify nonce
var storedNonce string
var expiresAt time.Time
query := `SELECT nonce, expires_at FROM wallet_nonces WHERE address = $1`
err := w.db.QueryRow(ctx, query, normalizedAddr).Scan(&storedNonce, &expiresAt)
if err != nil {
return nil, fmt.Errorf("nonce not found or expired")
}
if time.Now().After(expiresAt) {
return nil, fmt.Errorf("nonce expired")
}
if storedNonce != req.Nonce {
return nil, fmt.Errorf("invalid nonce")
}
// Verify signature
message := fmt.Sprintf("Sign this message to authenticate with SolaceScanScout Explorer.\n\nNonce: %s", req.Nonce)
messageHash := accounts.TextHash([]byte(message))
sigBytes, err := hex.DecodeString(req.Signature[2:]) // Remove 0x prefix
if err != nil {
return nil, fmt.Errorf("invalid signature format: %w", err)
}
// Recover public key from signature
if sigBytes[64] >= 27 {
sigBytes[64] -= 27
}
pubKey, err := crypto.SigToPub(messageHash, sigBytes)
if err != nil {
return nil, fmt.Errorf("failed to recover public key: %w", err)
}
recoveredAddr := crypto.PubkeyToAddress(*pubKey)
if recoveredAddr.Hex() != normalizedAddr {
return nil, fmt.Errorf("signature does not match address")
}
// Get or create user and track level
track, err := w.getUserTrack(ctx, normalizedAddr)
if err != nil {
return nil, fmt.Errorf("failed to get user track: %w", err)
}
// Generate JWT token
token, expiresAt, err := w.generateJWT(normalizedAddr, track)
if err != nil {
return nil, fmt.Errorf("failed to generate token: %w", err)
}
// Delete used nonce
w.db.Exec(ctx, `DELETE FROM wallet_nonces WHERE address = $1`, normalizedAddr)
// Get permissions for track
permissions := getPermissionsForTrack(track)
return &WalletAuthResponse{
Token: token,
ExpiresAt: expiresAt,
Track: track,
Permissions: permissions,
}, nil
}
// getUserTrack gets the track level for a user address
func (w *WalletAuth) getUserTrack(ctx context.Context, address string) (int, error) {
// Check if user exists in operator_roles (Track 4)
var track int
var approved bool
query := `SELECT track_level, approved FROM operator_roles WHERE address = $1`
err := w.db.QueryRow(ctx, query, address).Scan(&track, &approved)
if err == nil && approved {
return track, nil
}
// Check if user is approved for Track 2 or 3
// For now, default to Track 1 (public)
// In production, you'd have an approval table
return 1, nil
}
// generateJWT generates a JWT token with track claim
func (w *WalletAuth) generateJWT(address string, track int) (string, time.Time, error) {
expiresAt := time.Now().Add(24 * time.Hour)
claims := jwt.MapClaims{
"address": address,
"track": track,
"exp": expiresAt.Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString(w.jwtSecret)
if err != nil {
return "", time.Time{}, fmt.Errorf("failed to sign token: %w", err)
}
return tokenString, expiresAt, nil
}
// ValidateJWT validates a JWT token and returns the address and track
func (w *WalletAuth) ValidateJWT(tokenString string) (string, int, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return w.jwtSecret, nil
})
if err != nil {
return "", 0, fmt.Errorf("failed to parse token: %w", err)
}
if !token.Valid {
return "", 0, fmt.Errorf("invalid token")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return "", 0, fmt.Errorf("invalid token claims")
}
address, ok := claims["address"].(string)
if !ok {
return "", 0, fmt.Errorf("address not found in token")
}
trackFloat, ok := claims["track"].(float64)
if !ok {
return "", 0, fmt.Errorf("track not found in token")
}
track := int(trackFloat)
return address, track, nil
}
// getPermissionsForTrack returns permissions for a track level
func getPermissionsForTrack(track int) []string {
permissions := []string{
"explorer.read.blocks",
"explorer.read.transactions",
"explorer.read.address.basic",
"explorer.read.bridge.status",
"weth.wrap",
"weth.unwrap",
}
if track >= 2 {
permissions = append(permissions,
"explorer.read.address.full",
"explorer.read.tokens",
"explorer.read.tx_history",
"explorer.read.internal_txs",
"explorer.search.enhanced",
)
}
if track >= 3 {
permissions = append(permissions,
"analytics.read.flows",
"analytics.read.bridge",
"analytics.read.token_distribution",
"analytics.read.address_risk",
)
}
if track >= 4 {
permissions = append(permissions,
"operator.read.bridge_events",
"operator.read.validators",
"operator.read.contracts",
"operator.read.protocol_state",
"operator.write.bridge_control",
)
}
return permissions
}