Validate resolver addresses during startup (#77)

* Validate resolver addresses during startup

* Validate the port number as well

* Move endpoint validation into the library
This commit is contained in:
Frank Olbricht
2020-08-02 09:57:58 -06:00
committed by GitHub
parent ce96fdc8b4
commit 896eb3f8c1
9 changed files with 87 additions and 10 deletions

View File

@@ -73,6 +73,7 @@ func start(opt options, args []string) error {
if _, ok := resolvers[id]; ok {
return fmt.Errorf("group resolver with duplicate id '%s", id)
}
switch r.Protocol {
case "doq":
tlsConfig, err := rdns.TLSClientConfig(r.CA, r.ClientCrt, r.ClientKey)
@@ -136,7 +137,10 @@ func start(opt options, args []string) error {
opt := rdns.DNSClientOptions{
LocalAddr: net.ParseIP(r.LocalAddr),
}
resolvers[id] = rdns.NewDNSClient(id, r.Address, r.Protocol, opt)
resolvers[id], err = rdns.NewDNSClient(id, r.Address, r.Protocol, opt)
if err != nil {
return fmt.Errorf("failed to parse resolver config for '%s' : %s", id, err)
}
default:
return fmt.Errorf("unsupported protocol '%s' for resolver '%s'", r.Protocol, id)
}

View File

@@ -25,7 +25,10 @@ var _ Resolver = &DNSClient{}
// NewDNSClient returns a new instance of DNSClient which is a plain DNS resolver
// that supports pipelining over a single connection.
func NewDNSClient(id, endpoint, network string, opt DNSClientOptions) *DNSClient {
func NewDNSClient(id, endpoint, network string, opt DNSClientOptions) (*DNSClient, error) {
if err := validEndpoint(endpoint); err != nil {
return nil, err
}
// Use a custom dialer if a local address was provided
var dialer *net.Dialer
if opt.LocalAddr != nil {
@@ -47,7 +50,7 @@ func NewDNSClient(id, endpoint, network string, opt DNSClientOptions) *DNSClient
net: network,
endpoint: endpoint,
pipeline: NewPipeline(endpoint, client),
}
}, nil
}
// Resolve a DNS query.

View File

@@ -8,7 +8,7 @@ import (
)
func TestDNSClientSimpleTCP(t *testing.T) {
d := NewDNSClient("test-dns", "8.8.8.8:53", "tcp", DNSClientOptions{})
d, _ := NewDNSClient("test-dns", "8.8.8.8:53", "tcp", DNSClientOptions{})
q := new(dns.Msg)
q.SetQuestion("google.com.", dns.TypeA)
r, err := d.Resolve(q, ClientInfo{})
@@ -17,7 +17,7 @@ func TestDNSClientSimpleTCP(t *testing.T) {
}
func TestDNSClientSimpleUDP(t *testing.T) {
d := NewDNSClient("test-dns", "8.8.8.8:53", "udp", DNSClientOptions{})
d, _ := NewDNSClient("test-dns", "8.8.8.8:53", "udp", DNSClientOptions{})
q := new(dns.Msg)
q.SetQuestion("google.com.", dns.TypeA)
r, err := d.Resolve(q, ClientInfo{})

View File

@@ -46,6 +46,9 @@ var _ Resolver = &DoQClient{}
// NewDoQClient instantiates a new DNS-over-QUIC resolver.
func NewDoQClient(id, endpoint string, opt DoQClientOptions) (*DoQClient, error) {
if err := validEndpoint(endpoint); err != nil {
return nil, err
}
if opt.TLSConfig == nil {
opt.TLSConfig = new(tls.Config)
}

View File

@@ -32,6 +32,10 @@ var _ Resolver = &DoTClient{}
// NewDoTClient instantiates a new DNS-over-TLS resolver.
func NewDoTClient(id, endpoint string, opt DoTClientOptions) (*DoTClient, error) {
if err := validEndpoint(endpoint); err != nil {
return nil, err
}
// Use a custom dialer if a local address was provided
var dialer *net.Dialer
if opt.LocalAddr != nil {

View File

@@ -80,7 +80,7 @@ func TestDoTListenerMutual(t *testing.T) {
func TestDoTListenerPadding(t *testing.T) {
// Define a listener that does not respond with padding
upstream := NewDNSClient("test-dns", "8.8.8.8:53", "udp", DNSClientOptions{})
upstream, _ := NewDNSClient("test-dns", "8.8.8.8:53", "udp", DNSClientOptions{})
// Find a free port for the listener
addr, err := getLnAddress()

View File

@@ -33,6 +33,9 @@ var _ Resolver = &DTLSClient{}
// NewDTLSClient instantiates a new DNS-over-TLS resolver.
func NewDTLSClient(id, endpoint string, opt DTLSClientOptions) (*DTLSClient, error) {
if err := validEndpoint(endpoint); err != nil {
return nil, err
}
host, port, err := net.SplitHostPort(endpoint)
if err != nil {
return nil, err

View File

@@ -22,8 +22,8 @@ func Example_resolver() {
func Example_group() {
// Define resolvers
r1 := rdns.NewDNSClient("google1", "8.8.8.8:53", "udp", rdns.DNSClientOptions{})
r2 := rdns.NewDNSClient("google2", "8.8.4.4:53", "udp", rdns.DNSClientOptions{})
r1, _ := rdns.NewDNSClient("google1", "8.8.8.8:53", "udp", rdns.DNSClientOptions{})
r2, _ := rdns.NewDNSClient("google2", "8.8.4.4:53", "udp", rdns.DNSClientOptions{})
// Combine them int a group that does round-robin over the two resolvers
g := rdns.NewRoundRobin("test-rr", r1, r2)
@@ -39,8 +39,8 @@ func Example_group() {
func Example_router() {
// Define resolvers
google := rdns.NewDNSClient("g-dns", "8.8.8.8:53", "udp", rdns.DNSClientOptions{})
cloudflare := rdns.NewDNSClient("cf-dns", "1.1.1.1:53", "udp", rdns.DNSClientOptions{})
google, _ := rdns.NewDNSClient("g-dns", "8.8.8.8:53", "udp", rdns.DNSClientOptions{})
cloudflare, _ := rdns.NewDNSClient("cf-dns", "1.1.1.1:53", "udp", rdns.DNSClientOptions{})
// Build a router that will send all "*.cloudflare.com" to the cloudflare
// resolvber while everything else goes to the google resolver (default)

60
validate.go Normal file
View File

@@ -0,0 +1,60 @@
package rdns
import (
"errors"
"fmt"
"net"
"strconv"
"strings"
)
// Returns nil if the endpoint address in the form of <host>:<port> is a valid.
func validEndpoint(addr string) error {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return err
}
if _, err := strconv.ParseUint(port, 10, 16); err != nil {
return fmt.Errorf("invalid port: %w", err)
}
// See if we have a valid IP
if ip := net.ParseIP(host); ip != nil {
return nil
}
return validHostname(host)
}
// Returns nil if the given name is a valid hostnam as per https://tools.ietf.org/html/rfc3696#section-2
// and https://tools.ietf.org/html/rfc1123#page-13
func validHostname(name string) error {
if name == "" {
return errors.New("hostname empty")
}
if len(name) > 255 {
return fmt.Errorf("invalid hostname %q: too long", name)
}
name = strings.TrimSuffix(name, ".")
labels := strings.Split(name, ".")
for _, label := range labels {
for _, c := range label {
if label == "" {
return fmt.Errorf("invalid hostname %q: empty label", name)
}
if strings.HasPrefix(label, "-") || strings.HasSuffix(label, "-") {
return fmt.Errorf("invalid hostname %q: label can not start or end with -", name)
}
switch {
case c >= '0' && c <= '9', c >= 'a' && c <= 'z', c >= 'A' && c <= 'Z', c == '-':
default:
return fmt.Errorf("invalid hostname %q: invalid character %q", name, string(c))
}
}
}
// The last label can not be all-numeric
for _, c := range labels[len(labels)-1] {
if c < '0' || c > '9' {
return nil
}
}
return fmt.Errorf("invalid hostname %q: last label can not be all numeric", name)
}