mirror of
https://github.com/folbricht/routedns.git
synced 2025-12-17 15:45:21 -06:00
Implement query timeout for all resolvers
This commit is contained in:
@@ -39,6 +39,7 @@ func listenHandler(r Resolver) dns.HandlerFunc {
|
||||
Log.Printf("received query for '%s' forwarded to %s", qName(req), r.String())
|
||||
a, err := r.Resolve(req)
|
||||
if err != nil {
|
||||
Log.Println("failed to resolve '%s' : %s", qName(req), err)
|
||||
return
|
||||
}
|
||||
w.WriteMsg(a)
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
// DoTClient is a DNS-over-TLS resolver.
|
||||
type DoTClient struct {
|
||||
endpoint string
|
||||
conn *Pipeline
|
||||
pipeline *Pipeline
|
||||
}
|
||||
|
||||
var _ Resolver = &DoTClient{}
|
||||
@@ -23,13 +23,13 @@ func NewDoTClient(endpoint string) *DoTClient {
|
||||
}
|
||||
return &DoTClient{
|
||||
endpoint: endpoint,
|
||||
conn: NewPipeline(endpoint, client),
|
||||
pipeline: NewPipeline(endpoint, client),
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve a DNS query.
|
||||
func (d *DoTClient) Resolve(q *dns.Msg) (*dns.Msg, error) {
|
||||
return d.conn.Resolve(q)
|
||||
return d.pipeline.Resolve(q)
|
||||
}
|
||||
|
||||
func (d *DoTClient) String() string {
|
||||
|
||||
16
errors.go
Normal file
16
errors.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package rdns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// QueryTimeoutError is returned when a query times out.
|
||||
type QueryTimeoutError struct {
|
||||
query *dns.Msg
|
||||
}
|
||||
|
||||
func (e QueryTimeoutError) Error() string {
|
||||
return fmt.Sprintf("query for '%s' timed out", qName(e.query))
|
||||
}
|
||||
36
pipeline.go
36
pipeline.go
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@@ -32,7 +33,18 @@ func NewPipeline(addr string, client *dns.Client) *Pipeline {
|
||||
// Resolve a single query using this connection.
|
||||
func (c *Pipeline) Resolve(q *dns.Msg) (*dns.Msg, error) {
|
||||
r := newRequest(q)
|
||||
c.requests <- r
|
||||
c.requests <- r // Queue up the request
|
||||
|
||||
timeout := time.NewTimer(time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
// Wait for the request to complete or time out
|
||||
select {
|
||||
case <-r.done:
|
||||
case <-timeout.C:
|
||||
r.markDone(nil, QueryTimeoutError{q})
|
||||
}
|
||||
|
||||
return r.waitFor()
|
||||
}
|
||||
|
||||
@@ -47,7 +59,6 @@ func (c *Pipeline) start() {
|
||||
for req := range c.requests { // Lazy connection. Only open a real connection if there's a request
|
||||
done := make(chan struct{})
|
||||
Log.Println("opening dot connection to", c.addr)
|
||||
// conn, err := dns.DialWithTLS("tcp", c.addr, &tls.Config{})
|
||||
conn, err := c.client.Dial(c.addr)
|
||||
if err != nil {
|
||||
Log.Println("failed to open dot connection to", c.addr, ":", err)
|
||||
@@ -65,8 +76,9 @@ func (c *Pipeline) start() {
|
||||
query := inFlight.add(req)
|
||||
Log.Printf("sending query for '%s' to %s", qName(query), c.addr)
|
||||
if err := conn.WriteMsg(query); err != nil {
|
||||
req.markDone(nil, err)
|
||||
conn.Close() // throw away this connection, should wake up the reader as well
|
||||
req.markDone(nil, err) // fail the request
|
||||
inFlight.get(query) // clean up the in-flight queue to it doesn't keep growing
|
||||
conn.Close() // throw away this connection, should wake up the reader as well
|
||||
wg.Done()
|
||||
Log.Printf("failed to send query for '%s' to %s : %s", qName(query), c.addr, err.Error())
|
||||
return
|
||||
@@ -119,13 +131,15 @@ func newRequest(q *dns.Msg) *request {
|
||||
func (r *request) waitFor() (*dns.Msg, error) {
|
||||
<-r.done
|
||||
|
||||
// As per https://tools.ietf.org/html/rfc7858#section-3.3, we need to double check this
|
||||
// really is the correct response.
|
||||
if len(r.a.Question) > 0 && len(r.q.Question) > 0 {
|
||||
q := r.q.Question[0]
|
||||
a := r.a.Question[0]
|
||||
if a.Name != q.Name || a.Qclass != q.Qclass || a.Qtype != q.Qtype {
|
||||
return nil, fmt.Errorf("expected answer for %s, got %s", q.String(), a.String())
|
||||
if r.err == nil {
|
||||
// As per https://tools.ietf.org/html/rfc7858#section-3.3, we need to double check this
|
||||
// really is the correct response.
|
||||
if len(r.a.Question) > 0 && len(r.q.Question) > 0 {
|
||||
q := r.q.Question[0]
|
||||
a := r.a.Question[0]
|
||||
if a.Name != q.Name || a.Qclass != q.Qclass || a.Qtype != q.Qtype {
|
||||
return nil, fmt.Errorf("expected answer for %s, got %s", q.String(), a.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user