289 lines
7.6 KiB
Go
289 lines
7.6 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/ethereum/go-ethereum/accounts"
|
|
"github.com/ethereum/go-ethereum/common"
|
|
"github.com/ethereum/go-ethereum/crypto"
|
|
"github.com/golang-jwt/jwt/v4"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
// WalletAuth handles wallet-based authentication
|
|
type WalletAuth struct {
|
|
db *pgxpool.Pool
|
|
jwtSecret []byte
|
|
}
|
|
|
|
// NewWalletAuth creates a new wallet auth handler
|
|
func NewWalletAuth(db *pgxpool.Pool, jwtSecret []byte) *WalletAuth {
|
|
return &WalletAuth{
|
|
db: db,
|
|
jwtSecret: jwtSecret,
|
|
}
|
|
}
|
|
|
|
// NonceRequest represents a nonce request
|
|
type NonceRequest struct {
|
|
Address string `json:"address"`
|
|
}
|
|
|
|
// NonceResponse represents a nonce response
|
|
type NonceResponse struct {
|
|
Nonce string `json:"nonce"`
|
|
ExpiresAt time.Time `json:"expires_at"`
|
|
}
|
|
|
|
// WalletAuthRequest represents a wallet authentication request
|
|
type WalletAuthRequest struct {
|
|
Address string `json:"address"`
|
|
Signature string `json:"signature"`
|
|
Nonce string `json:"nonce"`
|
|
}
|
|
|
|
// WalletAuthResponse represents a wallet authentication response
|
|
type WalletAuthResponse struct {
|
|
Token string `json:"token"`
|
|
ExpiresAt time.Time `json:"expires_at"`
|
|
Track int `json:"track"`
|
|
Permissions []string `json:"permissions"`
|
|
}
|
|
|
|
// GenerateNonce generates a random nonce for wallet authentication
|
|
func (w *WalletAuth) GenerateNonce(ctx context.Context, address string) (*NonceResponse, error) {
|
|
// Validate address format
|
|
if !common.IsHexAddress(address) {
|
|
return nil, fmt.Errorf("invalid address format")
|
|
}
|
|
|
|
// Normalize address to checksum format
|
|
addr := common.HexToAddress(address)
|
|
normalizedAddr := addr.Hex()
|
|
|
|
// Generate random nonce
|
|
nonceBytes := make([]byte, 32)
|
|
if _, err := rand.Read(nonceBytes); err != nil {
|
|
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
|
}
|
|
nonce := hex.EncodeToString(nonceBytes)
|
|
|
|
// Store nonce in database with expiration (5 minutes)
|
|
expiresAt := time.Now().Add(5 * time.Minute)
|
|
query := `
|
|
INSERT INTO wallet_nonces (address, nonce, expires_at)
|
|
VALUES ($1, $2, $3)
|
|
ON CONFLICT (address) DO UPDATE SET
|
|
nonce = EXCLUDED.nonce,
|
|
expires_at = EXCLUDED.expires_at,
|
|
created_at = NOW()
|
|
`
|
|
_, err := w.db.Exec(ctx, query, normalizedAddr, nonce, expiresAt)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to store nonce: %w", err)
|
|
}
|
|
|
|
return &NonceResponse{
|
|
Nonce: nonce,
|
|
ExpiresAt: expiresAt,
|
|
}, nil
|
|
}
|
|
|
|
// AuthenticateWallet authenticates a wallet using signature
|
|
func (w *WalletAuth) AuthenticateWallet(ctx context.Context, req *WalletAuthRequest) (*WalletAuthResponse, error) {
|
|
// Validate address format
|
|
if !common.IsHexAddress(req.Address) {
|
|
return nil, fmt.Errorf("invalid address format")
|
|
}
|
|
|
|
// Normalize address
|
|
addr := common.HexToAddress(req.Address)
|
|
normalizedAddr := addr.Hex()
|
|
|
|
// Verify nonce
|
|
var storedNonce string
|
|
var expiresAt time.Time
|
|
query := `SELECT nonce, expires_at FROM wallet_nonces WHERE address = $1`
|
|
err := w.db.QueryRow(ctx, query, normalizedAddr).Scan(&storedNonce, &expiresAt)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("nonce not found or expired")
|
|
}
|
|
|
|
if time.Now().After(expiresAt) {
|
|
return nil, fmt.Errorf("nonce expired")
|
|
}
|
|
|
|
if storedNonce != req.Nonce {
|
|
return nil, fmt.Errorf("invalid nonce")
|
|
}
|
|
|
|
// Verify signature
|
|
message := fmt.Sprintf("Sign this message to authenticate with SolaceScanScout Explorer.\n\nNonce: %s", req.Nonce)
|
|
messageHash := accounts.TextHash([]byte(message))
|
|
|
|
sigBytes, err := hex.DecodeString(req.Signature[2:]) // Remove 0x prefix
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid signature format: %w", err)
|
|
}
|
|
|
|
// Recover public key from signature
|
|
if sigBytes[64] >= 27 {
|
|
sigBytes[64] -= 27
|
|
}
|
|
|
|
pubKey, err := crypto.SigToPub(messageHash, sigBytes)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to recover public key: %w", err)
|
|
}
|
|
|
|
recoveredAddr := crypto.PubkeyToAddress(*pubKey)
|
|
if recoveredAddr.Hex() != normalizedAddr {
|
|
return nil, fmt.Errorf("signature does not match address")
|
|
}
|
|
|
|
// Get or create user and track level
|
|
track, err := w.getUserTrack(ctx, normalizedAddr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get user track: %w", err)
|
|
}
|
|
|
|
// Generate JWT token
|
|
token, expiresAt, err := w.generateJWT(normalizedAddr, track)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate token: %w", err)
|
|
}
|
|
|
|
// Delete used nonce
|
|
w.db.Exec(ctx, `DELETE FROM wallet_nonces WHERE address = $1`, normalizedAddr)
|
|
|
|
// Get permissions for track
|
|
permissions := getPermissionsForTrack(track)
|
|
|
|
return &WalletAuthResponse{
|
|
Token: token,
|
|
ExpiresAt: expiresAt,
|
|
Track: track,
|
|
Permissions: permissions,
|
|
}, nil
|
|
}
|
|
|
|
// getUserTrack gets the track level for a user address
|
|
func (w *WalletAuth) getUserTrack(ctx context.Context, address string) (int, error) {
|
|
// Check if user exists in operator_roles (Track 4)
|
|
var track int
|
|
var approved bool
|
|
query := `SELECT track_level, approved FROM operator_roles WHERE address = $1`
|
|
err := w.db.QueryRow(ctx, query, address).Scan(&track, &approved)
|
|
if err == nil && approved {
|
|
return track, nil
|
|
}
|
|
|
|
// Check if user is approved for Track 2 or 3
|
|
// For now, default to Track 1 (public)
|
|
// In production, you'd have an approval table
|
|
return 1, nil
|
|
}
|
|
|
|
// generateJWT generates a JWT token with track claim
|
|
func (w *WalletAuth) generateJWT(address string, track int) (string, time.Time, error) {
|
|
expiresAt := time.Now().Add(24 * time.Hour)
|
|
|
|
claims := jwt.MapClaims{
|
|
"address": address,
|
|
"track": track,
|
|
"exp": expiresAt.Unix(),
|
|
"iat": time.Now().Unix(),
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
tokenString, err := token.SignedString(w.jwtSecret)
|
|
if err != nil {
|
|
return "", time.Time{}, fmt.Errorf("failed to sign token: %w", err)
|
|
}
|
|
|
|
return tokenString, expiresAt, nil
|
|
}
|
|
|
|
// ValidateJWT validates a JWT token and returns the address and track
|
|
func (w *WalletAuth) ValidateJWT(tokenString string) (string, int, error) {
|
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return w.jwtSecret, nil
|
|
})
|
|
|
|
if err != nil {
|
|
return "", 0, fmt.Errorf("failed to parse token: %w", err)
|
|
}
|
|
|
|
if !token.Valid {
|
|
return "", 0, fmt.Errorf("invalid token")
|
|
}
|
|
|
|
claims, ok := token.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
return "", 0, fmt.Errorf("invalid token claims")
|
|
}
|
|
|
|
address, ok := claims["address"].(string)
|
|
if !ok {
|
|
return "", 0, fmt.Errorf("address not found in token")
|
|
}
|
|
|
|
trackFloat, ok := claims["track"].(float64)
|
|
if !ok {
|
|
return "", 0, fmt.Errorf("track not found in token")
|
|
}
|
|
|
|
track := int(trackFloat)
|
|
return address, track, nil
|
|
}
|
|
|
|
// getPermissionsForTrack returns permissions for a track level
|
|
func getPermissionsForTrack(track int) []string {
|
|
permissions := []string{
|
|
"explorer.read.blocks",
|
|
"explorer.read.transactions",
|
|
"explorer.read.address.basic",
|
|
"explorer.read.bridge.status",
|
|
"weth.wrap",
|
|
"weth.unwrap",
|
|
}
|
|
|
|
if track >= 2 {
|
|
permissions = append(permissions,
|
|
"explorer.read.address.full",
|
|
"explorer.read.tokens",
|
|
"explorer.read.tx_history",
|
|
"explorer.read.internal_txs",
|
|
"explorer.search.enhanced",
|
|
)
|
|
}
|
|
|
|
if track >= 3 {
|
|
permissions = append(permissions,
|
|
"analytics.read.flows",
|
|
"analytics.read.bridge",
|
|
"analytics.read.token_distribution",
|
|
"analytics.read.address_risk",
|
|
)
|
|
}
|
|
|
|
if track >= 4 {
|
|
permissions = append(permissions,
|
|
"operator.read.bridge_events",
|
|
"operator.read.validators",
|
|
"operator.read.contracts",
|
|
"operator.read.protocol_state",
|
|
"operator.write.bridge_control",
|
|
)
|
|
}
|
|
|
|
return permissions
|
|
}
|