diff --git a/dnslistener.go b/dnslistener.go index 80811da..0309c0f 100644 --- a/dnslistener.go +++ b/dnslistener.go @@ -39,6 +39,7 @@ func listenHandler(r Resolver) dns.HandlerFunc { Log.Printf("received query for '%s' forwarded to %s", qName(req), r.String()) a, err := r.Resolve(req) if err != nil { + Log.Println("failed to resolve '%s' : %s", qName(req), err) return } w.WriteMsg(a) diff --git a/dotclient.go b/dotclient.go index 93669a3..a52f102 100644 --- a/dotclient.go +++ b/dotclient.go @@ -10,7 +10,7 @@ import ( // DoTClient is a DNS-over-TLS resolver. type DoTClient struct { endpoint string - conn *Pipeline + pipeline *Pipeline } var _ Resolver = &DoTClient{} @@ -23,13 +23,13 @@ func NewDoTClient(endpoint string) *DoTClient { } return &DoTClient{ endpoint: endpoint, - conn: NewPipeline(endpoint, client), + pipeline: NewPipeline(endpoint, client), } } // Resolve a DNS query. func (d *DoTClient) Resolve(q *dns.Msg) (*dns.Msg, error) { - return d.conn.Resolve(q) + return d.pipeline.Resolve(q) } func (d *DoTClient) String() string { diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..cd64e38 --- /dev/null +++ b/errors.go @@ -0,0 +1,16 @@ +package rdns + +import ( + "fmt" + + "github.com/miekg/dns" +) + +// QueryTimeoutError is returned when a query times out. +type QueryTimeoutError struct { + query *dns.Msg +} + +func (e QueryTimeoutError) Error() string { + return fmt.Sprintf("query for '%s' timed out", qName(e.query)) +} diff --git a/pipeline.go b/pipeline.go index 09b3234..b758561 100644 --- a/pipeline.go +++ b/pipeline.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "sync" + "time" "github.com/miekg/dns" ) @@ -32,7 +33,18 @@ func NewPipeline(addr string, client *dns.Client) *Pipeline { // Resolve a single query using this connection. func (c *Pipeline) Resolve(q *dns.Msg) (*dns.Msg, error) { r := newRequest(q) - c.requests <- r + c.requests <- r // Queue up the request + + timeout := time.NewTimer(time.Second) + defer timeout.Stop() + + // Wait for the request to complete or time out + select { + case <-r.done: + case <-timeout.C: + r.markDone(nil, QueryTimeoutError{q}) + } + return r.waitFor() } @@ -47,7 +59,6 @@ func (c *Pipeline) start() { for req := range c.requests { // Lazy connection. Only open a real connection if there's a request done := make(chan struct{}) Log.Println("opening dot connection to", c.addr) - // conn, err := dns.DialWithTLS("tcp", c.addr, &tls.Config{}) conn, err := c.client.Dial(c.addr) if err != nil { Log.Println("failed to open dot connection to", c.addr, ":", err) @@ -65,8 +76,9 @@ func (c *Pipeline) start() { query := inFlight.add(req) Log.Printf("sending query for '%s' to %s", qName(query), c.addr) if err := conn.WriteMsg(query); err != nil { - req.markDone(nil, err) - conn.Close() // throw away this connection, should wake up the reader as well + req.markDone(nil, err) // fail the request + inFlight.get(query) // clean up the in-flight queue to it doesn't keep growing + conn.Close() // throw away this connection, should wake up the reader as well wg.Done() Log.Printf("failed to send query for '%s' to %s : %s", qName(query), c.addr, err.Error()) return @@ -119,13 +131,15 @@ func newRequest(q *dns.Msg) *request { func (r *request) waitFor() (*dns.Msg, error) { <-r.done - // As per https://tools.ietf.org/html/rfc7858#section-3.3, we need to double check this - // really is the correct response. - if len(r.a.Question) > 0 && len(r.q.Question) > 0 { - q := r.q.Question[0] - a := r.a.Question[0] - if a.Name != q.Name || a.Qclass != q.Qclass || a.Qtype != q.Qtype { - return nil, fmt.Errorf("expected answer for %s, got %s", q.String(), a.String()) + if r.err == nil { + // As per https://tools.ietf.org/html/rfc7858#section-3.3, we need to double check this + // really is the correct response. + if len(r.a.Question) > 0 && len(r.q.Question) > 0 { + q := r.q.Question[0] + a := r.a.Question[0] + if a.Name != q.Name || a.Qclass != q.Qclass || a.Qtype != q.Qtype { + return nil, fmt.Errorf("expected answer for %s, got %s", q.String(), a.String()) + } } }