Implement query timeout for all resolvers

This commit is contained in:
folbrich
2019-06-07 19:22:06 -06:00
parent 48980419a2
commit 6ecb8f3a96
4 changed files with 45 additions and 14 deletions

View File

@@ -39,6 +39,7 @@ func listenHandler(r Resolver) dns.HandlerFunc {
Log.Printf("received query for '%s' forwarded to %s", qName(req), r.String())
a, err := r.Resolve(req)
if err != nil {
Log.Println("failed to resolve '%s' : %s", qName(req), err)
return
}
w.WriteMsg(a)

View File

@@ -10,7 +10,7 @@ import (
// DoTClient is a DNS-over-TLS resolver.
type DoTClient struct {
endpoint string
conn *Pipeline
pipeline *Pipeline
}
var _ Resolver = &DoTClient{}
@@ -23,13 +23,13 @@ func NewDoTClient(endpoint string) *DoTClient {
}
return &DoTClient{
endpoint: endpoint,
conn: NewPipeline(endpoint, client),
pipeline: NewPipeline(endpoint, client),
}
}
// Resolve a DNS query.
func (d *DoTClient) Resolve(q *dns.Msg) (*dns.Msg, error) {
return d.conn.Resolve(q)
return d.pipeline.Resolve(q)
}
func (d *DoTClient) String() string {

16
errors.go Normal file
View File

@@ -0,0 +1,16 @@
package rdns
import (
"fmt"
"github.com/miekg/dns"
)
// QueryTimeoutError is returned when a query times out.
type QueryTimeoutError struct {
query *dns.Msg
}
func (e QueryTimeoutError) Error() string {
return fmt.Sprintf("query for '%s' timed out", qName(e.query))
}

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"os"
"sync"
"time"
"github.com/miekg/dns"
)
@@ -32,7 +33,18 @@ func NewPipeline(addr string, client *dns.Client) *Pipeline {
// Resolve a single query using this connection.
func (c *Pipeline) Resolve(q *dns.Msg) (*dns.Msg, error) {
r := newRequest(q)
c.requests <- r
c.requests <- r // Queue up the request
timeout := time.NewTimer(time.Second)
defer timeout.Stop()
// Wait for the request to complete or time out
select {
case <-r.done:
case <-timeout.C:
r.markDone(nil, QueryTimeoutError{q})
}
return r.waitFor()
}
@@ -47,7 +59,6 @@ func (c *Pipeline) start() {
for req := range c.requests { // Lazy connection. Only open a real connection if there's a request
done := make(chan struct{})
Log.Println("opening dot connection to", c.addr)
// conn, err := dns.DialWithTLS("tcp", c.addr, &tls.Config{})
conn, err := c.client.Dial(c.addr)
if err != nil {
Log.Println("failed to open dot connection to", c.addr, ":", err)
@@ -65,8 +76,9 @@ func (c *Pipeline) start() {
query := inFlight.add(req)
Log.Printf("sending query for '%s' to %s", qName(query), c.addr)
if err := conn.WriteMsg(query); err != nil {
req.markDone(nil, err)
conn.Close() // throw away this connection, should wake up the reader as well
req.markDone(nil, err) // fail the request
inFlight.get(query) // clean up the in-flight queue to it doesn't keep growing
conn.Close() // throw away this connection, should wake up the reader as well
wg.Done()
Log.Printf("failed to send query for '%s' to %s : %s", qName(query), c.addr, err.Error())
return
@@ -119,13 +131,15 @@ func newRequest(q *dns.Msg) *request {
func (r *request) waitFor() (*dns.Msg, error) {
<-r.done
// As per https://tools.ietf.org/html/rfc7858#section-3.3, we need to double check this
// really is the correct response.
if len(r.a.Question) > 0 && len(r.q.Question) > 0 {
q := r.q.Question[0]
a := r.a.Question[0]
if a.Name != q.Name || a.Qclass != q.Qclass || a.Qtype != q.Qtype {
return nil, fmt.Errorf("expected answer for %s, got %s", q.String(), a.String())
if r.err == nil {
// As per https://tools.ietf.org/html/rfc7858#section-3.3, we need to double check this
// really is the correct response.
if len(r.a.Question) > 0 && len(r.q.Question) > 0 {
q := r.q.Question[0]
a := r.a.Question[0]
if a.Name != q.Name || a.Qclass != q.Qclass || a.Qtype != q.Qtype {
return nil, fmt.Errorf("expected answer for %s, got %s", q.String(), a.String())
}
}
}