diff --git a/cache-redis.go b/cache-redis.go index c86a72d..f2caad9 100644 --- a/cache-redis.go +++ b/cache-redis.go @@ -2,11 +2,13 @@ package rdns import ( "context" + "encoding/binary" "encoding/json" "errors" "expvar" "fmt" "strings" + "sync" "time" "github.com/miekg/dns" @@ -33,6 +35,102 @@ type RedisBackendOptions struct { var _ CacheBackend = (*redisBackend)(nil) +// Buffer pool for dns.Msg.PackBuffer to minimize allocations. +var packBufPool = sync.Pool{ + New: func() any { + b := make([]byte, 0, 2048) + return &b + }, +} + +const ( + binaryFormatVersion = 1 + headerSize = 10 + flagPrefetchBit = 1 << 0 +) + +// encodeCacheAnswer encodes a cacheAnswer into a compact binary format: +// - byte 0: version (1) +// - byte 1: flags (bit0: prefetchEligible) +// - bytes 2..9: timestamp (uint64 seconds from Unix epoch, big endian) +// - bytes 10..N: dns.Msg wire bytes +func encodeCacheAnswer(item *cacheAnswer) ([]byte, error) { + bufPtr := packBufPool.Get().(*[]byte) + buf := *bufPtr + + defer func() { + *bufPtr = buf[:0] + packBufPool.Put(bufPtr) + }() + + if cap(buf) == 0 { + buf = make([]byte, 0, 2048) + } + + // Pack DNS message first into the scratch buffer + buf = buf[:cap(buf)] + dnsWire, err := item.Msg.PackBuffer(buf) + if err != nil { + return nil, fmt.Errorf("failed to pack DNS message: %w", err) + } + + // Keep the (potentially grown) buffer for cleanup + buf = dnsWire + + // Allocate result with header + DNS wire bytes + result := make([]byte, headerSize+len(dnsWire)) + + // Write header + result[0] = binaryFormatVersion + + var flags byte + if item.PrefetchEligible { + flags |= flagPrefetchBit + } + result[1] = flags + + timestamp := uint64(item.Timestamp.Unix()) + binary.BigEndian.PutUint64(result[2:10], timestamp) + + // Copy DNS wire bytes after header + copy(result[headerSize:], dnsWire) + + return result, nil +} + +// decodeCacheAnswer decodes a binary-encoded cacheAnswer. +// Returns an error if the format is invalid or unsupported. +func decodeCacheAnswer(b []byte) (*cacheAnswer, error) { + if len(b) < headerSize { + return nil, fmt.Errorf("binary data too short: %d bytes", len(b)) + } + + // Check version + version := b[0] + if version != binaryFormatVersion { + return nil, fmt.Errorf("unsupported binary format version: %d", version) + } + + // Parse flags + flags := b[1] + prefetchEligible := (flags & flagPrefetchBit) != 0 + + // Parse timestamp + timestamp := int64(binary.BigEndian.Uint64(b[2:10])) + + // Unpack DNS message + msg := new(dns.Msg) + if err := msg.Unpack(b[headerSize:]); err != nil { + return nil, fmt.Errorf("failed to unpack DNS message: %w", err) + } + + return &cacheAnswer{ + Timestamp: time.Unix(timestamp, 0), + PrefetchEligible: prefetchEligible, + Msg: msg, + }, nil +} + func NewRedisBackend(opt RedisBackendOptions) *redisBackend { b := &redisBackend{ client: redis.NewClient(&opt.RedisOptions), @@ -58,14 +156,15 @@ func (b *redisBackend) Store(query *dns.Msg, item *cacheAnswer) { } func (b *redisBackend) storeSync(query *dns.Msg, item *cacheAnswer, ttl time.Duration) { - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() key := b.keyFromQuery(query) - value, err := json.Marshal(item) + value, err := encodeCacheAnswer(item) if err != nil { - Log.Error("failed to marshal cache record", "error", err) + Log.Error("failed to encode cache record", "error", err) return } + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() if err := b.client.Set(ctx, key, value, ttl).Err(); err != nil { Log.Error("failed to write to redis", "error", err) } @@ -89,7 +188,9 @@ func (b *redisBackend) Lookup(q *dns.Msg) (*dns.Msg, bool, bool) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() key := b.keyFromQuery(q) - value, err := b.client.Get(ctx, key).Result() + + // Fetch raw bytes to avoid string conversion overhead + valueBytes, err := b.client.Get(ctx, key).Bytes() if err != nil { if errors.Is(err, redis.Nil) { // Return a cache-miss if there's no such key return nil, false, false @@ -97,10 +198,16 @@ func (b *redisBackend) Lookup(q *dns.Msg) (*dns.Msg, bool, bool) { Log.Error("failed to read from redis", "error", err) return nil, false, false } + + // Try binary decode first, with JSON fallback for backward compatibility var a *cacheAnswer - if err := json.Unmarshal([]byte(value), &a); err != nil { - Log.Error("failed to unmarshal cache record from redis", "error", err) - return nil, false, false + a, err = decodeCacheAnswer(valueBytes) + if err != nil { + // Fallback to JSON for backward compatibility with existing cached entries + if jsonErr := json.Unmarshal(valueBytes, &a); jsonErr != nil { + Log.Error("failed to decode cache record from redis", "binary_error", err, "json_error", jsonErr) + return nil, false, false + } } answer := a.Msg diff --git a/cache-redis_test.go b/cache-redis_test.go new file mode 100644 index 0000000..daf13bc --- /dev/null +++ b/cache-redis_test.go @@ -0,0 +1,296 @@ +package rdns + +import ( + "fmt" + "sync" + "testing" + "time" + + "github.com/miekg/dns" +) + +func TestEncodeDecode(t *testing.T) { + // Create a test DNS message + msg := new(dns.Msg) + msg.SetQuestion("example.com.", dns.TypeA) + msg.Response = true + msg.Rcode = dns.RcodeSuccess + + // Add an answer + rr, err := dns.NewRR("example.com. 300 IN A 192.0.2.1") + if err != nil { + t.Fatalf("failed to create RR: %v", err) + } + msg.Answer = append(msg.Answer, rr) + + // Create a cacheAnswer + now := time.Now() + original := &cacheAnswer{ + Timestamp: now, + PrefetchEligible: true, + Msg: msg, + } + + // Encode + encoded, err := encodeCacheAnswer(original) + if err != nil { + t.Fatalf("encodeCacheAnswer failed: %v", err) + } + + // Verify format + if len(encoded) < headerSize { + t.Fatalf("encoded data too short: %d bytes", len(encoded)) + } + + // Check version byte + if encoded[0] != binaryFormatVersion { + t.Errorf("version byte = %d, want %d", encoded[0], binaryFormatVersion) + } + + // Check flags byte + expectedFlags := byte(flagPrefetchBit) + if encoded[1] != expectedFlags { + t.Errorf("flags byte = %d, want %d", encoded[1], expectedFlags) + } + + // Decode + decoded, err := decodeCacheAnswer(encoded) + if err != nil { + t.Fatalf("decodeCacheAnswer failed: %v", err) + } + + // Verify fields + if decoded.Timestamp.Unix() != original.Timestamp.Unix() { + t.Errorf("timestamp = %v, want %v", decoded.Timestamp, original.Timestamp) + } + + if decoded.PrefetchEligible != original.PrefetchEligible { + t.Errorf("prefetchEligible = %v, want %v", decoded.PrefetchEligible, original.PrefetchEligible) + } + + // Verify DNS message + if len(decoded.Msg.Answer) != len(original.Msg.Answer) { + t.Errorf("answer count = %d, want %d", len(decoded.Msg.Answer), len(original.Msg.Answer)) + } + + if decoded.Msg.Question[0].Name != original.Msg.Question[0].Name { + t.Errorf("question name = %s, want %s", decoded.Msg.Question[0].Name, original.Msg.Question[0].Name) + } + + if decoded.Msg.Question[0].Qtype != original.Msg.Question[0].Qtype { + t.Errorf("question type = %d, want %d", decoded.Msg.Question[0].Qtype, original.Msg.Question[0].Qtype) + } +} + +func TestEncodeDecodeNoPrefetch(t *testing.T) { + // Create a test DNS message + msg := new(dns.Msg) + msg.SetQuestion("test.example.", dns.TypeAAAA) + msg.Response = true + + // Create a cacheAnswer with prefetch disabled + original := &cacheAnswer{ + Timestamp: time.Unix(1234567890, 0), + PrefetchEligible: false, + Msg: msg, + } + + // Encode + encoded, err := encodeCacheAnswer(original) + if err != nil { + t.Fatalf("encodeCacheAnswer failed: %v", err) + } + + // Check flags byte (should be 0) + if encoded[1] != 0 { + t.Errorf("flags byte = %d, want 0", encoded[1]) + } + + // Decode + decoded, err := decodeCacheAnswer(encoded) + if err != nil { + t.Fatalf("decodeCacheAnswer failed: %v", err) + } + + if decoded.PrefetchEligible != false { + t.Errorf("prefetchEligible = %v, want false", decoded.PrefetchEligible) + } +} + +func TestDecodeInvalidData(t *testing.T) { + tests := []struct { + name string + data []byte + }{ + {"too short", []byte{0x01, 0x00}}, + {"wrong version", []byte{0x99, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, + {"invalid DNS", []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := decodeCacheAnswer(tt.data) + if err == nil { + t.Error("expected error, got nil") + } + }) + } +} + +func TestEncodeDecodePooling(t *testing.T) { + // Test that pooling works correctly across multiple encode operations + msg := new(dns.Msg) + msg.SetQuestion("pool.test.", dns.TypeA) + msg.Response = true + + item := &cacheAnswer{ + Timestamp: time.Now(), + PrefetchEligible: true, + Msg: msg, + } + + // Encode multiple times to test pool reuse + for i := 0; i < 100; i++ { + encoded, err := encodeCacheAnswer(item) + if err != nil { + t.Fatalf("iteration %d: encodeCacheAnswer failed: %v", i, err) + } + + // Verify first byte is always version + if encoded[0] != binaryFormatVersion { + t.Errorf("iteration %d: version byte = %d, want %d", i, encoded[0], binaryFormatVersion) + } + + // Decode to verify correctness + decoded, err := decodeCacheAnswer(encoded) + if err != nil { + t.Fatalf("iteration %d: decodeCacheAnswer failed: %v", i, err) + } + + if decoded.Msg.Question[0].Name != "pool.test." { + t.Errorf("iteration %d: corrupted data", i) + } + } +} + +func TestEncodeReturnsIndependentSlice(t *testing.T) { + // Verify that encoded bytes are independent of the pool and mutations don't affect subsequent encodes + msg := new(dns.Msg) + msg.SetQuestion("independent.test.", dns.TypeA) + msg.Response = true + + item := &cacheAnswer{ + Timestamp: time.Unix(1234567890, 0), + PrefetchEligible: true, + Msg: msg, + } + + // First encode + encoded1, err := encodeCacheAnswer(item) + if err != nil { + t.Fatalf("first encode failed: %v", err) + } + + // Save a copy of the original encoded data + original := make([]byte, len(encoded1)) + copy(original, encoded1) + + // Mutate the returned slice to verify it's independent of the pool + for i := range encoded1 { + encoded1[i] = 0xFF + } + + // Verify the mutated buffer is now garbage and fails to decode + _, err = decodeCacheAnswer(encoded1) + if err == nil { + t.Error("expected decode of mutated buffer to fail, but it succeeded") + } + + // Second encode - should succeed and produce the same result as the first + encoded2, err := encodeCacheAnswer(item) + if err != nil { + t.Fatalf("second encode failed: %v", err) + } + + // Verify second encode matches the original (not corrupted by mutation) + if len(encoded2) != len(original) { + t.Fatalf("length mismatch: got %d, want %d", len(encoded2), len(original)) + } + + for i := range original { + if encoded2[i] != original[i] { + t.Errorf("byte %d: got %02x, want %02x (mutation leaked into pool)", i, encoded2[i], original[i]) + } + } + + // Verify we can still decode the second result + decoded, err := decodeCacheAnswer(encoded2) + if err != nil { + t.Fatalf("decode after mutation failed: %v", err) + } + + if decoded.Msg.Question[0].Name != "independent.test." { + t.Errorf("decoded name = %s, want independent.test.", decoded.Msg.Question[0].Name) + } +} + +func TestEncodeConcurrent(t *testing.T) { + // Test concurrent encoding to catch pool-related race conditions + // Each goroutine gets its own dns.Msg to avoid racing on shared message internals + const numGoroutines = 50 + const numIterations = 100 + + errs := make(chan error, numGoroutines) + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for g := 0; g < numGoroutines; g++ { + go func(gid int) { + defer wg.Done() + + msg := new(dns.Msg) + msg.SetQuestion("concurrent.test.", dns.TypeA) + msg.Response = true + rr, err := dns.NewRR("concurrent.test. 300 IN A 192.0.2.1") + if err != nil { + errs <- err + return + } + msg.Answer = append(msg.Answer, rr) + + item := &cacheAnswer{ + Timestamp: time.Now(), + PrefetchEligible: true, + Msg: msg, + } + + for i := 0; i < numIterations; i++ { + encoded, err := encodeCacheAnswer(item) + if err != nil { + errs <- err + return + } + if encoded[0] != binaryFormatVersion { + errs <- fmt.Errorf("goroutine %d iteration %d: invalid version byte %d", gid, i, encoded[0]) + return + } + decoded, err := decodeCacheAnswer(encoded) + if err != nil { + errs <- err + return + } + if decoded.Msg.Question[0].Name != "concurrent.test." { + errs <- fmt.Errorf("goroutine %d iteration %d: corrupted data, got name %s", gid, i, decoded.Msg.Question[0].Name) + return + } + } + }(g) + } + + wg.Wait() + close(errs) + + for err := range errs { + t.Fatalf("concurrent encode/decode failed: %v", err) + } +}