package paste import ( "log/slog" "net" "net/http" "strings" "sync" "time" "golang.org/x/time/rate" ) type ipLimiter struct { limiter *rate.Limiter lastSeen time.Time } type RateLimiter struct { mu sync.Mutex limiters map[string]*ipLimiter rate rate.Limit burst int } func NewRateLimiter(r rate.Limit, burst int) *RateLimiter { rl := &RateLimiter{ limiters: make(map[string]*ipLimiter), rate: r, burst: burst, } go rl.cleanup() return rl } func (rl *RateLimiter) get(ip string) *rate.Limiter { rl.mu.Lock() defer rl.mu.Unlock() entry, ok := rl.limiters[ip] if !ok { entry = &ipLimiter{limiter: rate.NewLimiter(rl.rate, rl.burst)} rl.limiters[ip] = entry } entry.lastSeen = time.Now() return entry.limiter } func (rl *RateLimiter) cleanup() { for range time.Tick(5 * time.Minute) { rl.mu.Lock() for ip, entry := range rl.limiters { if time.Since(entry.lastSeen) > 10*time.Minute { delete(rl.limiters, ip) } } rl.mu.Unlock() } } func (rl *RateLimiter) Middleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ip := extractIP(r) if !rl.get(ip).Allow() { slog.Warn("rate limit exceeded", "ip", ip) writeError(w, http.StatusTooManyRequests, "rate limit exceeded") return } next(w, r) } } func extractIP(r *http.Request) string { if fwd := r.Header.Get("X-Forwarded-For"); fwd != "" { parts := strings.SplitN(fwd, ",", 2) return strings.TrimSpace(parts[0]) } host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return r.RemoteAddr } return host }