Fix Round-Robin rotation of cached records (#328)

This commit is contained in:
Frank Olbricht
2023-09-04 14:26:18 +02:00
committed by GitHub
parent 95cd90eb16
commit 4c1e011dbe
2 changed files with 97 additions and 14 deletions

View File

@@ -6,6 +6,8 @@ import (
"math"
"math/rand"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns"
@@ -327,24 +329,91 @@ func AnswerShuffleRandom(msg *dns.Msg) {
})
}
// Round Robin shuffling requires keeping state as it's operating on copies
// of DNS messages so the number of shift operations needs to be remembered.
type rrShuffleRecord struct {
reads uint64
expiry time.Time
}
var (
rrShuffleState map[lruKey]*rrShuffleRecord
rrShuffleOnce sync.Once
rrShuffleMu sync.RWMutex
)
// Shift the answer A/AAAA record order in an answer by one.
func AnswerShuffleRoundRobin(msg *dns.Msg) {
if len(msg.Answer) < 2 {
return
}
var last dns.RR
var dst int
rrShuffleOnce.Do(func() {
rrShuffleState = make(map[lruKey]*rrShuffleRecord)
// Start a cleanup job
go func() {
for {
time.Sleep(30 * time.Second)
rrShuffleMu.RLock()
// Build a list of expired items
var toRemove []lruKey
for k, v := range rrShuffleState {
now := time.Now()
if now.After(v.expiry) {
toRemove = append(toRemove, k)
}
}
rrShuffleMu.RUnlock()
// Remove the expired items
rrShuffleMu.Lock()
for _, k := range toRemove {
delete(rrShuffleState, k)
}
rrShuffleMu.Unlock()
}
}()
})
// Lookup how often the results were shifted previously
key := lruKeyFromQuery(msg)
rrShuffleMu.RLock()
rec, ok := rrShuffleState[key]
rrShuffleMu.RUnlock()
var shiftBy uint64
if ok {
shiftBy = atomic.AddUint64(&rec.reads, 1)
} else {
ttl, ok := minTTL(msg)
if !ok {
return
}
rec = &rrShuffleRecord{
expiry: time.Now().Add(time.Duration(ttl) * time.Second),
}
rrShuffleMu.Lock()
rrShuffleState[key] = rec
rrShuffleMu.Unlock()
}
// Build a list of pointers to A/AAAA records in the message
var aRecords []*dns.RR
for i, rr := range msg.Answer {
if rr.Header().Rrtype == dns.TypeA || rr.Header().Rrtype == dns.TypeAAAA {
if last == nil {
last = rr
} else {
msg.Answer[dst] = rr
}
dst = i
aRecords = append(aRecords, &msg.Answer[i])
}
}
if last != nil {
msg.Answer[dst] = last
// Rotate the A/AAAA record pointers
shiftBy %= uint64(len(aRecords))
shiftBy++
for i := uint64(0); i < shiftBy; i++ {
last := *aRecords[len(aRecords)-1]
for j := len(aRecords) - 1; j > 0; j-- {
*aRecords[j] = *aRecords[j-1]
}
*aRecords[0] = last
}
}

View File

@@ -134,6 +134,9 @@ func TestCacheHardenBelowNXDOMAIN(t *testing.T) {
func TestRoundRobinShuffle(t *testing.T) {
msg := &dns.Msg{
Question: []dns.Question{
{Name: "example.com."},
},
Answer: []dns.RR{
&dns.CNAME{
Hdr: dns.RR_Header{
@@ -161,17 +164,28 @@ func TestRoundRobinShuffle(t *testing.T) {
},
},
}
// Shift the A records
AnswerShuffleRoundRobin(msg)
// Shift the A records once
msg1 := msg.Copy()
AnswerShuffleRoundRobin(msg1)
require.Equal(t, dns.TypeCNAME, msg.Answer[0].Header().Rrtype)
require.Equal(t, dns.TypeA, msg.Answer[1].Header().Rrtype)
require.Equal(t, dns.TypeA, msg.Answer[2].Header().Rrtype)
a1 := msg.Answer[1].(*dns.A)
a2 := msg.Answer[2].(*dns.A)
a1 := msg1.Answer[1].(*dns.A)
a2 := msg1.Answer[2].(*dns.A)
require.Equal(t, net.IP{0, 0, 0, 2}, a1.A)
require.Equal(t, net.IP{0, 0, 0, 1}, a2.A)
// Shift the A records again
msg2 := msg.Copy()
AnswerShuffleRoundRobin(msg2)
a1 = msg2.Answer[1].(*dns.A)
a2 = msg2.Answer[2].(*dns.A)
require.Equal(t, net.IP{0, 0, 0, 1}, a1.A)
require.Equal(t, net.IP{0, 0, 0, 2}, a2.A)
}
// Truncated responses should not be cached