Redis cache: use binary wire format with pooled buffers (#473)

* Redis cache: use binary wire format with pooled buffers

* Changes based on PR comments

* Fix binary wire format message being corrupted
This commit is contained in:
emlimap
2025-11-28 16:22:19 +00:00
committed by GitHub
parent 0f3005ce11
commit ff279d5024
2 changed files with 411 additions and 8 deletions

View File

@@ -2,11 +2,13 @@ package rdns
import (
"context"
"encoding/binary"
"encoding/json"
"errors"
"expvar"
"fmt"
"strings"
"sync"
"time"
"github.com/miekg/dns"
@@ -33,6 +35,102 @@ type RedisBackendOptions struct {
var _ CacheBackend = (*redisBackend)(nil)
// Buffer pool for dns.Msg.PackBuffer to minimize allocations.
var packBufPool = sync.Pool{
New: func() any {
b := make([]byte, 0, 2048)
return &b
},
}
const (
binaryFormatVersion = 1
headerSize = 10
flagPrefetchBit = 1 << 0
)
// encodeCacheAnswer encodes a cacheAnswer into a compact binary format:
// - byte 0: version (1)
// - byte 1: flags (bit0: prefetchEligible)
// - bytes 2..9: timestamp (uint64 seconds from Unix epoch, big endian)
// - bytes 10..N: dns.Msg wire bytes
func encodeCacheAnswer(item *cacheAnswer) ([]byte, error) {
bufPtr := packBufPool.Get().(*[]byte)
buf := *bufPtr
defer func() {
*bufPtr = buf[:0]
packBufPool.Put(bufPtr)
}()
if cap(buf) == 0 {
buf = make([]byte, 0, 2048)
}
// Pack DNS message first into the scratch buffer
buf = buf[:cap(buf)]
dnsWire, err := item.Msg.PackBuffer(buf)
if err != nil {
return nil, fmt.Errorf("failed to pack DNS message: %w", err)
}
// Keep the (potentially grown) buffer for cleanup
buf = dnsWire
// Allocate result with header + DNS wire bytes
result := make([]byte, headerSize+len(dnsWire))
// Write header
result[0] = binaryFormatVersion
var flags byte
if item.PrefetchEligible {
flags |= flagPrefetchBit
}
result[1] = flags
timestamp := uint64(item.Timestamp.Unix())
binary.BigEndian.PutUint64(result[2:10], timestamp)
// Copy DNS wire bytes after header
copy(result[headerSize:], dnsWire)
return result, nil
}
// decodeCacheAnswer decodes a binary-encoded cacheAnswer.
// Returns an error if the format is invalid or unsupported.
func decodeCacheAnswer(b []byte) (*cacheAnswer, error) {
if len(b) < headerSize {
return nil, fmt.Errorf("binary data too short: %d bytes", len(b))
}
// Check version
version := b[0]
if version != binaryFormatVersion {
return nil, fmt.Errorf("unsupported binary format version: %d", version)
}
// Parse flags
flags := b[1]
prefetchEligible := (flags & flagPrefetchBit) != 0
// Parse timestamp
timestamp := int64(binary.BigEndian.Uint64(b[2:10]))
// Unpack DNS message
msg := new(dns.Msg)
if err := msg.Unpack(b[headerSize:]); err != nil {
return nil, fmt.Errorf("failed to unpack DNS message: %w", err)
}
return &cacheAnswer{
Timestamp: time.Unix(timestamp, 0),
PrefetchEligible: prefetchEligible,
Msg: msg,
}, nil
}
func NewRedisBackend(opt RedisBackendOptions) *redisBackend {
b := &redisBackend{
client: redis.NewClient(&opt.RedisOptions),
@@ -58,14 +156,15 @@ func (b *redisBackend) Store(query *dns.Msg, item *cacheAnswer) {
}
func (b *redisBackend) storeSync(query *dns.Msg, item *cacheAnswer, ttl time.Duration) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
key := b.keyFromQuery(query)
value, err := json.Marshal(item)
value, err := encodeCacheAnswer(item)
if err != nil {
Log.Error("failed to marshal cache record", "error", err)
Log.Error("failed to encode cache record", "error", err)
return
}
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
if err := b.client.Set(ctx, key, value, ttl).Err(); err != nil {
Log.Error("failed to write to redis", "error", err)
}
@@ -89,7 +188,9 @@ func (b *redisBackend) Lookup(q *dns.Msg) (*dns.Msg, bool, bool) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
key := b.keyFromQuery(q)
value, err := b.client.Get(ctx, key).Result()
// Fetch raw bytes to avoid string conversion overhead
valueBytes, err := b.client.Get(ctx, key).Bytes()
if err != nil {
if errors.Is(err, redis.Nil) { // Return a cache-miss if there's no such key
return nil, false, false
@@ -97,10 +198,16 @@ func (b *redisBackend) Lookup(q *dns.Msg) (*dns.Msg, bool, bool) {
Log.Error("failed to read from redis", "error", err)
return nil, false, false
}
// Try binary decode first, with JSON fallback for backward compatibility
var a *cacheAnswer
if err := json.Unmarshal([]byte(value), &a); err != nil {
Log.Error("failed to unmarshal cache record from redis", "error", err)
return nil, false, false
a, err = decodeCacheAnswer(valueBytes)
if err != nil {
// Fallback to JSON for backward compatibility with existing cached entries
if jsonErr := json.Unmarshal(valueBytes, &a); jsonErr != nil {
Log.Error("failed to decode cache record from redis", "binary_error", err, "json_error", jsonErr)
return nil, false, false
}
}
answer := a.Msg

296
cache-redis_test.go Normal file
View File

@@ -0,0 +1,296 @@
package rdns
import (
"fmt"
"sync"
"testing"
"time"
"github.com/miekg/dns"
)
func TestEncodeDecode(t *testing.T) {
// Create a test DNS message
msg := new(dns.Msg)
msg.SetQuestion("example.com.", dns.TypeA)
msg.Response = true
msg.Rcode = dns.RcodeSuccess
// Add an answer
rr, err := dns.NewRR("example.com. 300 IN A 192.0.2.1")
if err != nil {
t.Fatalf("failed to create RR: %v", err)
}
msg.Answer = append(msg.Answer, rr)
// Create a cacheAnswer
now := time.Now()
original := &cacheAnswer{
Timestamp: now,
PrefetchEligible: true,
Msg: msg,
}
// Encode
encoded, err := encodeCacheAnswer(original)
if err != nil {
t.Fatalf("encodeCacheAnswer failed: %v", err)
}
// Verify format
if len(encoded) < headerSize {
t.Fatalf("encoded data too short: %d bytes", len(encoded))
}
// Check version byte
if encoded[0] != binaryFormatVersion {
t.Errorf("version byte = %d, want %d", encoded[0], binaryFormatVersion)
}
// Check flags byte
expectedFlags := byte(flagPrefetchBit)
if encoded[1] != expectedFlags {
t.Errorf("flags byte = %d, want %d", encoded[1], expectedFlags)
}
// Decode
decoded, err := decodeCacheAnswer(encoded)
if err != nil {
t.Fatalf("decodeCacheAnswer failed: %v", err)
}
// Verify fields
if decoded.Timestamp.Unix() != original.Timestamp.Unix() {
t.Errorf("timestamp = %v, want %v", decoded.Timestamp, original.Timestamp)
}
if decoded.PrefetchEligible != original.PrefetchEligible {
t.Errorf("prefetchEligible = %v, want %v", decoded.PrefetchEligible, original.PrefetchEligible)
}
// Verify DNS message
if len(decoded.Msg.Answer) != len(original.Msg.Answer) {
t.Errorf("answer count = %d, want %d", len(decoded.Msg.Answer), len(original.Msg.Answer))
}
if decoded.Msg.Question[0].Name != original.Msg.Question[0].Name {
t.Errorf("question name = %s, want %s", decoded.Msg.Question[0].Name, original.Msg.Question[0].Name)
}
if decoded.Msg.Question[0].Qtype != original.Msg.Question[0].Qtype {
t.Errorf("question type = %d, want %d", decoded.Msg.Question[0].Qtype, original.Msg.Question[0].Qtype)
}
}
func TestEncodeDecodeNoPrefetch(t *testing.T) {
// Create a test DNS message
msg := new(dns.Msg)
msg.SetQuestion("test.example.", dns.TypeAAAA)
msg.Response = true
// Create a cacheAnswer with prefetch disabled
original := &cacheAnswer{
Timestamp: time.Unix(1234567890, 0),
PrefetchEligible: false,
Msg: msg,
}
// Encode
encoded, err := encodeCacheAnswer(original)
if err != nil {
t.Fatalf("encodeCacheAnswer failed: %v", err)
}
// Check flags byte (should be 0)
if encoded[1] != 0 {
t.Errorf("flags byte = %d, want 0", encoded[1])
}
// Decode
decoded, err := decodeCacheAnswer(encoded)
if err != nil {
t.Fatalf("decodeCacheAnswer failed: %v", err)
}
if decoded.PrefetchEligible != false {
t.Errorf("prefetchEligible = %v, want false", decoded.PrefetchEligible)
}
}
func TestDecodeInvalidData(t *testing.T) {
tests := []struct {
name string
data []byte
}{
{"too short", []byte{0x01, 0x00}},
{"wrong version", []byte{0x99, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}},
{"invalid DNS", []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := decodeCacheAnswer(tt.data)
if err == nil {
t.Error("expected error, got nil")
}
})
}
}
func TestEncodeDecodePooling(t *testing.T) {
// Test that pooling works correctly across multiple encode operations
msg := new(dns.Msg)
msg.SetQuestion("pool.test.", dns.TypeA)
msg.Response = true
item := &cacheAnswer{
Timestamp: time.Now(),
PrefetchEligible: true,
Msg: msg,
}
// Encode multiple times to test pool reuse
for i := 0; i < 100; i++ {
encoded, err := encodeCacheAnswer(item)
if err != nil {
t.Fatalf("iteration %d: encodeCacheAnswer failed: %v", i, err)
}
// Verify first byte is always version
if encoded[0] != binaryFormatVersion {
t.Errorf("iteration %d: version byte = %d, want %d", i, encoded[0], binaryFormatVersion)
}
// Decode to verify correctness
decoded, err := decodeCacheAnswer(encoded)
if err != nil {
t.Fatalf("iteration %d: decodeCacheAnswer failed: %v", i, err)
}
if decoded.Msg.Question[0].Name != "pool.test." {
t.Errorf("iteration %d: corrupted data", i)
}
}
}
func TestEncodeReturnsIndependentSlice(t *testing.T) {
// Verify that encoded bytes are independent of the pool and mutations don't affect subsequent encodes
msg := new(dns.Msg)
msg.SetQuestion("independent.test.", dns.TypeA)
msg.Response = true
item := &cacheAnswer{
Timestamp: time.Unix(1234567890, 0),
PrefetchEligible: true,
Msg: msg,
}
// First encode
encoded1, err := encodeCacheAnswer(item)
if err != nil {
t.Fatalf("first encode failed: %v", err)
}
// Save a copy of the original encoded data
original := make([]byte, len(encoded1))
copy(original, encoded1)
// Mutate the returned slice to verify it's independent of the pool
for i := range encoded1 {
encoded1[i] = 0xFF
}
// Verify the mutated buffer is now garbage and fails to decode
_, err = decodeCacheAnswer(encoded1)
if err == nil {
t.Error("expected decode of mutated buffer to fail, but it succeeded")
}
// Second encode - should succeed and produce the same result as the first
encoded2, err := encodeCacheAnswer(item)
if err != nil {
t.Fatalf("second encode failed: %v", err)
}
// Verify second encode matches the original (not corrupted by mutation)
if len(encoded2) != len(original) {
t.Fatalf("length mismatch: got %d, want %d", len(encoded2), len(original))
}
for i := range original {
if encoded2[i] != original[i] {
t.Errorf("byte %d: got %02x, want %02x (mutation leaked into pool)", i, encoded2[i], original[i])
}
}
// Verify we can still decode the second result
decoded, err := decodeCacheAnswer(encoded2)
if err != nil {
t.Fatalf("decode after mutation failed: %v", err)
}
if decoded.Msg.Question[0].Name != "independent.test." {
t.Errorf("decoded name = %s, want independent.test.", decoded.Msg.Question[0].Name)
}
}
func TestEncodeConcurrent(t *testing.T) {
// Test concurrent encoding to catch pool-related race conditions
// Each goroutine gets its own dns.Msg to avoid racing on shared message internals
const numGoroutines = 50
const numIterations = 100
errs := make(chan error, numGoroutines)
var wg sync.WaitGroup
wg.Add(numGoroutines)
for g := 0; g < numGoroutines; g++ {
go func(gid int) {
defer wg.Done()
msg := new(dns.Msg)
msg.SetQuestion("concurrent.test.", dns.TypeA)
msg.Response = true
rr, err := dns.NewRR("concurrent.test. 300 IN A 192.0.2.1")
if err != nil {
errs <- err
return
}
msg.Answer = append(msg.Answer, rr)
item := &cacheAnswer{
Timestamp: time.Now(),
PrefetchEligible: true,
Msg: msg,
}
for i := 0; i < numIterations; i++ {
encoded, err := encodeCacheAnswer(item)
if err != nil {
errs <- err
return
}
if encoded[0] != binaryFormatVersion {
errs <- fmt.Errorf("goroutine %d iteration %d: invalid version byte %d", gid, i, encoded[0])
return
}
decoded, err := decodeCacheAnswer(encoded)
if err != nil {
errs <- err
return
}
if decoded.Msg.Question[0].Name != "concurrent.test." {
errs <- fmt.Errorf("goroutine %d iteration %d: corrupted data, got name %s", gid, i, decoded.Msg.Question[0].Name)
return
}
}
}(g)
}
wg.Wait()
close(errs)
for err := range errs {
t.Fatalf("concurrent encode/decode failed: %v", err)
}
}