mirror of
https://github.com/folbricht/routedns.git
synced 2025-12-31 06:29:59 -06:00
216 lines
4.7 KiB
Go
216 lines
4.7 KiB
Go
package rdns
|
|
|
|
import (
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
|
|
"log/slog"
|
|
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
type memoryBackend struct {
|
|
lru *lruCache
|
|
mu sync.Mutex
|
|
opt MemoryBackendOptions
|
|
}
|
|
|
|
type MemoryBackendOptions struct {
|
|
// Total capacity of the cache, default unlimited
|
|
Capacity int
|
|
|
|
// How often to run garbage collection, default 1 minute
|
|
GCPeriod time.Duration
|
|
|
|
// Load the cache from file on startup and write it on close
|
|
Filename string
|
|
|
|
// Write the file in an interval. Only write on shutdown if not set
|
|
SaveInterval time.Duration
|
|
}
|
|
|
|
var _ CacheBackend = (*memoryBackend)(nil)
|
|
|
|
func NewMemoryBackend(opt MemoryBackendOptions) *memoryBackend {
|
|
if opt.GCPeriod == 0 {
|
|
opt.GCPeriod = time.Minute
|
|
}
|
|
b := &memoryBackend{
|
|
lru: newLRUCache(opt.Capacity),
|
|
opt: opt,
|
|
}
|
|
if opt.Filename != "" {
|
|
b.loadFromFile(opt.Filename)
|
|
}
|
|
go b.startGC(opt.GCPeriod)
|
|
go b.intervalSave()
|
|
return b
|
|
}
|
|
|
|
func (b *memoryBackend) Store(query *dns.Msg, item *cacheAnswer) {
|
|
b.mu.Lock()
|
|
b.lru.add(query, item)
|
|
b.mu.Unlock()
|
|
}
|
|
|
|
func (b *memoryBackend) Lookup(q *dns.Msg) (*dns.Msg, bool, bool) {
|
|
var answer *dns.Msg
|
|
var timestamp time.Time
|
|
var prefetchEligible bool
|
|
var expiry time.Time
|
|
b.mu.Lock()
|
|
if a := b.lru.get(q); a != nil {
|
|
answer = a.Msg.Copy()
|
|
timestamp = a.Timestamp
|
|
prefetchEligible = a.PrefetchEligible
|
|
expiry = a.Expiry
|
|
}
|
|
b.mu.Unlock()
|
|
|
|
// Return a cache-miss if there's no answer record in the map
|
|
if answer == nil {
|
|
return nil, false, false
|
|
}
|
|
|
|
// Check if item has expired from the cache
|
|
if time.Now().After(expiry) {
|
|
b.Evict(q)
|
|
return nil, false, false
|
|
}
|
|
|
|
// Make a copy of the response before returning it. Some later
|
|
// elements might make changes.
|
|
answer = answer.Copy()
|
|
answer.Id = q.Id
|
|
answer.Question = q.Question // restore the case used in the question
|
|
|
|
// Calculate the time the record spent in the cache. We need to
|
|
// subtract that from the TTL of each answer record.
|
|
age := uint32(time.Since(timestamp).Seconds())
|
|
|
|
// Go through all the answers, NS, and Extra and adjust the TTL (subtract the time
|
|
// it's spent in the cache). If the record is too old, evict it from the cache
|
|
// and return a cache-miss. OPT records have a TTL of 0 and are ignored.
|
|
for _, rr := range [][]dns.RR{answer.Answer, answer.Ns, answer.Extra} {
|
|
for _, a := range rr {
|
|
if _, ok := a.(*dns.OPT); ok {
|
|
continue
|
|
}
|
|
h := a.Header()
|
|
if age >= h.Ttl {
|
|
b.Evict(q)
|
|
return nil, false, false
|
|
}
|
|
h.Ttl -= age
|
|
}
|
|
}
|
|
|
|
return answer, prefetchEligible, true
|
|
}
|
|
|
|
func (b *memoryBackend) Evict(queries ...*dns.Msg) {
|
|
b.mu.Lock()
|
|
for _, query := range queries {
|
|
b.lru.delete(query)
|
|
}
|
|
b.mu.Unlock()
|
|
}
|
|
|
|
func (b *memoryBackend) Flush() {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
b.lru.reset()
|
|
}
|
|
|
|
// Runs every period time and evicts all items from the cache that are
|
|
// older than max, regardless of TTL. Note that the cache can hold old
|
|
// records that are no longer valid. These will only be evicted once
|
|
// a new query for them is made (and TTL is too old) or when they are
|
|
// older than max.
|
|
func (b *memoryBackend) startGC(period time.Duration) {
|
|
for {
|
|
time.Sleep(period)
|
|
now := time.Now()
|
|
var total, removed int
|
|
b.mu.Lock()
|
|
b.lru.deleteFunc(func(a *cacheAnswer) bool {
|
|
if now.After(a.Expiry) {
|
|
removed++
|
|
return true
|
|
}
|
|
return false
|
|
})
|
|
total = b.lru.size()
|
|
b.mu.Unlock()
|
|
|
|
Log.Debug("cache garbage collection",
|
|
slog.Group("details",
|
|
slog.Int("total", total),
|
|
slog.Int("removed", removed),
|
|
),
|
|
)
|
|
}
|
|
}
|
|
|
|
func (b *memoryBackend) Size() int {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
return b.lru.size()
|
|
}
|
|
|
|
func (b *memoryBackend) Close() error {
|
|
if b.opt.Filename != "" {
|
|
return b.writeToFile(b.opt.Filename)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (b *memoryBackend) writeToFile(filename string) error {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
log := Log.With("filename", filename)
|
|
log.Info("writing cache file")
|
|
f, err := os.Create(filename)
|
|
if err != nil {
|
|
log.Warn("failed to create cache file", "error", err)
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
|
|
if err := b.lru.serialize(f); err != nil {
|
|
log.Warn("failed to persist cache to disk", "error", err)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (b *memoryBackend) loadFromFile(filename string) error {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
log := Log.With("filename", filename)
|
|
log.Info("reading cache file")
|
|
f, err := os.Open(filename)
|
|
if err != nil {
|
|
log.Warn("failed to open cache file", "error", err)
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
|
|
if err := b.lru.deserialize(f); err != nil {
|
|
log.Warn("failed to read cache from disk", "error", err)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (b *memoryBackend) intervalSave() {
|
|
if b.opt.Filename == "" || b.opt.SaveInterval == 0 {
|
|
return
|
|
}
|
|
for {
|
|
time.Sleep(b.opt.SaveInterval)
|
|
b.writeToFile(b.opt.Filename)
|
|
}
|
|
}
|