mirror of
https://github.com/folbricht/routedns.git
synced 2025-12-31 14:40:24 -06:00
277 lines
7.8 KiB
Go
277 lines
7.8 KiB
Go
package rdns
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
// Defines how long to wait for a response from the resolver if no other timeout is given.
|
|
const defaultQueryTimeout = 2 * time.Second
|
|
|
|
// Tear down an upstream connection if nothing has been received for this long.
|
|
const idleTimeout = 10 * time.Second
|
|
|
|
// Pipeline is a DNS client that is able to use pipelining for multiple requests over
|
|
// one connection, handle out-of-order responses and deals with disconnects
|
|
// gracefully. It opens a single connection on demand and uses it for all queries.
|
|
// It can manage UDP, TCP, DNS-over-TLS, and DNS-over-DTLS connections.
|
|
type Pipeline struct {
|
|
addr string
|
|
client DNSDialer
|
|
requests chan *request
|
|
metrics *ListenerMetrics
|
|
timeout time.Duration
|
|
}
|
|
|
|
// DNSDialer is an abstraction for a dns.Client that returns a *dns.Conn.
|
|
type DNSDialer interface {
|
|
Dial(address string) (*dns.Conn, error)
|
|
}
|
|
|
|
// NewPipeline returns an initialized (and running) DNS connection manager.
|
|
func NewPipeline(id string, addr string, client DNSDialer, timeout time.Duration) *Pipeline {
|
|
if timeout == 0 {
|
|
timeout = defaultQueryTimeout
|
|
}
|
|
c := &Pipeline{
|
|
addr: addr,
|
|
client: client,
|
|
requests: make(chan *request),
|
|
metrics: NewListenerMetrics("client", id),
|
|
timeout: timeout,
|
|
}
|
|
go c.start()
|
|
return c
|
|
}
|
|
|
|
// Resolve a single query using this connection.
|
|
func (c *Pipeline) Resolve(q *dns.Msg) (*dns.Msg, error) {
|
|
r := newRequest(q)
|
|
|
|
timeout := time.NewTimer(c.timeout)
|
|
defer timeout.Stop()
|
|
|
|
// Queue up the request or time out
|
|
select {
|
|
case c.requests <- r:
|
|
case <-timeout.C:
|
|
c.metrics.err.Add("querytimeout", 1)
|
|
return nil, QueryTimeoutError{q}
|
|
}
|
|
|
|
// Wait for the request to complete or time out
|
|
select {
|
|
case <-r.done:
|
|
case <-timeout.C:
|
|
c.metrics.err.Add("querytimeout", 1)
|
|
return nil, QueryTimeoutError{q}
|
|
}
|
|
|
|
return r.waitFor()
|
|
}
|
|
|
|
// Starts a loop that will wait for queries and open an upstream connection on-demand, writing queries
|
|
// and reading answers concurrently using the same connection. It also handles errors like idle
|
|
// close from upstream.
|
|
func (c *Pipeline) start() {
|
|
var (
|
|
wg sync.WaitGroup
|
|
inFlight inFlightQueue
|
|
)
|
|
log := Log.With("addr", c.addr)
|
|
for req := range c.requests { // Lazy connection. Only open a real connection if there's a request
|
|
done := make(chan struct{})
|
|
log.Debug("opening connection")
|
|
conn, err := c.client.Dial(c.addr)
|
|
if err != nil {
|
|
c.metrics.err.Add("open", 1)
|
|
log.Warn("failed to open connection", "error", err)
|
|
req.markDone(nil, err)
|
|
continue
|
|
}
|
|
wg.Add(2)
|
|
|
|
go func(r *request) { c.requests <- r }(req) // re-queue the request that triggered the upstream connection
|
|
|
|
go func() { // writer
|
|
for {
|
|
select {
|
|
case req := <-c.requests:
|
|
query := inFlight.add(req)
|
|
log.With("qname", qName(query)).Debug("sending query")
|
|
c.metrics.query.Add(1)
|
|
if err := conn.WriteMsg(query); err != nil {
|
|
req.markDone(nil, err) // fail the request
|
|
inFlight.get(query) // clean up the in-flight queue so it doesn't keep growing
|
|
conn.Close() // throw away this connection, should wake up the reader as well
|
|
wg.Done()
|
|
c.metrics.err.Add("send_query", 1)
|
|
log.With("qname", qName(query)).Debug("failed sending query",
|
|
"error", err)
|
|
return
|
|
}
|
|
case <-done: // the reader ran into an error and we want to stop using this connection
|
|
wg.Done()
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
go func() { // reader
|
|
for {
|
|
// Set the idle deadline on the reader, not the writer since when using UDP "connections",
|
|
// a network topology change wouldn't be noticed. Putting the idle timeout here ensures
|
|
// a reconnect in that case as well. This does create a very slight race however if the
|
|
// sender is using the connection right at the time of the timeout in the receiver.
|
|
_ = conn.SetReadDeadline(time.Now().Add(idleTimeout))
|
|
a, err := conn.ReadMsg()
|
|
if err != nil {
|
|
switch e := err.(type) {
|
|
case net.Error:
|
|
if e.Timeout() {
|
|
log.Debug("connection terminated by idle timeout")
|
|
} else {
|
|
c.metrics.err.Add("server_term", 1)
|
|
log.Debug("connection terminated by server")
|
|
}
|
|
close(done) // tell the writer to not use this connection anymore
|
|
wg.Done()
|
|
return
|
|
default:
|
|
if err == io.EOF {
|
|
c.metrics.err.Add("server_eof", 1)
|
|
log.Debug("connection terminated by server")
|
|
close(done) // tell the writer to not use this connection anymore
|
|
wg.Done()
|
|
return
|
|
}
|
|
// It's possible the response can't be correctly parsed, but we do have a response.
|
|
// In this case, return it and carry on, don't terminate the connection because we
|
|
// got a bad packet (like a truncated one for example).
|
|
if a == nil {
|
|
c.metrics.err.Add("read", 1)
|
|
log.Warn("read failed", "error", err)
|
|
close(done) // tell the writer to not use this connection anymore
|
|
wg.Done()
|
|
return
|
|
}
|
|
log.Warn("failed to read response", "error", err, "qname", qName(a))
|
|
}
|
|
}
|
|
req := inFlight.get(a) // match the answer to an in-flight query
|
|
if req == nil {
|
|
c.metrics.err.Add("unexpected_a", 1)
|
|
log.With("qname", qName(a)).Warn("unexpected answer received, ignoring")
|
|
continue
|
|
}
|
|
c.metrics.response.Add(rCode(a), 1)
|
|
req.markDone(a, nil)
|
|
ql := inFlight.maxQueueLen()
|
|
if ql > c.metrics.maxQueueLen.Value() {
|
|
c.metrics.maxQueueLen.Set(ql)
|
|
}
|
|
}
|
|
}()
|
|
|
|
// wait for both, sender and receiver to terminate before trying to reconnect
|
|
wg.Wait()
|
|
}
|
|
}
|
|
|
|
// Request received from a client. It also contains the response and a channel that is
|
|
// closed when the request is done.
|
|
type request struct {
|
|
q, a *dns.Msg
|
|
err error
|
|
done chan struct{}
|
|
}
|
|
|
|
func newRequest(q *dns.Msg) *request {
|
|
return &request{
|
|
q: q,
|
|
done: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
// Wait for the request to be completed and return the answer.
|
|
func (r *request) waitFor() (*dns.Msg, error) {
|
|
<-r.done
|
|
|
|
if r.err == nil {
|
|
// As per https://tools.ietf.org/html/rfc7858#section-3.3, we need to double check this
|
|
// really is the correct response.
|
|
if len(r.a.Question) > 0 && len(r.q.Question) > 0 {
|
|
q := r.q.Question[0]
|
|
a := r.a.Question[0]
|
|
if a.Name != q.Name || a.Qclass != q.Qclass || a.Qtype != q.Qtype {
|
|
return nil, fmt.Errorf("expected answer for %s, got %s", q.String(), a.String())
|
|
}
|
|
}
|
|
}
|
|
|
|
return r.a, r.err
|
|
}
|
|
|
|
// Mark the request as complete.
|
|
func (r *request) markDone(a *dns.Msg, err error) {
|
|
if a != nil {
|
|
a.Id = r.q.Id // Fix the query ID in the answer to match the query
|
|
}
|
|
r.a = a
|
|
r.err = err
|
|
close(r.done)
|
|
}
|
|
|
|
// Queue to manage requests that are in flight. Used to asynchronously match received
|
|
// responses with their requests.
|
|
type inFlightQueue struct {
|
|
requests map[uint16]*request
|
|
mu sync.Mutex
|
|
idCounter uint16
|
|
maxLen int
|
|
}
|
|
|
|
// Add a request to the queue and return an updated DNS query with a new ID. The ID needs
|
|
// to be unique per connection, and we could be receiving multiple queries with the same
|
|
// ID. So make up a new ID, used that in the query upstream, then map it back to the
|
|
// request and replace the ID with the original one.
|
|
func (q *inFlightQueue) add(r *request) *dns.Msg {
|
|
q.mu.Lock()
|
|
defer q.mu.Unlock()
|
|
if q.requests == nil {
|
|
q.requests = make(map[uint16]*request)
|
|
}
|
|
q.idCounter++
|
|
q.requests[q.idCounter] = r
|
|
query := r.q.Copy()
|
|
query.Id = q.idCounter
|
|
if len(q.requests) > q.maxLen {
|
|
q.maxLen = len(q.requests)
|
|
}
|
|
return query
|
|
}
|
|
|
|
// Returns the request for a given query ID, or nil if the request isn't in the queue. The
|
|
// request is removed from the queue.
|
|
func (q *inFlightQueue) get(a *dns.Msg) *request {
|
|
q.mu.Lock()
|
|
defer q.mu.Unlock()
|
|
id := a.Id
|
|
r, ok := q.requests[id]
|
|
if !ok {
|
|
return nil
|
|
}
|
|
delete(q.requests, id)
|
|
return r
|
|
}
|
|
|
|
func (q *inFlightQueue) maxQueueLen() int64 {
|
|
q.mu.Lock()
|
|
defer q.mu.Unlock()
|
|
return int64(q.maxLen)
|
|
}
|