diff --git a/cmd/routedns/config.go b/cmd/routedns/config.go index a3d33f9..9e65e4d 100644 --- a/cmd/routedns/config.go +++ b/cmd/routedns/config.go @@ -50,6 +50,7 @@ type resolver struct { BootstrapAddr string `toml:"bootstrap-address"` LocalAddr string `toml:"local-address"` EDNS0UDPSize uint16 `toml:"edns0-udp-size"` // UDP resolver option + QueryTimeout int `toml:"query-timeout"` // Query timout in seconds } // DoH-specific resolver options diff --git a/cmd/routedns/resolver.go b/cmd/routedns/resolver.go index f4fe084..6b6fa4c 100644 --- a/cmd/routedns/resolver.go +++ b/cmd/routedns/resolver.go @@ -3,6 +3,7 @@ package main import ( "fmt" "net" + "time" rdns "github.com/folbricht/routedns" ) @@ -23,6 +24,7 @@ func instantiateResolver(id string, r resolver, resolvers map[string]rdns.Resolv BootstrapAddr: r.BootstrapAddr, LocalAddr: net.ParseIP(r.LocalAddr), TLSConfig: tlsConfig, + QueryTimeout: time.Duration(r.QueryTimeout) * time.Second, } resolvers[id], err = rdns.NewDoQClient(id, r.Address, opt) if err != nil { @@ -39,6 +41,7 @@ func instantiateResolver(id string, r resolver, resolvers map[string]rdns.Resolv BootstrapAddr: r.BootstrapAddr, LocalAddr: net.ParseIP(r.LocalAddr), TLSConfig: tlsConfig, + QueryTimeout: time.Duration(r.QueryTimeout) * time.Second, } resolvers[id], err = rdns.NewDoTClient(id, r.Address, opt) if err != nil { @@ -56,6 +59,7 @@ func instantiateResolver(id string, r resolver, resolvers map[string]rdns.Resolv LocalAddr: net.ParseIP(r.LocalAddr), DTLSConfig: dtlsConfig, UDPSize: r.EDNS0UDPSize, + QueryTimeout: time.Duration(r.QueryTimeout) * time.Second, } resolvers[id], err = rdns.NewDTLSClient(id, r.Address, opt) if err != nil { @@ -74,6 +78,7 @@ func instantiateResolver(id string, r resolver, resolvers map[string]rdns.Resolv BootstrapAddr: r.BootstrapAddr, Transport: r.Transport, LocalAddr: net.ParseIP(r.LocalAddr), + QueryTimeout: time.Duration(r.QueryTimeout) * time.Second, } resolvers[id], err = rdns.NewDoHClient(id, r.Address, opt) if err != nil { @@ -83,8 +88,9 @@ func instantiateResolver(id string, r resolver, resolvers map[string]rdns.Resolv r.Address = rdns.AddressWithDefault(r.Address, rdns.PlainDNSPort) opt := rdns.DNSClientOptions{ - LocalAddr: net.ParseIP(r.LocalAddr), - UDPSize: r.EDNS0UDPSize, + LocalAddr: net.ParseIP(r.LocalAddr), + UDPSize: r.EDNS0UDPSize, + QueryTimeout: time.Duration(r.QueryTimeout) * time.Second, } resolvers[id], err = rdns.NewDNSClient(id, r.Address, r.Protocol, opt) if err != nil { diff --git a/dnsclient.go b/dnsclient.go index 8b94511..dbba5fe 100644 --- a/dnsclient.go +++ b/dnsclient.go @@ -3,6 +3,7 @@ package rdns import ( "crypto/tls" "net" + "time" "github.com/miekg/dns" "github.com/sirupsen/logrus" @@ -24,6 +25,8 @@ type DNSClientOptions struct { // Sets the EDNS0 UDP size for all queries sent upstream. If set to 0, queries // are not changed. UDPSize uint16 + + QueryTimeout time.Duration } var _ Resolver = &DNSClient{} @@ -55,7 +58,7 @@ func NewDNSClient(id, endpoint, network string, opt DNSClientOptions) (*DNSClien id: id, net: network, endpoint: endpoint, - pipeline: NewPipeline(id, endpoint, client), + pipeline: NewPipeline(id, endpoint, client, opt.QueryTimeout), opt: opt, }, nil } diff --git a/doc/configuration.md b/doc/configuration.md index 1679dfa..78a9c6e 100644 --- a/doc/configuration.md +++ b/doc/configuration.md @@ -1346,6 +1346,7 @@ Resolvers are defined in the configuration like so `[resolvers.NAME]` and have t - `bootstrap-address` - Use this IP address if the name in `address` can't be resolved. Using the IP in `address` directly may not work when TLS/certificates are used by the server. - `local-address` - IP of the local interface to use for outgoing connections. The address is automatically chosen if this option is left blank. - `edns0-udp-size` - If set, modifies the EDNS0 UDP size option in all queries sent upstream. Only meaningful when using UDP or DTLS resolvers. Upstream resolvers may not respect this value and apply their own limits. +- `query-timeout` - Sets the query timeout to allow. In seconds. Secure resolvers such as DoT, DoH, or DoQ offer additional options to configure the TLS connections. diff --git a/dohclient.go b/dohclient.go index 87c80fa..f463349 100644 --- a/dohclient.go +++ b/dohclient.go @@ -39,6 +39,8 @@ type DoHClientOptions struct { LocalAddr net.IP TLSConfig *tls.Config + + QueryTimeout time.Duration } // DoHClient is a DNS-over-HTTP resolver with support fot HTTP/2. @@ -83,6 +85,9 @@ func NewDoHClient(id, endpoint string, opt DoHClientOptions) (*DoHClient, error) if opt.Method != "POST" && opt.Method != "GET" { return nil, fmt.Errorf("unsupported method '%s'", opt.Method) } + if opt.QueryTimeout == 0 { + opt.QueryTimeout = defaultQueryTimeout + } return &DoHClient{ id: id, @@ -129,7 +134,11 @@ func (d *DoHClient) ResolvePOST(q *dns.Msg) (*dns.Msg, error) { d.metrics.err.Add("template", 1) return nil, err } - req, err := http.NewRequest("POST", u, bytes.NewReader(b)) + + ctx, cancel := context.WithTimeout(context.Background(), d.opt.QueryTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", u, bytes.NewReader(b)) if err != nil { d.metrics.err.Add("http", 1) return nil, err @@ -162,7 +171,11 @@ func (d *DoHClient) ResolveGET(q *dns.Msg) (*dns.Msg, error) { d.metrics.err.Add("template", 1) return nil, err } - req, err := http.NewRequest("GET", u, nil) + + ctx, cancel := context.WithTimeout(context.Background(), d.opt.QueryTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", u, nil) if err != nil { d.metrics.err.Add("http", 1) return nil, err diff --git a/doqclient.go b/doqclient.go index 4913533..f9165c2 100644 --- a/doqclient.go +++ b/doqclient.go @@ -39,6 +39,8 @@ type DoQClientOptions struct { LocalAddr net.IP TLSConfig *tls.Config + + QueryTimeout time.Duration } var _ Resolver = &DoQClient{} @@ -71,6 +73,9 @@ func NewDoQClient(id, endpoint string, opt DoQClientOptions) (*DoQClient, error) tlsConfig.ServerName = host endpoint = net.JoinHostPort(opt.BootstrapAddr, port) } + if opt.QueryTimeout == 0 { + opt.QueryTimeout = defaultQueryTimeout + } log := Log.WithFields(logrus.Fields{"protocol": "doq", "endpoint": endpoint}) return &DoQClient{ id: id, @@ -139,7 +144,7 @@ func (d *DoQClient) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) { } // Write the query into the stream and close it. Only one stream per query/response - _ = stream.SetWriteDeadline(time.Now().Add(time.Second)) + _ = stream.SetWriteDeadline(time.Now().Add(d.DoQClientOptions.QueryTimeout)) if _, err = stream.Write(b); err != nil { d.metrics.err.Add("write", 1) return nil, err @@ -149,7 +154,7 @@ func (d *DoQClient) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) { return nil, err } - _ = stream.SetReadDeadline(time.Now().Add(time.Second)) + _ = stream.SetReadDeadline(time.Now().Add(d.DoQClientOptions.QueryTimeout)) // DoQ requires a length prefix, like TCP var length uint16 diff --git a/dotclient.go b/dotclient.go index be1a0f9..9420f07 100644 --- a/dotclient.go +++ b/dotclient.go @@ -3,6 +3,7 @@ package rdns import ( "crypto/tls" "net" + "time" "github.com/miekg/dns" "github.com/pkg/errors" @@ -27,6 +28,8 @@ type DoTClientOptions struct { LocalAddr net.IP TLSConfig *tls.Config + + QueryTimeout time.Duration } var _ Resolver = &DoTClient{} @@ -62,7 +65,7 @@ func NewDoTClient(id, endpoint string, opt DoTClientOptions) (*DoTClient, error) return &DoTClient{ id: id, endpoint: endpoint, - pipeline: NewPipeline(id, endpoint, client), + pipeline: NewPipeline(id, endpoint, client, opt.QueryTimeout), }, nil } diff --git a/dtlsclient.go b/dtlsclient.go index 77d5873..9c42078 100644 --- a/dtlsclient.go +++ b/dtlsclient.go @@ -4,6 +4,7 @@ import ( "fmt" "net" "strconv" + "time" "github.com/miekg/dns" "github.com/pion/dtls/v2" @@ -32,6 +33,8 @@ type DTLSClientOptions struct { UDPSize uint16 DTLSConfig *dtls.Config + + QueryTimeout time.Duration } var _ Resolver = &DTLSClient{} @@ -83,7 +86,7 @@ func NewDTLSClient(id, endpoint string, opt DTLSClientOptions) (*DTLSClient, err return &DTLSClient{ id: id, endpoint: endpoint, - pipeline: NewPipeline(id, endpoint, client), + pipeline: NewPipeline(id, endpoint, client, opt.QueryTimeout), opt: opt, }, nil } diff --git a/pipeline.go b/pipeline.go index 91fb624..fe6e5c8 100644 --- a/pipeline.go +++ b/pipeline.go @@ -10,8 +10,8 @@ import ( "github.com/miekg/dns" ) -// Defines how long to wait for a response from the resolver. -const queryTimeout = time.Second +// Defines how long to wait for a response from the resolver if no other timeout is given. +const defaultQueryTimeout = 2 * time.Second // Tear down an upstream connection if nothing has been received for this long. const idleTimeout = 10 * time.Second @@ -25,6 +25,7 @@ type Pipeline struct { client DNSDialer requests chan *request metrics *ListenerMetrics + timeout time.Duration } // DNSDialer is an abstraction for a dns.Client that returns a *dns.Conn. @@ -33,12 +34,16 @@ type DNSDialer interface { } // NewPipeline returns an initialized (and running) DNS connection manager. -func NewPipeline(id string, addr string, client DNSDialer) *Pipeline { +func NewPipeline(id string, addr string, client DNSDialer, timeout time.Duration) *Pipeline { + if timeout == 0 { + timeout = defaultQueryTimeout + } c := &Pipeline{ addr: addr, client: client, requests: make(chan *request), metrics: NewListenerMetrics("client", id), + timeout: timeout, } go c.start() return c @@ -48,7 +53,7 @@ func NewPipeline(id string, addr string, client DNSDialer) *Pipeline { func (c *Pipeline) Resolve(q *dns.Msg) (*dns.Msg, error) { r := newRequest(q) - timeout := time.NewTimer(queryTimeout) + timeout := time.NewTimer(c.timeout) defer timeout.Stop() // Queue up the request or time out diff --git a/pipeline_test.go b/pipeline_test.go index 8a56b10..4e0f569 100644 --- a/pipeline_test.go +++ b/pipeline_test.go @@ -20,7 +20,7 @@ func TestPipelineQueryTimeout(t *testing.T) { time.Sleep(2 * time.Second) return nil, errors.New("failed") } - p := NewPipeline("test", "localhost:53", testDialer(df)) + p := NewPipeline("test", "localhost:53", testDialer(df), time.Second) q := new(dns.Msg) q.SetQuestion("example.com.", dns.TypeA) @@ -35,5 +35,5 @@ func TestPipelineQueryTimeout(t *testing.T) { // Make sure we get a timeout error and it took the right amount to come back require.ErrorAs(t, err, &QueryTimeoutError{}) - require.WithinDuration(t, start.Add(queryTimeout), time.Now(), 10*time.Millisecond) + require.WithinDuration(t, start.Add(time.Second), time.Now(), 10*time.Millisecond) }