diff --git a/cmd/routedns/main.go b/cmd/routedns/main.go index e575676..aeea259 100644 --- a/cmd/routedns/main.go +++ b/cmd/routedns/main.go @@ -73,6 +73,7 @@ func start(opt options, args []string) error { if _, ok := resolvers[id]; ok { return fmt.Errorf("group resolver with duplicate id '%s", id) } + switch r.Protocol { case "doq": tlsConfig, err := rdns.TLSClientConfig(r.CA, r.ClientCrt, r.ClientKey) @@ -136,7 +137,10 @@ func start(opt options, args []string) error { opt := rdns.DNSClientOptions{ LocalAddr: net.ParseIP(r.LocalAddr), } - resolvers[id] = rdns.NewDNSClient(id, r.Address, r.Protocol, opt) + resolvers[id], err = rdns.NewDNSClient(id, r.Address, r.Protocol, opt) + if err != nil { + return fmt.Errorf("failed to parse resolver config for '%s' : %s", id, err) + } default: return fmt.Errorf("unsupported protocol '%s' for resolver '%s'", r.Protocol, id) } diff --git a/dnsclient.go b/dnsclient.go index d7ed41e..2d031d4 100644 --- a/dnsclient.go +++ b/dnsclient.go @@ -25,7 +25,10 @@ var _ Resolver = &DNSClient{} // NewDNSClient returns a new instance of DNSClient which is a plain DNS resolver // that supports pipelining over a single connection. -func NewDNSClient(id, endpoint, network string, opt DNSClientOptions) *DNSClient { +func NewDNSClient(id, endpoint, network string, opt DNSClientOptions) (*DNSClient, error) { + if err := validEndpoint(endpoint); err != nil { + return nil, err + } // Use a custom dialer if a local address was provided var dialer *net.Dialer if opt.LocalAddr != nil { @@ -47,7 +50,7 @@ func NewDNSClient(id, endpoint, network string, opt DNSClientOptions) *DNSClient net: network, endpoint: endpoint, pipeline: NewPipeline(endpoint, client), - } + }, nil } // Resolve a DNS query. diff --git a/dnsclient_test.go b/dnsclient_test.go index 3983443..bb09d79 100644 --- a/dnsclient_test.go +++ b/dnsclient_test.go @@ -8,7 +8,7 @@ import ( ) func TestDNSClientSimpleTCP(t *testing.T) { - d := NewDNSClient("test-dns", "8.8.8.8:53", "tcp", DNSClientOptions{}) + d, _ := NewDNSClient("test-dns", "8.8.8.8:53", "tcp", DNSClientOptions{}) q := new(dns.Msg) q.SetQuestion("google.com.", dns.TypeA) r, err := d.Resolve(q, ClientInfo{}) @@ -17,7 +17,7 @@ func TestDNSClientSimpleTCP(t *testing.T) { } func TestDNSClientSimpleUDP(t *testing.T) { - d := NewDNSClient("test-dns", "8.8.8.8:53", "udp", DNSClientOptions{}) + d, _ := NewDNSClient("test-dns", "8.8.8.8:53", "udp", DNSClientOptions{}) q := new(dns.Msg) q.SetQuestion("google.com.", dns.TypeA) r, err := d.Resolve(q, ClientInfo{}) diff --git a/doqclient.go b/doqclient.go index 4865af9..be67317 100644 --- a/doqclient.go +++ b/doqclient.go @@ -46,6 +46,9 @@ var _ Resolver = &DoQClient{} // NewDoQClient instantiates a new DNS-over-QUIC resolver. func NewDoQClient(id, endpoint string, opt DoQClientOptions) (*DoQClient, error) { + if err := validEndpoint(endpoint); err != nil { + return nil, err + } if opt.TLSConfig == nil { opt.TLSConfig = new(tls.Config) } diff --git a/dotclient.go b/dotclient.go index 2c292b3..be0bae8 100644 --- a/dotclient.go +++ b/dotclient.go @@ -32,6 +32,10 @@ var _ Resolver = &DoTClient{} // NewDoTClient instantiates a new DNS-over-TLS resolver. func NewDoTClient(id, endpoint string, opt DoTClientOptions) (*DoTClient, error) { + if err := validEndpoint(endpoint); err != nil { + return nil, err + } + // Use a custom dialer if a local address was provided var dialer *net.Dialer if opt.LocalAddr != nil { diff --git a/dotlistener_test.go b/dotlistener_test.go index 3037307..cf57349 100644 --- a/dotlistener_test.go +++ b/dotlistener_test.go @@ -80,7 +80,7 @@ func TestDoTListenerMutual(t *testing.T) { func TestDoTListenerPadding(t *testing.T) { // Define a listener that does not respond with padding - upstream := NewDNSClient("test-dns", "8.8.8.8:53", "udp", DNSClientOptions{}) + upstream, _ := NewDNSClient("test-dns", "8.8.8.8:53", "udp", DNSClientOptions{}) // Find a free port for the listener addr, err := getLnAddress() diff --git a/dtlsclient.go b/dtlsclient.go index 98bc83b..0b950a8 100644 --- a/dtlsclient.go +++ b/dtlsclient.go @@ -33,6 +33,9 @@ var _ Resolver = &DTLSClient{} // NewDTLSClient instantiates a new DNS-over-TLS resolver. func NewDTLSClient(id, endpoint string, opt DTLSClientOptions) (*DTLSClient, error) { + if err := validEndpoint(endpoint); err != nil { + return nil, err + } host, port, err := net.SplitHostPort(endpoint) if err != nil { return nil, err diff --git a/example_test.go b/example_test.go index d0be713..908673a 100644 --- a/example_test.go +++ b/example_test.go @@ -22,8 +22,8 @@ func Example_resolver() { func Example_group() { // Define resolvers - r1 := rdns.NewDNSClient("google1", "8.8.8.8:53", "udp", rdns.DNSClientOptions{}) - r2 := rdns.NewDNSClient("google2", "8.8.4.4:53", "udp", rdns.DNSClientOptions{}) + r1, _ := rdns.NewDNSClient("google1", "8.8.8.8:53", "udp", rdns.DNSClientOptions{}) + r2, _ := rdns.NewDNSClient("google2", "8.8.4.4:53", "udp", rdns.DNSClientOptions{}) // Combine them int a group that does round-robin over the two resolvers g := rdns.NewRoundRobin("test-rr", r1, r2) @@ -39,8 +39,8 @@ func Example_group() { func Example_router() { // Define resolvers - google := rdns.NewDNSClient("g-dns", "8.8.8.8:53", "udp", rdns.DNSClientOptions{}) - cloudflare := rdns.NewDNSClient("cf-dns", "1.1.1.1:53", "udp", rdns.DNSClientOptions{}) + google, _ := rdns.NewDNSClient("g-dns", "8.8.8.8:53", "udp", rdns.DNSClientOptions{}) + cloudflare, _ := rdns.NewDNSClient("cf-dns", "1.1.1.1:53", "udp", rdns.DNSClientOptions{}) // Build a router that will send all "*.cloudflare.com" to the cloudflare // resolvber while everything else goes to the google resolver (default) diff --git a/validate.go b/validate.go new file mode 100644 index 0000000..ea22ecf --- /dev/null +++ b/validate.go @@ -0,0 +1,60 @@ +package rdns + +import ( + "errors" + "fmt" + "net" + "strconv" + "strings" +) + +// Returns nil if the endpoint address in the form of : is a valid. +func validEndpoint(addr string) error { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return err + } + if _, err := strconv.ParseUint(port, 10, 16); err != nil { + return fmt.Errorf("invalid port: %w", err) + } + // See if we have a valid IP + if ip := net.ParseIP(host); ip != nil { + return nil + } + return validHostname(host) +} + +// Returns nil if the given name is a valid hostnam as per https://tools.ietf.org/html/rfc3696#section-2 +// and https://tools.ietf.org/html/rfc1123#page-13 +func validHostname(name string) error { + if name == "" { + return errors.New("hostname empty") + } + if len(name) > 255 { + return fmt.Errorf("invalid hostname %q: too long", name) + } + name = strings.TrimSuffix(name, ".") + labels := strings.Split(name, ".") + for _, label := range labels { + for _, c := range label { + if label == "" { + return fmt.Errorf("invalid hostname %q: empty label", name) + } + if strings.HasPrefix(label, "-") || strings.HasSuffix(label, "-") { + return fmt.Errorf("invalid hostname %q: label can not start or end with -", name) + } + switch { + case c >= '0' && c <= '9', c >= 'a' && c <= 'z', c >= 'A' && c <= 'Z', c == '-': + default: + return fmt.Errorf("invalid hostname %q: invalid character %q", name, string(c)) + } + } + } + // The last label can not be all-numeric + for _, c := range labels[len(labels)-1] { + if c < '0' || c > '9' { + return nil + } + } + return fmt.Errorf("invalid hostname %q: last label can not be all numeric", name) +}