mirror of
https://github.com/folbricht/routedns.git
synced 2026-01-01 07:01:55 -06:00
* Migrate from logrus to slog * fully removing logrus * should be working now * Update pipeline.go Co-authored-by: Frank Olbricht <frank.olbricht@gmail.com> * Update response-blocklist-name.go Co-authored-by: Frank Olbricht <frank.olbricht@gmail.com> * added null logger * Update pipeline.go --------- Co-authored-by: Frank Olbricht <frank.olbricht@gmail.com>
140 lines
3.6 KiB
Go
140 lines
3.6 KiB
Go
package rdns
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
type redisBackend struct {
|
|
client *redis.Client
|
|
opt RedisBackendOptions
|
|
}
|
|
|
|
type RedisBackendOptions struct {
|
|
RedisOptions redis.Options
|
|
KeyPrefix string
|
|
}
|
|
|
|
var _ CacheBackend = (*redisBackend)(nil)
|
|
|
|
func NewRedisBackend(opt RedisBackendOptions) *redisBackend {
|
|
b := &redisBackend{
|
|
client: redis.NewClient(&opt.RedisOptions),
|
|
opt: opt,
|
|
}
|
|
return b
|
|
}
|
|
|
|
func (b *redisBackend) Store(query *dns.Msg, item *cacheAnswer) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
defer cancel()
|
|
key := b.keyFromQuery(query)
|
|
value, err := json.Marshal(item)
|
|
if err != nil {
|
|
Log.Error("failed to marshal cache record", "error", err)
|
|
return
|
|
}
|
|
if err := b.client.Set(ctx, key, value, time.Until(item.Expiry)).Err(); err != nil {
|
|
Log.Error("failed to write to redis", "error", err)
|
|
}
|
|
}
|
|
|
|
func (b *redisBackend) Lookup(q *dns.Msg) (*dns.Msg, bool, bool) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
defer cancel()
|
|
key := b.keyFromQuery(q)
|
|
value, err := b.client.Get(ctx, key).Result()
|
|
if err != nil {
|
|
if errors.Is(err, redis.Nil) { // Return a cache-miss if there's no such key
|
|
return nil, false, false
|
|
}
|
|
Log.Error("failed to read from redis", "error", err)
|
|
return nil, false, false
|
|
}
|
|
var a *cacheAnswer
|
|
if err := json.Unmarshal([]byte(value), &a); err != nil {
|
|
Log.Error("failed to unmarshal cache record from redis", "error", err)
|
|
return nil, false, false
|
|
}
|
|
|
|
answer := a.Msg
|
|
prefetchEligible := a.PrefetchEligible
|
|
answer.Id = q.Id
|
|
|
|
// Calculate the time the record spent in the cache. We need to
|
|
// subtract that from the TTL of each answer record.
|
|
age := uint32(time.Since(a.Timestamp).Seconds())
|
|
|
|
// Go through all the answers, NS, and Extra and adjust the TTL (subtract the time
|
|
// it's spent in the cache). If the record is too old, evict it from the cache
|
|
// and return a cache-miss. OPT records have a TTL of 0 and are ignored.
|
|
for _, rr := range [][]dns.RR{answer.Answer, answer.Ns, answer.Extra} {
|
|
for _, a := range rr {
|
|
if _, ok := a.(*dns.OPT); ok {
|
|
continue
|
|
}
|
|
h := a.Header()
|
|
if age >= h.Ttl {
|
|
return nil, false, false
|
|
}
|
|
h.Ttl -= age
|
|
}
|
|
}
|
|
|
|
return answer, prefetchEligible, true
|
|
}
|
|
|
|
func (b *redisBackend) Flush() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
defer cancel()
|
|
if _, err := b.client.Del(ctx, b.opt.KeyPrefix+"*").Result(); err != nil {
|
|
Log.Error("failed to delete keys in redis", "error", err)
|
|
}
|
|
}
|
|
|
|
func (b *redisBackend) Size() int {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
defer cancel()
|
|
size, err := b.client.DBSize(ctx).Result()
|
|
if err != nil {
|
|
Log.Error("failed to run dbsize command on redis", "error", err)
|
|
}
|
|
return int(size)
|
|
}
|
|
|
|
func (b *redisBackend) Close() error {
|
|
return b.client.Close()
|
|
}
|
|
|
|
// Build a key string to be used in redis.
|
|
func (b *redisBackend) keyFromQuery(q *dns.Msg) string {
|
|
var key strings.Builder
|
|
key.WriteString(b.opt.KeyPrefix)
|
|
key.WriteString(q.Question[0].Name)
|
|
key.WriteByte(':')
|
|
key.WriteString(dns.Class(q.Question[0].Qclass).String())
|
|
key.WriteByte(':')
|
|
key.WriteString(dns.Type(q.Question[0].Qtype).String())
|
|
key.WriteByte(':')
|
|
|
|
edns0 := q.IsEdns0()
|
|
if edns0 != nil {
|
|
key.WriteString(fmt.Sprintf("%t", edns0.Do()))
|
|
key.WriteByte(':')
|
|
// See if we have a subnet option
|
|
for _, opt := range edns0.Option {
|
|
if subnet, ok := opt.(*dns.EDNS0_SUBNET); ok {
|
|
key.WriteString(subnet.Address.String())
|
|
}
|
|
}
|
|
}
|
|
return key.String()
|
|
}
|