mirror of
https://github.com/folbricht/routedns.git
synced 2025-12-17 15:45:21 -06:00
* Migrate from logrus to slog * fully removing logrus * should be working now * Update pipeline.go Co-authored-by: Frank Olbricht <frank.olbricht@gmail.com> * Update response-blocklist-name.go Co-authored-by: Frank Olbricht <frank.olbricht@gmail.com> * added null logger * Update pipeline.go --------- Co-authored-by: Frank Olbricht <frank.olbricht@gmail.com>
124 lines
3.0 KiB
Go
124 lines
3.0 KiB
Go
package rdns
|
|
|
|
import (
|
|
"expvar"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
// RateLimiter is a resolver that limits the number of queries by a client (network)
|
|
// that are passed to the upstream resolver per timeframe.
|
|
type RateLimiter struct {
|
|
id string
|
|
resolver Resolver
|
|
RateLimiterOptions
|
|
|
|
mu sync.RWMutex
|
|
currWinID int64
|
|
counters map[string]*uint
|
|
metrics *RateLimiterMetrics
|
|
}
|
|
|
|
var _ Resolver = &RateLimiter{}
|
|
|
|
type RateLimiterOptions struct {
|
|
Requests uint // Number of requests allwed per time period
|
|
Window uint // Time period in seconds
|
|
Prefix4 uint8 // Netmask to identify IP4 clients
|
|
Prefix6 uint8 // Netmask to identify IP6 clients
|
|
LimitResolver Resolver // Alternate resolver for rate-limited requests
|
|
}
|
|
|
|
type RateLimiterMetrics struct {
|
|
// Count of queries.
|
|
query *expvar.Int
|
|
// Count of queries that have exceeded the rate limit.
|
|
exceed *expvar.Int
|
|
// Count of dropped queries.
|
|
drop *expvar.Int
|
|
}
|
|
|
|
// NewRateLimiterIP returns a new instance of a query rate limiter.
|
|
func NewRateLimiter(id string, resolver Resolver, opt RateLimiterOptions) *RateLimiter {
|
|
if opt.Window == 0 {
|
|
opt.Window = 60
|
|
}
|
|
if opt.Prefix4 == 0 {
|
|
opt.Prefix4 = 24
|
|
}
|
|
if opt.Prefix6 == 0 {
|
|
opt.Prefix6 = 56
|
|
}
|
|
return &RateLimiter{
|
|
id: id,
|
|
resolver: resolver,
|
|
RateLimiterOptions: opt,
|
|
metrics: &RateLimiterMetrics{
|
|
query: getVarInt("router", id, "query"),
|
|
exceed: getVarInt("router", id, "exceed"),
|
|
drop: getVarInt("router", id, "drop"),
|
|
},
|
|
}
|
|
}
|
|
|
|
// Resolve a DNS query while limiting the query rate per time period.
|
|
func (r *RateLimiter) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) {
|
|
log := logger(r.id, q, ci)
|
|
r.metrics.query.Add(1)
|
|
|
|
// Apply the desired mask to the client IP to build a key it identify the client (network)
|
|
source := ci.SourceIP
|
|
if ip4 := source.To4(); len(ip4) == net.IPv4len {
|
|
source = source.Mask(net.CIDRMask(int(r.Prefix4), 32))
|
|
} else {
|
|
source = source.Mask(net.CIDRMask(int(r.Prefix6), 128))
|
|
}
|
|
key := source.String()
|
|
|
|
// Calculate the current (fixed) window
|
|
windowID := time.Now().Unix() / int64(r.Window)
|
|
|
|
var reject bool
|
|
r.mu.Lock()
|
|
|
|
// If we have moved on to the next window, re-initialize the counters
|
|
if windowID != r.currWinID {
|
|
r.currWinID = windowID
|
|
r.counters = make(map[string]*uint)
|
|
}
|
|
|
|
// Load the current counter for this client or make a new one
|
|
v, ok := r.counters[key]
|
|
if !ok {
|
|
v = new(uint)
|
|
r.counters[key] = v
|
|
}
|
|
|
|
// Check the number of requests made in this window
|
|
if *v >= r.Requests {
|
|
reject = true
|
|
}
|
|
*v++
|
|
r.mu.Unlock()
|
|
|
|
if reject {
|
|
r.metrics.exceed.Add(1)
|
|
if r.LimitResolver != nil {
|
|
log.With("resolver", r.LimitResolver).Debug("rate-limit exceeded, forwarding to limit-resolver")
|
|
return r.LimitResolver.Resolve(q, ci)
|
|
}
|
|
r.metrics.drop.Add(1)
|
|
log.Debug("rate-limit reached, dropping")
|
|
return nil, nil
|
|
}
|
|
log.With("resolver", r.resolver).Debug("forwarding query to resolver")
|
|
return r.resolver.Resolve(q, ci)
|
|
}
|
|
|
|
func (r *RateLimiter) String() string {
|
|
return r.id
|
|
}
|