diff --git a/cache-memory.go b/cache-memory.go index fc6a587..800e75d 100644 --- a/cache-memory.go +++ b/cache-memory.go @@ -1,6 +1,7 @@ package rdns import ( + "os" "sync" "time" @@ -9,19 +10,40 @@ import ( ) type memoryBackend struct { - lru *lruCache - mu sync.Mutex - metrics *CacheMetrics + lru *lruCache + mu sync.Mutex + opt MemoryBackendOptions } -var _ cacheBackend = (*memoryBackend)(nil) +type MemoryBackendOptions struct { + // Total capacity of the cache, default unlimited + Capacity int -func newMemoryBackend(capacity int, gcperiod time.Duration, metrics *CacheMetrics) *memoryBackend { - b := &memoryBackend{ - lru: newLRUCache(capacity), - metrics: metrics, + // How often to run garbage collection, default 1 minute + GCPeriod time.Duration + + // Load the cache from file on startup and write it on close + Filename string + + // Write the file in an interval. Only write on shutdown if not set + SaveInterval time.Duration +} + +var _ CacheBackend = (*memoryBackend)(nil) + +func NewMemoryBackend(opt MemoryBackendOptions) *memoryBackend { + if opt.GCPeriod == 0 { + opt.GCPeriod = time.Minute } - go b.startGC(gcperiod) + b := &memoryBackend{ + lru: newLRUCache(opt.Capacity), + opt: opt, + } + if opt.Filename != "" { + b.loadFromFile(opt.Filename) + } + go b.startGC(opt.GCPeriod) + go b.intervalSave() return b } @@ -38,10 +60,10 @@ func (b *memoryBackend) Lookup(q *dns.Msg) (*dns.Msg, bool, bool) { var expiry time.Time b.mu.Lock() if a := b.lru.get(q); a != nil { - answer = a.Copy() - timestamp = a.timestamp - prefetchEligible = a.prefetchEligible - expiry = a.expiry + answer = a.Msg.Copy() + timestamp = a.Timestamp + prefetchEligible = a.PrefetchEligible + expiry = a.Expiry } b.mu.Unlock() @@ -111,7 +133,7 @@ func (b *memoryBackend) startGC(period time.Duration) { var total, removed int b.mu.Lock() b.lru.deleteFunc(func(a *cacheAnswer) bool { - if now.After(a.expiry) { + if now.After(a.Expiry) { removed++ return true } @@ -120,7 +142,67 @@ func (b *memoryBackend) startGC(period time.Duration) { total = b.lru.size() b.mu.Unlock() - b.metrics.entries.Set(int64(total)) Log.WithFields(logrus.Fields{"total": total, "removed": removed}).Trace("cache garbage collection") } } + +func (b *memoryBackend) Size() int { + b.mu.Lock() + defer b.mu.Unlock() + return b.lru.size() +} + +func (b *memoryBackend) Close() error { + if b.opt.Filename != "" { + return b.writeToFile(b.opt.Filename) + } + return nil +} + +func (b *memoryBackend) writeToFile(filename string) error { + b.mu.Lock() + defer b.mu.Unlock() + log := Log.WithField("filename", filename) + log.Info("writing cache file") + f, err := os.Create(filename) + if err != nil { + log.WithError(err).Warn("failed to create cache file") + return err + } + defer f.Close() + + if err := b.lru.serialize(f); err != nil { + log.WithError(err).Warn("failed to persist cache to disk") + return err + } + return nil +} + +func (b *memoryBackend) loadFromFile(filename string) error { + b.mu.Lock() + defer b.mu.Unlock() + log := Log.WithField("filename", filename) + log.Info("reading cache file") + f, err := os.Open(filename) + if err != nil { + log.WithError(err).Warn("failed to open cache file") + return err + } + defer f.Close() + + if err := b.lru.deserialize(f); err != nil { + log.WithError(err).Warn("failed to read cache from disk") + return err + } + return nil +} + +func (b *memoryBackend) intervalSave() { + if b.opt.Filename == "" || b.opt.SaveInterval == 0 { + return + } + for { + time.Sleep(b.opt.SaveInterval) + b.writeToFile(b.opt.Filename) + } +} diff --git a/cache.go b/cache.go index 6687d1b..197c56a 100644 --- a/cache.go +++ b/cache.go @@ -18,7 +18,7 @@ type Cache struct { id string resolver Resolver metrics *CacheMetrics - backend cacheBackend + backend CacheBackend } type CacheMetrics struct { @@ -34,10 +34,14 @@ var _ Resolver = &Cache{} type CacheOptions struct { // Time period the cache garbage collection runs. Defaults to one minute if set to 0. + // + // Deprecated: Pass a configured cache backend instead. GCPeriod time.Duration // Max number of responses to keep in the cache. Defaults to 0 which means no limit. If // the limit is reached, the least-recently used entry is removed from the cache. + // + // Deprecated: Pass a configured cache backend instead. Capacity int // TTL to use for negative responses that do not have an SOA record, default 60 @@ -67,19 +71,24 @@ type CacheOptions struct { // Only records with at least PrefetchEligible seconds TTL are eligible to be prefetched. PrefetchEligible uint32 + + // Cache backend used to store records. + Backend CacheBackend } -type cacheBackend interface { +type CacheBackend interface { Store(query *dns.Msg, item *cacheAnswer) // Lookup a cached response Lookup(q *dns.Msg) (answer *dns.Msg, prefetchEligible bool, ok bool) - // Remove one or more cached responses - Evict(queries ...*dns.Msg) + // Return the number of items in the cache + Size() int // Flush all records in the store Flush() + + Close() error } // NewCache returns a new instance of a Cache resolver. @@ -94,13 +103,25 @@ func NewCache(id string, resolver Resolver, opt CacheOptions) *Cache { entries: getVarInt("cache", id, "entries"), }, } - if c.GCPeriod == 0 { - c.GCPeriod = time.Minute - } if c.NegativeTTL == 0 { c.NegativeTTL = 60 } - c.backend = newMemoryBackend(opt.Capacity, c.GCPeriod, c.metrics) + if opt.Backend == nil { + opt.Backend = NewMemoryBackend(MemoryBackendOptions{ + Capacity: opt.Capacity, + GCPeriod: opt.GCPeriod, + }) + } + c.backend = opt.Backend + + // Regularly query the cache size and emit metrics + go func() { + for { + time.Sleep(time.Minute) + total := c.backend.Size() + c.metrics.entries.Set(int64(total)) + } + }() return c } @@ -227,7 +248,7 @@ func (r *Cache) storeInCache(query, answer *dns.Msg) { now := time.Now() // Prepare an item for the cache, without expiry for now - item := &cacheAnswer{Msg: answer, timestamp: now} + item := &cacheAnswer{Msg: answer, Timestamp: now} // Find the lowest TTL in the response, this determines the expiry for the whole answer in the cache. min, ok := minTTL(answer) @@ -236,17 +257,17 @@ func (r *Cache) storeInCache(query, answer *dns.Msg) { switch answer.Rcode { case dns.RcodeSuccess, dns.RcodeNameError, dns.RcodeRefused, dns.RcodeNotImplemented, dns.RcodeFormatError: if ok { - item.expiry = now.Add(time.Duration(min) * time.Second) - item.prefetchEligible = min > r.CacheOptions.PrefetchEligible + item.Expiry = now.Add(time.Duration(min) * time.Second) + item.PrefetchEligible = min > r.CacheOptions.PrefetchEligible } else { - item.expiry = now.Add(time.Duration(r.NegativeTTL) * time.Second) + item.Expiry = now.Add(time.Duration(r.NegativeTTL) * time.Second) } case dns.RcodeServerFailure: // According to RFC2308, a SERVFAIL response must not be cached for longer than 5 minutes. if r.NegativeTTL < 300 { - item.expiry = now.Add(time.Duration(r.NegativeTTL) * time.Second) + item.Expiry = now.Add(time.Duration(r.NegativeTTL) * time.Second) } else { - item.expiry = now.Add(300 * time.Second) + item.Expiry = now.Add(300 * time.Second) } default: return @@ -255,8 +276,8 @@ func (r *Cache) storeInCache(query, answer *dns.Msg) { // Set the RCODE-based limit if one was configured if rcodeLimit, ok := r.CacheOptions.CacheRcodeMaxTTL[answer.Rcode]; ok { limit := now.Add(time.Duration(rcodeLimit) * time.Second) - if item.expiry.After(limit) { - item.expiry = limit + if item.Expiry.After(limit) { + item.Expiry = limit } } diff --git a/cmd/routedns/config.go b/cmd/routedns/config.go index 54fe8c6..1b33974 100644 --- a/cmd/routedns/config.go +++ b/cmd/routedns/config.go @@ -58,11 +58,19 @@ type doh struct { Method string } +// Cache backend options +type cacheBackend struct { + Type string // Cache backend type.Defaults to "memory" + Size int // Max number of items to keep in the cache. Default 0 == unlimited. Deprecated, use backend + GCPeriod int `toml:"gc-period"` // Time-period (seconds) used to expire cached items + Filename string // File to load/store cache content, optional, for "memory" type cache + SaveInterval int `toml:"save-interval"` // Seconds to write the cache to file +} + type group struct { Resolvers []string Type string Replace []rdns.ReplaceOperation // only used by "replace" type - GCPeriod int `toml:"gc-period"` // Time-period (seconds) used to expire cached items in the "cache" type ECSOp string `toml:"ecs-op"` // ECS modifier operation, "add", "delete", "privacy" ECSAddress net.IP `toml:"ecs-address"` // ECS address. If empty for "add", uses the client IP. Ignored for "privacy" and "delete" ECSPrefix4 uint8 `toml:"ecs-prefix4"` // ECS IPv4 address prefix, 0-32. Used for "add" and "privacy" @@ -79,7 +87,9 @@ type group struct { ServfailError bool `toml:"servfail-error"` // If true, SERVFAIL responses are considered errors and cause failover etc. // Cache options - CacheSize int `toml:"cache-size"` // Max number of items to keep in the cache. Default 0 == unlimited + Backend *cacheBackend + GCPeriod int `toml:"gc-period"` // Time-period (seconds) used to expire cached items in the "cache" type. Deprecated, use backend + CacheSize int `toml:"cache-size"` // Max number of items to keep in the cache. Default 0 == unlimited. Deprecated, use backend CacheNegativeTTL uint32 `toml:"cache-negative-ttl"` // TTL to apply to negative responses, default 60. CacheAnswerShuffle string `toml:"cache-answer-shuffle"` // Algorithm to use for modifying the response order of cached items CacheHardenBelowNXDOMAIN bool `toml:"cache-harden-below-nxdomain"` // Return NXDOMAIN if an NXDOMAIN is cached for a parent domain diff --git a/cmd/routedns/example-config/block-split-cache.toml b/cmd/routedns/example-config/block-split-cache.toml index 16b0d61..c3af526 100644 --- a/cmd/routedns/example-config/block-split-cache.toml +++ b/cmd/routedns/example-config/block-split-cache.toml @@ -11,6 +11,7 @@ protocol = "dot" [groups.cloudflare-cached] type = "cache" resolvers = ["cloudflare-dot"] +backend = {type = "memory"} [routers.router1] routes = [ diff --git a/cmd/routedns/example-config/cache-flush.toml b/cmd/routedns/example-config/cache-flush.toml index 99be247..8dde810 100644 --- a/cmd/routedns/example-config/cache-flush.toml +++ b/cmd/routedns/example-config/cache-flush.toml @@ -9,6 +9,7 @@ resolver = "cloudflare-cached" type = "cache" resolvers = ["cloudflare-dot"] cache-flush-query = "flush.cache." # When a query for this name is received, the cache is reset. +backend = {type = "memory"} [resolvers.cloudflare-dot] address = "1.1.1.1:853" diff --git a/cmd/routedns/example-config/cache-with-prefetch.toml b/cmd/routedns/example-config/cache-with-prefetch.toml index aaa2a6d..1dc5bb9 100644 --- a/cmd/routedns/example-config/cache-with-prefetch.toml +++ b/cmd/routedns/example-config/cache-with-prefetch.toml @@ -9,6 +9,7 @@ type = "cache" resolvers = ["cloudflare-dot"] cache-prefetch-trigger = 10 # Prefetch when the TTL has fallen below this value cache-prefetch-eligible = 20 # Only prefetch records if their original TTL is above this +backend = {type = "memory", filename = "/var/tmp/cache.json"} [listeners.local-udp] address = "127.0.0.1:53" diff --git a/cmd/routedns/example-config/cache.toml b/cmd/routedns/example-config/cache.toml index c6c37f5..af4857e 100644 --- a/cmd/routedns/example-config/cache.toml +++ b/cmd/routedns/example-config/cache.toml @@ -10,6 +10,7 @@ resolvers = ["cloudflare-dot"] cache-size = 1000 # Optional, max number of responses to cache. Default unlimited cache-negative-ttl = 10 # Optional, TTL to apply to responses without a SOA cache-answer-shuffle = "round-robin" # Optional, rotate the order of cached responses +backend = {type = "memory", size = 1000, filename = "/tmp/cache.json", save-interval = 60} [listeners.local-udp] address = "127.0.0.1:53" diff --git a/cmd/routedns/example-config/request-dedup.toml b/cmd/routedns/example-config/request-dedup.toml index 1aeafb7..ea405fa 100644 --- a/cmd/routedns/example-config/request-dedup.toml +++ b/cmd/routedns/example-config/request-dedup.toml @@ -9,6 +9,7 @@ resolver = "cache" [groups.cache] type = "cache" resolvers = ["dedup"] +backend = {type = "memory"} [groups.dedup] type = "request-dedup" diff --git a/cmd/routedns/example-config/simple-dot-cache.toml b/cmd/routedns/example-config/simple-dot-cache.toml index 1850fb9..ec29ac8 100644 --- a/cmd/routedns/example-config/simple-dot-cache.toml +++ b/cmd/routedns/example-config/simple-dot-cache.toml @@ -15,6 +15,7 @@ title = "RouteDNS configuration" type = "cache" resolvers = ["cloudflare-dot"] # Anything that passes the filter is sent on to this resolver #gc-period = 60 # Number of seconds between cache cleanups. Defaults to 1min + backend = {type = "memory"} [listeners] diff --git a/cmd/routedns/example-config/split-config/cache.toml b/cmd/routedns/example-config/split-config/cache.toml index 5817346..ccbaf20 100644 --- a/cmd/routedns/example-config/split-config/cache.toml +++ b/cmd/routedns/example-config/split-config/cache.toml @@ -1,4 +1,4 @@ [groups.cloudflare-cached] type = "cache" resolvers = ["cloudflare-dot"] # Anything that passes the filter is sent on to this resolver - +backend = {type = "memory"} diff --git a/cmd/routedns/example-config/truncate-retry.toml b/cmd/routedns/example-config/truncate-retry.toml index 38ddf7b..3859fd8 100644 --- a/cmd/routedns/example-config/truncate-retry.toml +++ b/cmd/routedns/example-config/truncate-retry.toml @@ -24,6 +24,7 @@ retry-resolver = "cloudflare-tcp" [groups.cache] type = "cache" resolvers = ["retry"] +backend = {type = "memory"} [listeners.local-udp] address = "127.0.0.1:53" diff --git a/cmd/routedns/example-config/ttl-modifier-average.toml b/cmd/routedns/example-config/ttl-modifier-average.toml index 9c6785b..b1a5b03 100644 --- a/cmd/routedns/example-config/ttl-modifier-average.toml +++ b/cmd/routedns/example-config/ttl-modifier-average.toml @@ -13,6 +13,7 @@ ttl-max = 86400 [groups.cloudflare-cached] type = "cache" resolvers = ["cloudflare-updated-ttl"] +backend = {type = "memory"} [listeners.local-udp] address = "127.0.0.1:53" diff --git a/cmd/routedns/example-config/ttl-modifier.toml b/cmd/routedns/example-config/ttl-modifier.toml index 1dca9fd..5978719 100644 --- a/cmd/routedns/example-config/ttl-modifier.toml +++ b/cmd/routedns/example-config/ttl-modifier.toml @@ -13,6 +13,7 @@ ttl-max = 86400 [groups.cloudflare-cached] type = "cache" resolvers = ["cloudflare-updated-ttl"] +backend = {type = "memory"} [listeners.local-udp] address = "127.0.0.1:53" diff --git a/cmd/routedns/example-config/use-case-1.toml b/cmd/routedns/example-config/use-case-1.toml index 3f9d468..086147e 100644 --- a/cmd/routedns/example-config/use-case-1.toml +++ b/cmd/routedns/example-config/use-case-1.toml @@ -8,6 +8,7 @@ protocol = "dot" [groups.cloudflare-cached] type = "cache" resolvers = ["cloudflare-dot"] +backend = {type = "memory"} [listeners.local-udp] address = "127.0.0.1:53" diff --git a/cmd/routedns/example-config/use-case-6.toml b/cmd/routedns/example-config/use-case-6.toml index 8c0f0fe..1ca25fa 100644 --- a/cmd/routedns/example-config/use-case-6.toml +++ b/cmd/routedns/example-config/use-case-6.toml @@ -28,6 +28,7 @@ type = "cache" resolvers = ["ttl-update"] cache-size = 8192 cache-negative-ttl = 120 +backend = {type = "memory"} # Update TTL to avoid noise using values that are too low [groups.ttl-update] diff --git a/cmd/routedns/main.go b/cmd/routedns/main.go index 4db0425..1ffedbb 100644 --- a/cmd/routedns/main.go +++ b/cmd/routedns/main.go @@ -7,7 +7,9 @@ import ( "net" "net/url" "os" + "os/signal" "strconv" + "syscall" "time" syslog "github.com/RackSec/srslog" @@ -70,6 +72,9 @@ func (n Node) ID() string { return n.id } +// Functions to call on shutdown +var onClose []func() + func start(opt options, args []string) error { // Set the log level in the library package if opt.logLevel > 6 { @@ -292,7 +297,16 @@ func start(opt options, args []string) error { }(l) } - select {} + // Graceful shutdown + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP) + <-sig + rdns.Log.Info("stopping") + for _, f := range onClose { + f() + } + + return nil } // Instantiate a group object based on configuration and add to the map of resolvers by ID. @@ -575,6 +589,16 @@ func instantiateGroup(id string, g group, resolvers map[string]rdns.Resolver) er PrefetchTrigger: g.PrefetchTrigger, PrefetchEligible: g.PrefetchEligible, } + if g.Backend != nil { + backend := rdns.NewMemoryBackend(rdns.MemoryBackendOptions{ + Capacity: g.Backend.Size, + GCPeriod: time.Duration(g.Backend.GCPeriod) * time.Second, + Filename: g.Backend.Filename, + SaveInterval: time.Duration(g.Backend.SaveInterval) * time.Second, + }) + onClose = append(onClose, func() { backend.Close() }) + opt.Backend = backend + } resolvers[id] = rdns.NewCache(id, gr[0], opt) case "response-blocklist-ip", "response-blocklist-cidr": // "response-blocklist-cidr" has been retired/renamed to "response-blocklist-ip" if len(gr) != 1 { diff --git a/doc/configuration.md b/doc/configuration.md index 664242a..5757ff4 100644 --- a/doc/configuration.md +++ b/doc/configuration.md @@ -291,6 +291,8 @@ Caches can be combined with a [TTL Modifier](#TTL-Modifier) to avoid too many ca It is possible to pre-define a query name that will flush the cache if received from a client. +The content of memory caches can be persisted to and loaded from disk. + #### Configuration Caches are instantiated with `type = "cache"` in the groups section of the configuration. @@ -298,7 +300,7 @@ Caches are instantiated with `type = "cache"` in the groups section of the confi Options: - `resolvers` - Array of upstream resolvers, only one is supported. -- `cache-size` - Max number of responses to cache. Defaults to 0 which means no limit. Optional +- `cache-size` - Max number of responses to cache. Defaults to 0 which means no limit. Deprecated, set limit in the backend instead. - `cache-negative-ttl` - TTL (in seconds) to apply to responses without a SOA. Default: 60. Optional - `cache-rcode-max-ttl` - Map of RCODE to max TTL (in seconds) to use for records based on the status code regardless of SOA. Response codes are given in their numerical form: 0 = NOERROR, 1 = FORMERR, 2 = SERVFAIL, 3 = NXDOMAIN, ... See [rfc2929#section-2.3](https://tools.ietf.org/html/rfc2929#section-2.3) for a more complete list. For example `{1 = 60, 3 = 60}` would set a limit on how long FORMERR or NXDOMAIN responses can be cached. - `cache-answer-shuffle` - Specifies a method for changing the order of cached A/AAAA answer records. Possible values `random` or `round-robin`. Defaults to static responses if not set. @@ -306,6 +308,18 @@ Options: - `cache-flush-query` - A query name (FQDN with trailing `.`) that if received from a client will trigger a cache flush (reset). Inactive if not set. Simple way to support flushing the cache by sending a pre-defined query name of any type. If successful, the response will be empty. The query will not be forwarded upstream by the cache. - `cache-prefetch-trigger`- If a query is received for a record with less that `cache-prefetch-trigger` TTL left, the cache will send another, independent query to upstream with the goal of automatically refreshing the record in the cache with the response. - `cache-prefetch-eligible` - Only records with at least `prefetch-eligible` seconds TTL are eligible to be prefetched. +- `backend` - Define what kind of storage is used for the cache. Contains multiple keys depending on type that can configure the behavior. Defaults to memory backend if not configued. + +Backends: + +**Memory backend** + +The memory backend will keep all cache items in memory. It can be configured to write the content of the cache to disk on shutdown. Memmory backend config has the following options: + +- `type="memory"` +- `size` - Max number of responses to cache. Defaults to 0 which means no limit. +- `filename` - File to use for persistent storage to disk. The cache will be initialized with the content from the file and it'll write the content to the same file on shutdown. Defaults to no persistence +- `save-interval` - Interval (in seconds) to save the cache to file. Optional. If not set, the file is written only on shutdown. #### Examples @@ -315,6 +329,7 @@ Simple cache without size-limit: [groups.cloudflare-cached] type = "cache" resolvers = ["cloudflare-dot"] +backend = {type = "memory"} ``` Cache that only stores up to 1000 records in memory and keeps negative responses for 1h. Responses are randomized for cached responses. @@ -323,18 +338,19 @@ Cache that only stores up to 1000 records in memory and keeps negative responses [groups.cloudflare-cached] type = "cache" resolvers = ["cloudflare-dot"] -cache-size = 1000 cache-negative-ttl = 3600 cache-answer-shuffle = "random" +backend = {type = "memory", size = 1000} ``` -Cache that is flushed if a query for `flush.cache.` is received. +Cache that is flushed if a query for `flush.cache.` is received. Also persists the cache to disk. ```toml [groups.cloudflare-cached] type = "cache" resolvers = ["cloudflare-dot"] cache-flush-query = "flush.cache." +backend = {type = "memory", filename = "/var/tmp/cache.json"} ``` Example config files: [cache.toml](../cmd/routedns/example-config/cache.toml), [block-split-cache.toml](../cmd/routedns/example-config/block-split-cache.toml), [cache-flush.toml](../cmd/routedns/example-config/cache-flush.toml), [cache-with-prefetch.toml](../cmd/routedns/example-config/cache-with-prefetch.toml), [cache-rcode.toml](../cmd/routedns/example-config/cache-rcode.toml) diff --git a/lru-cache.go b/lru-cache.go index 6c11f4a..d7d8c21 100644 --- a/lru-cache.go +++ b/lru-cache.go @@ -1,6 +1,8 @@ package rdns import ( + "encoding/json" + "io" "time" "github.com/miekg/dns" @@ -13,21 +15,52 @@ type lruCache struct { } type cacheItem struct { - key lruKey - *cacheAnswer + Key lruKey + Answer *cacheAnswer prev, next *cacheItem } type lruKey struct { - question dns.Question - net string + Question dns.Question + Net string } type cacheAnswer struct { - timestamp time.Time // Time the record was cached. Needed to adjust TTL - expiry time.Time // Time the record expires and should be removed - prefetchEligible bool // The cache can prefetch this record - *dns.Msg + Timestamp time.Time // Time the record was cached. Needed to adjust TTL + Expiry time.Time // Time the record expires and should be removed + PrefetchEligible bool // The cache can prefetch this record + Msg *dns.Msg +} + +func (c cacheAnswer) MarshalJSON() ([]byte, error) { + msg, err := c.Msg.Pack() + if err != nil { + return nil, err + } + type alias cacheAnswer + record := struct { + alias + Msg []byte + }{ + alias: alias(c), + Msg: msg, + } + return json.Marshal(record) +} + +func (c *cacheAnswer) UnmarshalJSON(data []byte) error { + type alias cacheAnswer + aux := struct { + *alias + Msg []byte + }{ + alias: (*alias)(c), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + c.Msg = new(dns.Msg) + return c.Msg.Unpack(aux.Msg) } func newLRUCache(capacity int) *lruCache { @@ -46,19 +79,23 @@ func newLRUCache(capacity int) *lruCache { func (c *lruCache) add(query *dns.Msg, answer *cacheAnswer) { key := lruKeyFromQuery(query) + c.addKey(key, answer) +} + +func (c *lruCache) addKey(key lruKey, answer *cacheAnswer) { item := c.touch(key) if item != nil { // Update the item, it's already at the top of the list // so we can just change the value - item.cacheAnswer = answer + item.Answer = answer return } // Add new item to the top of the linked list item = &cacheItem{ - key: key, - cacheAnswer: answer, - next: c.head.next, - prev: c.head, + Key: key, + Answer: answer, + next: c.head.next, + prev: c.head, } c.head.next.prev = item c.head.next = item @@ -97,7 +134,7 @@ func (c *lruCache) get(query *dns.Msg) *cacheAnswer { key := lruKeyFromQuery(query) item := c.touch(key) if item != nil { - return item.cacheAnswer + return item.Answer } return nil } @@ -112,7 +149,7 @@ func (c *lruCache) resize() { item := c.tail.prev item.prev.next = c.tail c.tail.prev = item.prev - delete(c.items, item.key) + delete(c.items, item.Key) } } @@ -133,10 +170,10 @@ func (c *lruCache) reset() { func (c *lruCache) deleteFunc(f func(*cacheAnswer) bool) { item := c.head.next for item != c.tail { - if f(item.cacheAnswer) { + if f(item.Answer) { item.prev.next = item.next item.next.prev = item.prev - delete(c.items, item.key) + delete(c.items, item.Key) } item = item.next } @@ -146,15 +183,37 @@ func (c *lruCache) size() int { return len(c.items) } +func (c *lruCache) serialize(w io.Writer) error { + enc := json.NewEncoder(w) + for item := c.tail.prev; item != c.head; item = item.prev { + if err := enc.Encode(item); err != nil { + return err + } + } + return nil +} + +func (c *lruCache) deserialize(r io.Reader) error { + dec := json.NewDecoder(r) + for dec.More() { + item := new(cacheItem) + if err := dec.Decode(item); err != nil { + return err + } + c.addKey(item.Key, item.Answer) + } + return nil +} + func lruKeyFromQuery(q *dns.Msg) lruKey { - key := lruKey{question: q.Question[0]} + key := lruKey{Question: q.Question[0]} edns0 := q.IsEdns0() if edns0 != nil { // See if we have a subnet option for _, opt := range edns0.Option { if subnet, ok := opt.(*dns.EDNS0_SUBNET); ok { - key.net = subnet.Address.String() + key.Net = subnet.Address.String() } } } diff --git a/lru-cache_test.go b/lru-cache_test.go index 14aec6c..e501a2c 100644 --- a/lru-cache_test.go +++ b/lru-cache_test.go @@ -61,7 +61,7 @@ func TestLRUAddGet(t *testing.T) { // Use an iterator to delete two more c.deleteFunc(func(a *cacheAnswer) bool { - question := a.Question[0] + question := a.Msg.Question[0] return question.Name == "test8.com." || question.Name == "test9.com." }) require.Equal(t, 2, c.size())