Files
routedns/rate-limiter.go
Ali e51f51e1bc move from logrus to slog (#422)
* Migrate from logrus to slog

* fully removing logrus

* should be working now

* Update pipeline.go

Co-authored-by: Frank Olbricht <frank.olbricht@gmail.com>

* Update response-blocklist-name.go

Co-authored-by: Frank Olbricht <frank.olbricht@gmail.com>

* added null logger

* Update pipeline.go

---------

Co-authored-by: Frank Olbricht <frank.olbricht@gmail.com>
2025-01-13 08:43:30 +01:00

124 lines
3.0 KiB
Go

package rdns
import (
"expvar"
"net"
"sync"
"time"
"github.com/miekg/dns"
)
// RateLimiter is a resolver that limits the number of queries by a client (network)
// that are passed to the upstream resolver per timeframe.
type RateLimiter struct {
id string
resolver Resolver
RateLimiterOptions
mu sync.RWMutex
currWinID int64
counters map[string]*uint
metrics *RateLimiterMetrics
}
var _ Resolver = &RateLimiter{}
type RateLimiterOptions struct {
Requests uint // Number of requests allwed per time period
Window uint // Time period in seconds
Prefix4 uint8 // Netmask to identify IP4 clients
Prefix6 uint8 // Netmask to identify IP6 clients
LimitResolver Resolver // Alternate resolver for rate-limited requests
}
type RateLimiterMetrics struct {
// Count of queries.
query *expvar.Int
// Count of queries that have exceeded the rate limit.
exceed *expvar.Int
// Count of dropped queries.
drop *expvar.Int
}
// NewRateLimiterIP returns a new instance of a query rate limiter.
func NewRateLimiter(id string, resolver Resolver, opt RateLimiterOptions) *RateLimiter {
if opt.Window == 0 {
opt.Window = 60
}
if opt.Prefix4 == 0 {
opt.Prefix4 = 24
}
if opt.Prefix6 == 0 {
opt.Prefix6 = 56
}
return &RateLimiter{
id: id,
resolver: resolver,
RateLimiterOptions: opt,
metrics: &RateLimiterMetrics{
query: getVarInt("router", id, "query"),
exceed: getVarInt("router", id, "exceed"),
drop: getVarInt("router", id, "drop"),
},
}
}
// Resolve a DNS query while limiting the query rate per time period.
func (r *RateLimiter) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) {
log := logger(r.id, q, ci)
r.metrics.query.Add(1)
// Apply the desired mask to the client IP to build a key it identify the client (network)
source := ci.SourceIP
if ip4 := source.To4(); len(ip4) == net.IPv4len {
source = source.Mask(net.CIDRMask(int(r.Prefix4), 32))
} else {
source = source.Mask(net.CIDRMask(int(r.Prefix6), 128))
}
key := source.String()
// Calculate the current (fixed) window
windowID := time.Now().Unix() / int64(r.Window)
var reject bool
r.mu.Lock()
// If we have moved on to the next window, re-initialize the counters
if windowID != r.currWinID {
r.currWinID = windowID
r.counters = make(map[string]*uint)
}
// Load the current counter for this client or make a new one
v, ok := r.counters[key]
if !ok {
v = new(uint)
r.counters[key] = v
}
// Check the number of requests made in this window
if *v >= r.Requests {
reject = true
}
*v++
r.mu.Unlock()
if reject {
r.metrics.exceed.Add(1)
if r.LimitResolver != nil {
log.With("resolver", r.LimitResolver).Debug("rate-limit exceeded, forwarding to limit-resolver")
return r.LimitResolver.Resolve(q, ci)
}
r.metrics.drop.Add(1)
log.Debug("rate-limit reached, dropping")
return nil, nil
}
log.With("resolver", r.resolver).Debug("forwarding query to resolver")
return r.resolver.Resolve(q, ci)
}
func (r *RateLimiter) String() string {
return r.id
}