mirror of
https://github.com/folbricht/routedns.git
synced 2025-12-23 10:29:45 -06:00
Fix Round-Robin rotation of cached records (#328)
This commit is contained in:
89
cache.go
89
cache.go
@@ -6,6 +6,8 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@@ -327,24 +329,91 @@ func AnswerShuffleRandom(msg *dns.Msg) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Round Robin shuffling requires keeping state as it's operating on copies
|
||||||
|
// of DNS messages so the number of shift operations needs to be remembered.
|
||||||
|
type rrShuffleRecord struct {
|
||||||
|
reads uint64
|
||||||
|
expiry time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
rrShuffleState map[lruKey]*rrShuffleRecord
|
||||||
|
rrShuffleOnce sync.Once
|
||||||
|
rrShuffleMu sync.RWMutex
|
||||||
|
)
|
||||||
|
|
||||||
// Shift the answer A/AAAA record order in an answer by one.
|
// Shift the answer A/AAAA record order in an answer by one.
|
||||||
func AnswerShuffleRoundRobin(msg *dns.Msg) {
|
func AnswerShuffleRoundRobin(msg *dns.Msg) {
|
||||||
if len(msg.Answer) < 2 {
|
if len(msg.Answer) < 2 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var last dns.RR
|
rrShuffleOnce.Do(func() {
|
||||||
var dst int
|
rrShuffleState = make(map[lruKey]*rrShuffleRecord)
|
||||||
|
|
||||||
|
// Start a cleanup job
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
time.Sleep(30 * time.Second)
|
||||||
|
rrShuffleMu.RLock()
|
||||||
|
|
||||||
|
// Build a list of expired items
|
||||||
|
var toRemove []lruKey
|
||||||
|
for k, v := range rrShuffleState {
|
||||||
|
now := time.Now()
|
||||||
|
if now.After(v.expiry) {
|
||||||
|
toRemove = append(toRemove, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rrShuffleMu.RUnlock()
|
||||||
|
|
||||||
|
// Remove the expired items
|
||||||
|
rrShuffleMu.Lock()
|
||||||
|
for _, k := range toRemove {
|
||||||
|
delete(rrShuffleState, k)
|
||||||
|
}
|
||||||
|
rrShuffleMu.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Lookup how often the results were shifted previously
|
||||||
|
key := lruKeyFromQuery(msg)
|
||||||
|
rrShuffleMu.RLock()
|
||||||
|
rec, ok := rrShuffleState[key]
|
||||||
|
rrShuffleMu.RUnlock()
|
||||||
|
var shiftBy uint64
|
||||||
|
if ok {
|
||||||
|
shiftBy = atomic.AddUint64(&rec.reads, 1)
|
||||||
|
} else {
|
||||||
|
ttl, ok := minTTL(msg)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rec = &rrShuffleRecord{
|
||||||
|
expiry: time.Now().Add(time.Duration(ttl) * time.Second),
|
||||||
|
}
|
||||||
|
rrShuffleMu.Lock()
|
||||||
|
rrShuffleState[key] = rec
|
||||||
|
rrShuffleMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a list of pointers to A/AAAA records in the message
|
||||||
|
var aRecords []*dns.RR
|
||||||
for i, rr := range msg.Answer {
|
for i, rr := range msg.Answer {
|
||||||
if rr.Header().Rrtype == dns.TypeA || rr.Header().Rrtype == dns.TypeAAAA {
|
if rr.Header().Rrtype == dns.TypeA || rr.Header().Rrtype == dns.TypeAAAA {
|
||||||
if last == nil {
|
aRecords = append(aRecords, &msg.Answer[i])
|
||||||
last = rr
|
|
||||||
} else {
|
|
||||||
msg.Answer[dst] = rr
|
|
||||||
}
|
|
||||||
dst = i
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if last != nil {
|
|
||||||
msg.Answer[dst] = last
|
// Rotate the A/AAAA record pointers
|
||||||
|
shiftBy %= uint64(len(aRecords))
|
||||||
|
shiftBy++
|
||||||
|
|
||||||
|
for i := uint64(0); i < shiftBy; i++ {
|
||||||
|
last := *aRecords[len(aRecords)-1]
|
||||||
|
for j := len(aRecords) - 1; j > 0; j-- {
|
||||||
|
*aRecords[j] = *aRecords[j-1]
|
||||||
|
}
|
||||||
|
*aRecords[0] = last
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -134,6 +134,9 @@ func TestCacheHardenBelowNXDOMAIN(t *testing.T) {
|
|||||||
|
|
||||||
func TestRoundRobinShuffle(t *testing.T) {
|
func TestRoundRobinShuffle(t *testing.T) {
|
||||||
msg := &dns.Msg{
|
msg := &dns.Msg{
|
||||||
|
Question: []dns.Question{
|
||||||
|
{Name: "example.com."},
|
||||||
|
},
|
||||||
Answer: []dns.RR{
|
Answer: []dns.RR{
|
||||||
&dns.CNAME{
|
&dns.CNAME{
|
||||||
Hdr: dns.RR_Header{
|
Hdr: dns.RR_Header{
|
||||||
@@ -161,17 +164,28 @@ func TestRoundRobinShuffle(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
// Shift the A records
|
|
||||||
AnswerShuffleRoundRobin(msg)
|
// Shift the A records once
|
||||||
|
msg1 := msg.Copy()
|
||||||
|
AnswerShuffleRoundRobin(msg1)
|
||||||
|
|
||||||
require.Equal(t, dns.TypeCNAME, msg.Answer[0].Header().Rrtype)
|
require.Equal(t, dns.TypeCNAME, msg.Answer[0].Header().Rrtype)
|
||||||
require.Equal(t, dns.TypeA, msg.Answer[1].Header().Rrtype)
|
require.Equal(t, dns.TypeA, msg.Answer[1].Header().Rrtype)
|
||||||
require.Equal(t, dns.TypeA, msg.Answer[2].Header().Rrtype)
|
require.Equal(t, dns.TypeA, msg.Answer[2].Header().Rrtype)
|
||||||
|
|
||||||
a1 := msg.Answer[1].(*dns.A)
|
a1 := msg1.Answer[1].(*dns.A)
|
||||||
a2 := msg.Answer[2].(*dns.A)
|
a2 := msg1.Answer[2].(*dns.A)
|
||||||
require.Equal(t, net.IP{0, 0, 0, 2}, a1.A)
|
require.Equal(t, net.IP{0, 0, 0, 2}, a1.A)
|
||||||
require.Equal(t, net.IP{0, 0, 0, 1}, a2.A)
|
require.Equal(t, net.IP{0, 0, 0, 1}, a2.A)
|
||||||
|
|
||||||
|
// Shift the A records again
|
||||||
|
msg2 := msg.Copy()
|
||||||
|
AnswerShuffleRoundRobin(msg2)
|
||||||
|
|
||||||
|
a1 = msg2.Answer[1].(*dns.A)
|
||||||
|
a2 = msg2.Answer[2].(*dns.A)
|
||||||
|
require.Equal(t, net.IP{0, 0, 0, 1}, a1.A)
|
||||||
|
require.Equal(t, net.IP{0, 0, 0, 2}, a2.A)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Truncated responses should not be cached
|
// Truncated responses should not be cached
|
||||||
|
|||||||
Reference in New Issue
Block a user