Files
virtual-banker/backend/safety/rate_limit.go

60 lines
1.3 KiB
Go

package safety
import (
"context"
"fmt"
"time"
"github.com/redis/go-redis/v9"
)
// RateLimiter implements rate limiting
type RateLimiter struct {
redis *redis.Client
}
// NewRateLimiter creates a new rate limiter
func NewRateLimiter(redisClient *redis.Client) *RateLimiter {
return &RateLimiter{
redis: redisClient,
}
}
// Check checks if a request is within rate limits
func (r *RateLimiter) Check(ctx context.Context, key string, limit int, window time.Duration) (bool, error) {
// Use sliding window log algorithm
now := time.Now()
windowStart := now.Add(-window)
// Count requests in window
count, err := r.redis.ZCount(ctx, key, fmt.Sprintf("%d", windowStart.Unix()), fmt.Sprintf("%d", now.Unix())).Result()
if err != nil {
return false, err
}
if count >= int64(limit) {
return false, nil
}
// Add current request
_, err = r.redis.ZAdd(ctx, key, redis.Z{
Score: float64(now.Unix()),
Member: fmt.Sprintf("%d", now.UnixNano()),
}).Result()
if err != nil {
return false, err
}
// Expire old entries
r.redis.Expire(ctx, key, window)
return true, nil
}
// CheckUser checks rate limit for a user
func (r *RateLimiter) CheckUser(ctx context.Context, tenantID, userID string, limit int, window time.Duration) (bool, error) {
key := fmt.Sprintf("ratelimit:user:%s:%s", tenantID, userID)
return r.Check(ctx, key, limit, window)
}