mirror of
https://github.com/folbricht/routedns.git
synced 2025-12-21 09:29:56 -06:00
Fix Round-Robin rotation of cached records (#328)
This commit is contained in:
89
cache.go
89
cache.go
@@ -6,6 +6,8 @@ import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@@ -327,24 +329,91 @@ func AnswerShuffleRandom(msg *dns.Msg) {
|
||||
})
|
||||
}
|
||||
|
||||
// Round Robin shuffling requires keeping state as it's operating on copies
|
||||
// of DNS messages so the number of shift operations needs to be remembered.
|
||||
type rrShuffleRecord struct {
|
||||
reads uint64
|
||||
expiry time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
rrShuffleState map[lruKey]*rrShuffleRecord
|
||||
rrShuffleOnce sync.Once
|
||||
rrShuffleMu sync.RWMutex
|
||||
)
|
||||
|
||||
// Shift the answer A/AAAA record order in an answer by one.
|
||||
func AnswerShuffleRoundRobin(msg *dns.Msg) {
|
||||
if len(msg.Answer) < 2 {
|
||||
return
|
||||
}
|
||||
var last dns.RR
|
||||
var dst int
|
||||
rrShuffleOnce.Do(func() {
|
||||
rrShuffleState = make(map[lruKey]*rrShuffleRecord)
|
||||
|
||||
// Start a cleanup job
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(30 * time.Second)
|
||||
rrShuffleMu.RLock()
|
||||
|
||||
// Build a list of expired items
|
||||
var toRemove []lruKey
|
||||
for k, v := range rrShuffleState {
|
||||
now := time.Now()
|
||||
if now.After(v.expiry) {
|
||||
toRemove = append(toRemove, k)
|
||||
}
|
||||
}
|
||||
rrShuffleMu.RUnlock()
|
||||
|
||||
// Remove the expired items
|
||||
rrShuffleMu.Lock()
|
||||
for _, k := range toRemove {
|
||||
delete(rrShuffleState, k)
|
||||
}
|
||||
rrShuffleMu.Unlock()
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
// Lookup how often the results were shifted previously
|
||||
key := lruKeyFromQuery(msg)
|
||||
rrShuffleMu.RLock()
|
||||
rec, ok := rrShuffleState[key]
|
||||
rrShuffleMu.RUnlock()
|
||||
var shiftBy uint64
|
||||
if ok {
|
||||
shiftBy = atomic.AddUint64(&rec.reads, 1)
|
||||
} else {
|
||||
ttl, ok := minTTL(msg)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
rec = &rrShuffleRecord{
|
||||
expiry: time.Now().Add(time.Duration(ttl) * time.Second),
|
||||
}
|
||||
rrShuffleMu.Lock()
|
||||
rrShuffleState[key] = rec
|
||||
rrShuffleMu.Unlock()
|
||||
}
|
||||
|
||||
// Build a list of pointers to A/AAAA records in the message
|
||||
var aRecords []*dns.RR
|
||||
for i, rr := range msg.Answer {
|
||||
if rr.Header().Rrtype == dns.TypeA || rr.Header().Rrtype == dns.TypeAAAA {
|
||||
if last == nil {
|
||||
last = rr
|
||||
} else {
|
||||
msg.Answer[dst] = rr
|
||||
}
|
||||
dst = i
|
||||
aRecords = append(aRecords, &msg.Answer[i])
|
||||
}
|
||||
}
|
||||
if last != nil {
|
||||
msg.Answer[dst] = last
|
||||
|
||||
// Rotate the A/AAAA record pointers
|
||||
shiftBy %= uint64(len(aRecords))
|
||||
shiftBy++
|
||||
|
||||
for i := uint64(0); i < shiftBy; i++ {
|
||||
last := *aRecords[len(aRecords)-1]
|
||||
for j := len(aRecords) - 1; j > 0; j-- {
|
||||
*aRecords[j] = *aRecords[j-1]
|
||||
}
|
||||
*aRecords[0] = last
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,6 +134,9 @@ func TestCacheHardenBelowNXDOMAIN(t *testing.T) {
|
||||
|
||||
func TestRoundRobinShuffle(t *testing.T) {
|
||||
msg := &dns.Msg{
|
||||
Question: []dns.Question{
|
||||
{Name: "example.com."},
|
||||
},
|
||||
Answer: []dns.RR{
|
||||
&dns.CNAME{
|
||||
Hdr: dns.RR_Header{
|
||||
@@ -161,17 +164,28 @@ func TestRoundRobinShuffle(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
// Shift the A records
|
||||
AnswerShuffleRoundRobin(msg)
|
||||
|
||||
// Shift the A records once
|
||||
msg1 := msg.Copy()
|
||||
AnswerShuffleRoundRobin(msg1)
|
||||
|
||||
require.Equal(t, dns.TypeCNAME, msg.Answer[0].Header().Rrtype)
|
||||
require.Equal(t, dns.TypeA, msg.Answer[1].Header().Rrtype)
|
||||
require.Equal(t, dns.TypeA, msg.Answer[2].Header().Rrtype)
|
||||
|
||||
a1 := msg.Answer[1].(*dns.A)
|
||||
a2 := msg.Answer[2].(*dns.A)
|
||||
a1 := msg1.Answer[1].(*dns.A)
|
||||
a2 := msg1.Answer[2].(*dns.A)
|
||||
require.Equal(t, net.IP{0, 0, 0, 2}, a1.A)
|
||||
require.Equal(t, net.IP{0, 0, 0, 1}, a2.A)
|
||||
|
||||
// Shift the A records again
|
||||
msg2 := msg.Copy()
|
||||
AnswerShuffleRoundRobin(msg2)
|
||||
|
||||
a1 = msg2.Answer[1].(*dns.A)
|
||||
a2 = msg2.Answer[2].(*dns.A)
|
||||
require.Equal(t, net.IP{0, 0, 0, 1}, a1.A)
|
||||
require.Equal(t, net.IP{0, 0, 0, 2}, a2.A)
|
||||
}
|
||||
|
||||
// Truncated responses should not be cached
|
||||
|
||||
Reference in New Issue
Block a user