diff --git a/cache.go b/cache.go index 197c56a..c387750 100644 --- a/cache.go +++ b/cache.go @@ -6,6 +6,8 @@ import ( "math" "math/rand" "strings" + "sync" + "sync/atomic" "time" "github.com/miekg/dns" @@ -327,24 +329,91 @@ func AnswerShuffleRandom(msg *dns.Msg) { }) } +// Round Robin shuffling requires keeping state as it's operating on copies +// of DNS messages so the number of shift operations needs to be remembered. +type rrShuffleRecord struct { + reads uint64 + expiry time.Time +} + +var ( + rrShuffleState map[lruKey]*rrShuffleRecord + rrShuffleOnce sync.Once + rrShuffleMu sync.RWMutex +) + // Shift the answer A/AAAA record order in an answer by one. func AnswerShuffleRoundRobin(msg *dns.Msg) { if len(msg.Answer) < 2 { return } - var last dns.RR - var dst int + rrShuffleOnce.Do(func() { + rrShuffleState = make(map[lruKey]*rrShuffleRecord) + + // Start a cleanup job + go func() { + for { + time.Sleep(30 * time.Second) + rrShuffleMu.RLock() + + // Build a list of expired items + var toRemove []lruKey + for k, v := range rrShuffleState { + now := time.Now() + if now.After(v.expiry) { + toRemove = append(toRemove, k) + } + } + rrShuffleMu.RUnlock() + + // Remove the expired items + rrShuffleMu.Lock() + for _, k := range toRemove { + delete(rrShuffleState, k) + } + rrShuffleMu.Unlock() + } + }() + }) + + // Lookup how often the results were shifted previously + key := lruKeyFromQuery(msg) + rrShuffleMu.RLock() + rec, ok := rrShuffleState[key] + rrShuffleMu.RUnlock() + var shiftBy uint64 + if ok { + shiftBy = atomic.AddUint64(&rec.reads, 1) + } else { + ttl, ok := minTTL(msg) + if !ok { + return + } + rec = &rrShuffleRecord{ + expiry: time.Now().Add(time.Duration(ttl) * time.Second), + } + rrShuffleMu.Lock() + rrShuffleState[key] = rec + rrShuffleMu.Unlock() + } + + // Build a list of pointers to A/AAAA records in the message + var aRecords []*dns.RR for i, rr := range msg.Answer { if rr.Header().Rrtype == dns.TypeA || rr.Header().Rrtype == dns.TypeAAAA { - if last == nil { - last = rr - } else { - msg.Answer[dst] = rr - } - dst = i + aRecords = append(aRecords, &msg.Answer[i]) } } - if last != nil { - msg.Answer[dst] = last + + // Rotate the A/AAAA record pointers + shiftBy %= uint64(len(aRecords)) + shiftBy++ + + for i := uint64(0); i < shiftBy; i++ { + last := *aRecords[len(aRecords)-1] + for j := len(aRecords) - 1; j > 0; j-- { + *aRecords[j] = *aRecords[j-1] + } + *aRecords[0] = last } } diff --git a/cache_test.go b/cache_test.go index 11ea167..8a9d5fb 100644 --- a/cache_test.go +++ b/cache_test.go @@ -134,6 +134,9 @@ func TestCacheHardenBelowNXDOMAIN(t *testing.T) { func TestRoundRobinShuffle(t *testing.T) { msg := &dns.Msg{ + Question: []dns.Question{ + {Name: "example.com."}, + }, Answer: []dns.RR{ &dns.CNAME{ Hdr: dns.RR_Header{ @@ -161,17 +164,28 @@ func TestRoundRobinShuffle(t *testing.T) { }, }, } - // Shift the A records - AnswerShuffleRoundRobin(msg) + + // Shift the A records once + msg1 := msg.Copy() + AnswerShuffleRoundRobin(msg1) require.Equal(t, dns.TypeCNAME, msg.Answer[0].Header().Rrtype) require.Equal(t, dns.TypeA, msg.Answer[1].Header().Rrtype) require.Equal(t, dns.TypeA, msg.Answer[2].Header().Rrtype) - a1 := msg.Answer[1].(*dns.A) - a2 := msg.Answer[2].(*dns.A) + a1 := msg1.Answer[1].(*dns.A) + a2 := msg1.Answer[2].(*dns.A) require.Equal(t, net.IP{0, 0, 0, 2}, a1.A) require.Equal(t, net.IP{0, 0, 0, 1}, a2.A) + + // Shift the A records again + msg2 := msg.Copy() + AnswerShuffleRoundRobin(msg2) + + a1 = msg2.Answer[1].(*dns.A) + a2 = msg2.Answer[2].(*dns.A) + require.Equal(t, net.IP{0, 0, 0, 1}, a1.A) + require.Equal(t, net.IP{0, 0, 0, 2}, a2.A) } // Truncated responses should not be cached