60 lines
1.3 KiB
Go
60 lines
1.3 KiB
Go
package safety
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
// RateLimiter implements rate limiting
|
|
type RateLimiter struct {
|
|
redis *redis.Client
|
|
}
|
|
|
|
// NewRateLimiter creates a new rate limiter
|
|
func NewRateLimiter(redisClient *redis.Client) *RateLimiter {
|
|
return &RateLimiter{
|
|
redis: redisClient,
|
|
}
|
|
}
|
|
|
|
// Check checks if a request is within rate limits
|
|
func (r *RateLimiter) Check(ctx context.Context, key string, limit int, window time.Duration) (bool, error) {
|
|
// Use sliding window log algorithm
|
|
now := time.Now()
|
|
windowStart := now.Add(-window)
|
|
|
|
// Count requests in window
|
|
count, err := r.redis.ZCount(ctx, key, fmt.Sprintf("%d", windowStart.Unix()), fmt.Sprintf("%d", now.Unix())).Result()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
if count >= int64(limit) {
|
|
return false, nil
|
|
}
|
|
|
|
// Add current request
|
|
_, err = r.redis.ZAdd(ctx, key, redis.Z{
|
|
Score: float64(now.Unix()),
|
|
Member: fmt.Sprintf("%d", now.UnixNano()),
|
|
}).Result()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
// Expire old entries
|
|
r.redis.Expire(ctx, key, window)
|
|
|
|
return true, nil
|
|
}
|
|
|
|
// CheckUser checks rate limit for a user
|
|
func (r *RateLimiter) CheckUser(ctx context.Context, tenantID, userID string, limit int, window time.Duration) (bool, error) {
|
|
key := fmt.Sprintf("ratelimit:user:%s:%s", tenantID, userID)
|
|
return r.Check(ctx, key, limit, window)
|
|
}
|
|
|