mirror of
https://github.com/folbricht/routedns.git
synced 2026-02-09 10:28:28 -06:00
Redis cache: use binary wire format with pooled buffers (#473)
* Redis cache: use binary wire format with pooled buffers * Changes based on PR comments * Fix binary wire format message being corrupted
This commit is contained in:
123
cache-redis.go
123
cache-redis.go
@@ -2,11 +2,13 @@ package rdns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"expvar"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@@ -33,6 +35,102 @@ type RedisBackendOptions struct {
|
||||
|
||||
var _ CacheBackend = (*redisBackend)(nil)
|
||||
|
||||
// Buffer pool for dns.Msg.PackBuffer to minimize allocations.
|
||||
var packBufPool = sync.Pool{
|
||||
New: func() any {
|
||||
b := make([]byte, 0, 2048)
|
||||
return &b
|
||||
},
|
||||
}
|
||||
|
||||
const (
|
||||
binaryFormatVersion = 1
|
||||
headerSize = 10
|
||||
flagPrefetchBit = 1 << 0
|
||||
)
|
||||
|
||||
// encodeCacheAnswer encodes a cacheAnswer into a compact binary format:
|
||||
// - byte 0: version (1)
|
||||
// - byte 1: flags (bit0: prefetchEligible)
|
||||
// - bytes 2..9: timestamp (uint64 seconds from Unix epoch, big endian)
|
||||
// - bytes 10..N: dns.Msg wire bytes
|
||||
func encodeCacheAnswer(item *cacheAnswer) ([]byte, error) {
|
||||
bufPtr := packBufPool.Get().(*[]byte)
|
||||
buf := *bufPtr
|
||||
|
||||
defer func() {
|
||||
*bufPtr = buf[:0]
|
||||
packBufPool.Put(bufPtr)
|
||||
}()
|
||||
|
||||
if cap(buf) == 0 {
|
||||
buf = make([]byte, 0, 2048)
|
||||
}
|
||||
|
||||
// Pack DNS message first into the scratch buffer
|
||||
buf = buf[:cap(buf)]
|
||||
dnsWire, err := item.Msg.PackBuffer(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pack DNS message: %w", err)
|
||||
}
|
||||
|
||||
// Keep the (potentially grown) buffer for cleanup
|
||||
buf = dnsWire
|
||||
|
||||
// Allocate result with header + DNS wire bytes
|
||||
result := make([]byte, headerSize+len(dnsWire))
|
||||
|
||||
// Write header
|
||||
result[0] = binaryFormatVersion
|
||||
|
||||
var flags byte
|
||||
if item.PrefetchEligible {
|
||||
flags |= flagPrefetchBit
|
||||
}
|
||||
result[1] = flags
|
||||
|
||||
timestamp := uint64(item.Timestamp.Unix())
|
||||
binary.BigEndian.PutUint64(result[2:10], timestamp)
|
||||
|
||||
// Copy DNS wire bytes after header
|
||||
copy(result[headerSize:], dnsWire)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// decodeCacheAnswer decodes a binary-encoded cacheAnswer.
|
||||
// Returns an error if the format is invalid or unsupported.
|
||||
func decodeCacheAnswer(b []byte) (*cacheAnswer, error) {
|
||||
if len(b) < headerSize {
|
||||
return nil, fmt.Errorf("binary data too short: %d bytes", len(b))
|
||||
}
|
||||
|
||||
// Check version
|
||||
version := b[0]
|
||||
if version != binaryFormatVersion {
|
||||
return nil, fmt.Errorf("unsupported binary format version: %d", version)
|
||||
}
|
||||
|
||||
// Parse flags
|
||||
flags := b[1]
|
||||
prefetchEligible := (flags & flagPrefetchBit) != 0
|
||||
|
||||
// Parse timestamp
|
||||
timestamp := int64(binary.BigEndian.Uint64(b[2:10]))
|
||||
|
||||
// Unpack DNS message
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(b[headerSize:]); err != nil {
|
||||
return nil, fmt.Errorf("failed to unpack DNS message: %w", err)
|
||||
}
|
||||
|
||||
return &cacheAnswer{
|
||||
Timestamp: time.Unix(timestamp, 0),
|
||||
PrefetchEligible: prefetchEligible,
|
||||
Msg: msg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewRedisBackend(opt RedisBackendOptions) *redisBackend {
|
||||
b := &redisBackend{
|
||||
client: redis.NewClient(&opt.RedisOptions),
|
||||
@@ -58,14 +156,15 @@ func (b *redisBackend) Store(query *dns.Msg, item *cacheAnswer) {
|
||||
}
|
||||
|
||||
func (b *redisBackend) storeSync(query *dns.Msg, item *cacheAnswer, ttl time.Duration) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
key := b.keyFromQuery(query)
|
||||
value, err := json.Marshal(item)
|
||||
value, err := encodeCacheAnswer(item)
|
||||
if err != nil {
|
||||
Log.Error("failed to marshal cache record", "error", err)
|
||||
Log.Error("failed to encode cache record", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
if err := b.client.Set(ctx, key, value, ttl).Err(); err != nil {
|
||||
Log.Error("failed to write to redis", "error", err)
|
||||
}
|
||||
@@ -89,7 +188,9 @@ func (b *redisBackend) Lookup(q *dns.Msg) (*dns.Msg, bool, bool) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
key := b.keyFromQuery(q)
|
||||
value, err := b.client.Get(ctx, key).Result()
|
||||
|
||||
// Fetch raw bytes to avoid string conversion overhead
|
||||
valueBytes, err := b.client.Get(ctx, key).Bytes()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) { // Return a cache-miss if there's no such key
|
||||
return nil, false, false
|
||||
@@ -97,10 +198,16 @@ func (b *redisBackend) Lookup(q *dns.Msg) (*dns.Msg, bool, bool) {
|
||||
Log.Error("failed to read from redis", "error", err)
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
// Try binary decode first, with JSON fallback for backward compatibility
|
||||
var a *cacheAnswer
|
||||
if err := json.Unmarshal([]byte(value), &a); err != nil {
|
||||
Log.Error("failed to unmarshal cache record from redis", "error", err)
|
||||
return nil, false, false
|
||||
a, err = decodeCacheAnswer(valueBytes)
|
||||
if err != nil {
|
||||
// Fallback to JSON for backward compatibility with existing cached entries
|
||||
if jsonErr := json.Unmarshal(valueBytes, &a); jsonErr != nil {
|
||||
Log.Error("failed to decode cache record from redis", "binary_error", err, "json_error", jsonErr)
|
||||
return nil, false, false
|
||||
}
|
||||
}
|
||||
|
||||
answer := a.Msg
|
||||
|
||||
296
cache-redis_test.go
Normal file
296
cache-redis_test.go
Normal file
@@ -0,0 +1,296 @@
|
||||
package rdns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestEncodeDecode(t *testing.T) {
|
||||
// Create a test DNS message
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("example.com.", dns.TypeA)
|
||||
msg.Response = true
|
||||
msg.Rcode = dns.RcodeSuccess
|
||||
|
||||
// Add an answer
|
||||
rr, err := dns.NewRR("example.com. 300 IN A 192.0.2.1")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create RR: %v", err)
|
||||
}
|
||||
msg.Answer = append(msg.Answer, rr)
|
||||
|
||||
// Create a cacheAnswer
|
||||
now := time.Now()
|
||||
original := &cacheAnswer{
|
||||
Timestamp: now,
|
||||
PrefetchEligible: true,
|
||||
Msg: msg,
|
||||
}
|
||||
|
||||
// Encode
|
||||
encoded, err := encodeCacheAnswer(original)
|
||||
if err != nil {
|
||||
t.Fatalf("encodeCacheAnswer failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify format
|
||||
if len(encoded) < headerSize {
|
||||
t.Fatalf("encoded data too short: %d bytes", len(encoded))
|
||||
}
|
||||
|
||||
// Check version byte
|
||||
if encoded[0] != binaryFormatVersion {
|
||||
t.Errorf("version byte = %d, want %d", encoded[0], binaryFormatVersion)
|
||||
}
|
||||
|
||||
// Check flags byte
|
||||
expectedFlags := byte(flagPrefetchBit)
|
||||
if encoded[1] != expectedFlags {
|
||||
t.Errorf("flags byte = %d, want %d", encoded[1], expectedFlags)
|
||||
}
|
||||
|
||||
// Decode
|
||||
decoded, err := decodeCacheAnswer(encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("decodeCacheAnswer failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify fields
|
||||
if decoded.Timestamp.Unix() != original.Timestamp.Unix() {
|
||||
t.Errorf("timestamp = %v, want %v", decoded.Timestamp, original.Timestamp)
|
||||
}
|
||||
|
||||
if decoded.PrefetchEligible != original.PrefetchEligible {
|
||||
t.Errorf("prefetchEligible = %v, want %v", decoded.PrefetchEligible, original.PrefetchEligible)
|
||||
}
|
||||
|
||||
// Verify DNS message
|
||||
if len(decoded.Msg.Answer) != len(original.Msg.Answer) {
|
||||
t.Errorf("answer count = %d, want %d", len(decoded.Msg.Answer), len(original.Msg.Answer))
|
||||
}
|
||||
|
||||
if decoded.Msg.Question[0].Name != original.Msg.Question[0].Name {
|
||||
t.Errorf("question name = %s, want %s", decoded.Msg.Question[0].Name, original.Msg.Question[0].Name)
|
||||
}
|
||||
|
||||
if decoded.Msg.Question[0].Qtype != original.Msg.Question[0].Qtype {
|
||||
t.Errorf("question type = %d, want %d", decoded.Msg.Question[0].Qtype, original.Msg.Question[0].Qtype)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeDecodeNoPrefetch(t *testing.T) {
|
||||
// Create a test DNS message
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("test.example.", dns.TypeAAAA)
|
||||
msg.Response = true
|
||||
|
||||
// Create a cacheAnswer with prefetch disabled
|
||||
original := &cacheAnswer{
|
||||
Timestamp: time.Unix(1234567890, 0),
|
||||
PrefetchEligible: false,
|
||||
Msg: msg,
|
||||
}
|
||||
|
||||
// Encode
|
||||
encoded, err := encodeCacheAnswer(original)
|
||||
if err != nil {
|
||||
t.Fatalf("encodeCacheAnswer failed: %v", err)
|
||||
}
|
||||
|
||||
// Check flags byte (should be 0)
|
||||
if encoded[1] != 0 {
|
||||
t.Errorf("flags byte = %d, want 0", encoded[1])
|
||||
}
|
||||
|
||||
// Decode
|
||||
decoded, err := decodeCacheAnswer(encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("decodeCacheAnswer failed: %v", err)
|
||||
}
|
||||
|
||||
if decoded.PrefetchEligible != false {
|
||||
t.Errorf("prefetchEligible = %v, want false", decoded.PrefetchEligible)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeInvalidData(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
}{
|
||||
{"too short", []byte{0x01, 0x00}},
|
||||
{"wrong version", []byte{0x99, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}},
|
||||
{"invalid DNS", []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := decodeCacheAnswer(tt.data)
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeDecodePooling(t *testing.T) {
|
||||
// Test that pooling works correctly across multiple encode operations
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("pool.test.", dns.TypeA)
|
||||
msg.Response = true
|
||||
|
||||
item := &cacheAnswer{
|
||||
Timestamp: time.Now(),
|
||||
PrefetchEligible: true,
|
||||
Msg: msg,
|
||||
}
|
||||
|
||||
// Encode multiple times to test pool reuse
|
||||
for i := 0; i < 100; i++ {
|
||||
encoded, err := encodeCacheAnswer(item)
|
||||
if err != nil {
|
||||
t.Fatalf("iteration %d: encodeCacheAnswer failed: %v", i, err)
|
||||
}
|
||||
|
||||
// Verify first byte is always version
|
||||
if encoded[0] != binaryFormatVersion {
|
||||
t.Errorf("iteration %d: version byte = %d, want %d", i, encoded[0], binaryFormatVersion)
|
||||
}
|
||||
|
||||
// Decode to verify correctness
|
||||
decoded, err := decodeCacheAnswer(encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("iteration %d: decodeCacheAnswer failed: %v", i, err)
|
||||
}
|
||||
|
||||
if decoded.Msg.Question[0].Name != "pool.test." {
|
||||
t.Errorf("iteration %d: corrupted data", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeReturnsIndependentSlice(t *testing.T) {
|
||||
// Verify that encoded bytes are independent of the pool and mutations don't affect subsequent encodes
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("independent.test.", dns.TypeA)
|
||||
msg.Response = true
|
||||
|
||||
item := &cacheAnswer{
|
||||
Timestamp: time.Unix(1234567890, 0),
|
||||
PrefetchEligible: true,
|
||||
Msg: msg,
|
||||
}
|
||||
|
||||
// First encode
|
||||
encoded1, err := encodeCacheAnswer(item)
|
||||
if err != nil {
|
||||
t.Fatalf("first encode failed: %v", err)
|
||||
}
|
||||
|
||||
// Save a copy of the original encoded data
|
||||
original := make([]byte, len(encoded1))
|
||||
copy(original, encoded1)
|
||||
|
||||
// Mutate the returned slice to verify it's independent of the pool
|
||||
for i := range encoded1 {
|
||||
encoded1[i] = 0xFF
|
||||
}
|
||||
|
||||
// Verify the mutated buffer is now garbage and fails to decode
|
||||
_, err = decodeCacheAnswer(encoded1)
|
||||
if err == nil {
|
||||
t.Error("expected decode of mutated buffer to fail, but it succeeded")
|
||||
}
|
||||
|
||||
// Second encode - should succeed and produce the same result as the first
|
||||
encoded2, err := encodeCacheAnswer(item)
|
||||
if err != nil {
|
||||
t.Fatalf("second encode failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify second encode matches the original (not corrupted by mutation)
|
||||
if len(encoded2) != len(original) {
|
||||
t.Fatalf("length mismatch: got %d, want %d", len(encoded2), len(original))
|
||||
}
|
||||
|
||||
for i := range original {
|
||||
if encoded2[i] != original[i] {
|
||||
t.Errorf("byte %d: got %02x, want %02x (mutation leaked into pool)", i, encoded2[i], original[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Verify we can still decode the second result
|
||||
decoded, err := decodeCacheAnswer(encoded2)
|
||||
if err != nil {
|
||||
t.Fatalf("decode after mutation failed: %v", err)
|
||||
}
|
||||
|
||||
if decoded.Msg.Question[0].Name != "independent.test." {
|
||||
t.Errorf("decoded name = %s, want independent.test.", decoded.Msg.Question[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeConcurrent(t *testing.T) {
|
||||
// Test concurrent encoding to catch pool-related race conditions
|
||||
// Each goroutine gets its own dns.Msg to avoid racing on shared message internals
|
||||
const numGoroutines = 50
|
||||
const numIterations = 100
|
||||
|
||||
errs := make(chan error, numGoroutines)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for g := 0; g < numGoroutines; g++ {
|
||||
go func(gid int) {
|
||||
defer wg.Done()
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("concurrent.test.", dns.TypeA)
|
||||
msg.Response = true
|
||||
rr, err := dns.NewRR("concurrent.test. 300 IN A 192.0.2.1")
|
||||
if err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
msg.Answer = append(msg.Answer, rr)
|
||||
|
||||
item := &cacheAnswer{
|
||||
Timestamp: time.Now(),
|
||||
PrefetchEligible: true,
|
||||
Msg: msg,
|
||||
}
|
||||
|
||||
for i := 0; i < numIterations; i++ {
|
||||
encoded, err := encodeCacheAnswer(item)
|
||||
if err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
if encoded[0] != binaryFormatVersion {
|
||||
errs <- fmt.Errorf("goroutine %d iteration %d: invalid version byte %d", gid, i, encoded[0])
|
||||
return
|
||||
}
|
||||
decoded, err := decodeCacheAnswer(encoded)
|
||||
if err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
if decoded.Msg.Question[0].Name != "concurrent.test." {
|
||||
errs <- fmt.Errorf("goroutine %d iteration %d: corrupted data, got name %s", gid, i, decoded.Msg.Question[0].Name)
|
||||
return
|
||||
}
|
||||
}
|
||||
}(g)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
for err := range errs {
|
||||
t.Fatalf("concurrent encode/decode failed: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user