add paste functionality
Some checks failed
Build and Deploy Website / build (push) Failing after 2m46s
Some checks failed
Build and Deploy Website / build (push) Failing after 2m46s
This commit is contained in:
82
Backend/paste/ratelimit.go
Normal file
82
Backend/paste/ratelimit.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package paste
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type ipLimiter struct {
|
||||
limiter *rate.Limiter
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
limiters map[string]*ipLimiter
|
||||
rate rate.Limit
|
||||
burst int
|
||||
}
|
||||
|
||||
func NewRateLimiter(r rate.Limit, burst int) *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
limiters: make(map[string]*ipLimiter),
|
||||
rate: r,
|
||||
burst: burst,
|
||||
}
|
||||
go rl.cleanup()
|
||||
return rl
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) get(ip string) *rate.Limiter {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
entry, ok := rl.limiters[ip]
|
||||
if !ok {
|
||||
entry = &ipLimiter{limiter: rate.NewLimiter(rl.rate, rl.burst)}
|
||||
rl.limiters[ip] = entry
|
||||
}
|
||||
entry.lastSeen = time.Now()
|
||||
return entry.limiter
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
for range time.Tick(5 * time.Minute) {
|
||||
rl.mu.Lock()
|
||||
for ip, entry := range rl.limiters {
|
||||
if time.Since(entry.lastSeen) > 10*time.Minute {
|
||||
delete(rl.limiters, ip)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) Middleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ip := extractIP(r)
|
||||
if !rl.get(ip).Allow() {
|
||||
slog.Warn("rate limit exceeded", "ip", ip)
|
||||
writeError(w, http.StatusTooManyRequests, "rate limit exceeded")
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func extractIP(r *http.Request) string {
|
||||
if fwd := r.Header.Get("X-Forwarded-For"); fwd != "" {
|
||||
parts := strings.SplitN(fwd, ",", 2)
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
Reference in New Issue
Block a user