mirror of
https://github.com/folbricht/routedns.git
synced 2025-12-31 22:50:08 -06:00
228 lines
4.5 KiB
Go
228 lines
4.5 KiB
Go
package rdns
|
|
|
|
import (
|
|
"encoding/json"
|
|
"io"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
type lruCache struct {
|
|
maxItems int
|
|
items map[lruKey]*cacheItem
|
|
head, tail *cacheItem
|
|
}
|
|
|
|
type cacheItem struct {
|
|
Key lruKey
|
|
Answer *cacheAnswer
|
|
prev, next *cacheItem
|
|
}
|
|
|
|
type lruKey struct {
|
|
Question dns.Question
|
|
Net string
|
|
Do bool
|
|
}
|
|
|
|
type cacheAnswer struct {
|
|
Timestamp time.Time // Time the record was cached. Needed to adjust TTL
|
|
Expiry time.Time // Time the record expires and should be removed
|
|
PrefetchEligible bool // The cache can prefetch this record
|
|
Msg *dns.Msg
|
|
}
|
|
|
|
func (c cacheAnswer) MarshalJSON() ([]byte, error) {
|
|
msg, err := c.Msg.Pack()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
type alias cacheAnswer
|
|
record := struct {
|
|
alias
|
|
Msg []byte
|
|
}{
|
|
alias: alias(c),
|
|
Msg: msg,
|
|
}
|
|
return json.Marshal(record)
|
|
}
|
|
|
|
func (c *cacheAnswer) UnmarshalJSON(data []byte) error {
|
|
type alias cacheAnswer
|
|
aux := struct {
|
|
*alias
|
|
Msg []byte
|
|
}{
|
|
alias: (*alias)(c),
|
|
}
|
|
if err := json.Unmarshal(data, &aux); err != nil {
|
|
return err
|
|
}
|
|
c.Msg = new(dns.Msg)
|
|
return c.Msg.Unpack(aux.Msg)
|
|
}
|
|
|
|
func newLRUCache(capacity int) *lruCache {
|
|
head := new(cacheItem)
|
|
tail := new(cacheItem)
|
|
head.next = tail
|
|
tail.prev = head
|
|
|
|
return &lruCache{
|
|
maxItems: capacity,
|
|
items: make(map[lruKey]*cacheItem),
|
|
head: head,
|
|
tail: tail,
|
|
}
|
|
}
|
|
|
|
func (c *lruCache) add(query *dns.Msg, answer *cacheAnswer) {
|
|
key := lruKeyFromQuery(query)
|
|
c.addKey(key, answer)
|
|
}
|
|
|
|
func (c *lruCache) addKey(key lruKey, answer *cacheAnswer) {
|
|
item := c.touch(key)
|
|
if item != nil {
|
|
// Update the item, it's already at the top of the list
|
|
// so we can just change the value
|
|
item.Answer = answer
|
|
return
|
|
}
|
|
// Add new item to the top of the linked list
|
|
item = &cacheItem{
|
|
Key: key,
|
|
Answer: answer,
|
|
next: c.head.next,
|
|
prev: c.head,
|
|
}
|
|
c.head.next.prev = item
|
|
c.head.next = item
|
|
c.items[key] = item
|
|
c.resize()
|
|
}
|
|
|
|
// Loads a cache item and puts it to the top of the queue (most recent).
|
|
func (c *lruCache) touch(key lruKey) *cacheItem {
|
|
item := c.items[key]
|
|
if item == nil {
|
|
return nil
|
|
}
|
|
// move the item to the top of the linked list
|
|
item.prev.next = item.next
|
|
item.next.prev = item.prev
|
|
item.next = c.head.next
|
|
item.prev = c.head
|
|
c.head.next.prev = item
|
|
c.head.next = item
|
|
return item
|
|
}
|
|
|
|
func (c *lruCache) delete(q *dns.Msg) {
|
|
key := lruKeyFromQuery(q)
|
|
item := c.items[key]
|
|
if item == nil {
|
|
return
|
|
}
|
|
item.prev.next = item.next
|
|
item.next.prev = item.prev
|
|
delete(c.items, key)
|
|
}
|
|
|
|
func (c *lruCache) get(query *dns.Msg) *cacheAnswer {
|
|
key := lruKeyFromQuery(query)
|
|
item := c.touch(key)
|
|
if item != nil {
|
|
return item.Answer
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Shrink the cache down to the maximum number of items.
|
|
func (c *lruCache) resize() {
|
|
if c.maxItems <= 0 { // no size limit
|
|
return
|
|
}
|
|
drop := len(c.items) - c.maxItems
|
|
for i := 0; i < drop; i++ {
|
|
item := c.tail.prev
|
|
item.prev.next = c.tail
|
|
c.tail.prev = item.prev
|
|
delete(c.items, item.Key)
|
|
}
|
|
}
|
|
|
|
// Clear the cache.
|
|
func (c *lruCache) reset() {
|
|
head := new(cacheItem)
|
|
tail := new(cacheItem)
|
|
head.next = tail
|
|
tail.prev = head
|
|
|
|
c.head = head
|
|
c.tail = tail
|
|
c.items = make(map[lruKey]*cacheItem)
|
|
}
|
|
|
|
// Iterate over the cached answers and call the provided function. If it
|
|
// returns true, the item is deleted from the cache.
|
|
func (c *lruCache) deleteFunc(f func(*cacheAnswer) bool) {
|
|
item := c.head.next
|
|
for item != c.tail {
|
|
if f(item.Answer) {
|
|
item.prev.next = item.next
|
|
item.next.prev = item.prev
|
|
delete(c.items, item.Key)
|
|
}
|
|
item = item.next
|
|
}
|
|
}
|
|
|
|
func (c *lruCache) size() int {
|
|
return len(c.items)
|
|
}
|
|
|
|
func (c *lruCache) serialize(w io.Writer) error {
|
|
enc := json.NewEncoder(w)
|
|
for item := c.tail.prev; item != c.head; item = item.prev {
|
|
if err := enc.Encode(item); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *lruCache) deserialize(r io.Reader) error {
|
|
dec := json.NewDecoder(r)
|
|
for dec.More() {
|
|
item := new(cacheItem)
|
|
if err := dec.Decode(item); err != nil {
|
|
return err
|
|
}
|
|
// Skip bad (or incompatible) records
|
|
if item.Key.Question.Name == "" || item.Answer == nil {
|
|
continue
|
|
}
|
|
c.addKey(item.Key, item.Answer)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func lruKeyFromQuery(q *dns.Msg) lruKey {
|
|
key := lruKey{Question: q.Question[0]}
|
|
|
|
edns0 := q.IsEdns0()
|
|
if edns0 != nil {
|
|
key.Do = edns0.Do()
|
|
// See if we have a subnet option
|
|
for _, opt := range edns0.Option {
|
|
if subnet, ok := opt.(*dns.EDNS0_SUBNET); ok {
|
|
key.Net = subnet.Address.String()
|
|
}
|
|
}
|
|
}
|
|
return key
|
|
}
|