mirror of
https://github.com/folbricht/routedns.git
synced 2025-12-30 14:10:03 -06:00
Validate resolver addresses during startup (#77)
* Validate resolver addresses during startup * Validate the port number as well * Move endpoint validation into the library
This commit is contained in:
@@ -73,6 +73,7 @@ func start(opt options, args []string) error {
|
||||
if _, ok := resolvers[id]; ok {
|
||||
return fmt.Errorf("group resolver with duplicate id '%s", id)
|
||||
}
|
||||
|
||||
switch r.Protocol {
|
||||
case "doq":
|
||||
tlsConfig, err := rdns.TLSClientConfig(r.CA, r.ClientCrt, r.ClientKey)
|
||||
@@ -136,7 +137,10 @@ func start(opt options, args []string) error {
|
||||
opt := rdns.DNSClientOptions{
|
||||
LocalAddr: net.ParseIP(r.LocalAddr),
|
||||
}
|
||||
resolvers[id] = rdns.NewDNSClient(id, r.Address, r.Protocol, opt)
|
||||
resolvers[id], err = rdns.NewDNSClient(id, r.Address, r.Protocol, opt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse resolver config for '%s' : %s", id, err)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported protocol '%s' for resolver '%s'", r.Protocol, id)
|
||||
}
|
||||
|
||||
@@ -25,7 +25,10 @@ var _ Resolver = &DNSClient{}
|
||||
|
||||
// NewDNSClient returns a new instance of DNSClient which is a plain DNS resolver
|
||||
// that supports pipelining over a single connection.
|
||||
func NewDNSClient(id, endpoint, network string, opt DNSClientOptions) *DNSClient {
|
||||
func NewDNSClient(id, endpoint, network string, opt DNSClientOptions) (*DNSClient, error) {
|
||||
if err := validEndpoint(endpoint); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Use a custom dialer if a local address was provided
|
||||
var dialer *net.Dialer
|
||||
if opt.LocalAddr != nil {
|
||||
@@ -47,7 +50,7 @@ func NewDNSClient(id, endpoint, network string, opt DNSClientOptions) *DNSClient
|
||||
net: network,
|
||||
endpoint: endpoint,
|
||||
pipeline: NewPipeline(endpoint, client),
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Resolve a DNS query.
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
func TestDNSClientSimpleTCP(t *testing.T) {
|
||||
d := NewDNSClient("test-dns", "8.8.8.8:53", "tcp", DNSClientOptions{})
|
||||
d, _ := NewDNSClient("test-dns", "8.8.8.8:53", "tcp", DNSClientOptions{})
|
||||
q := new(dns.Msg)
|
||||
q.SetQuestion("google.com.", dns.TypeA)
|
||||
r, err := d.Resolve(q, ClientInfo{})
|
||||
@@ -17,7 +17,7 @@ func TestDNSClientSimpleTCP(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDNSClientSimpleUDP(t *testing.T) {
|
||||
d := NewDNSClient("test-dns", "8.8.8.8:53", "udp", DNSClientOptions{})
|
||||
d, _ := NewDNSClient("test-dns", "8.8.8.8:53", "udp", DNSClientOptions{})
|
||||
q := new(dns.Msg)
|
||||
q.SetQuestion("google.com.", dns.TypeA)
|
||||
r, err := d.Resolve(q, ClientInfo{})
|
||||
|
||||
@@ -46,6 +46,9 @@ var _ Resolver = &DoQClient{}
|
||||
|
||||
// NewDoQClient instantiates a new DNS-over-QUIC resolver.
|
||||
func NewDoQClient(id, endpoint string, opt DoQClientOptions) (*DoQClient, error) {
|
||||
if err := validEndpoint(endpoint); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if opt.TLSConfig == nil {
|
||||
opt.TLSConfig = new(tls.Config)
|
||||
}
|
||||
|
||||
@@ -32,6 +32,10 @@ var _ Resolver = &DoTClient{}
|
||||
|
||||
// NewDoTClient instantiates a new DNS-over-TLS resolver.
|
||||
func NewDoTClient(id, endpoint string, opt DoTClientOptions) (*DoTClient, error) {
|
||||
if err := validEndpoint(endpoint); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use a custom dialer if a local address was provided
|
||||
var dialer *net.Dialer
|
||||
if opt.LocalAddr != nil {
|
||||
|
||||
@@ -80,7 +80,7 @@ func TestDoTListenerMutual(t *testing.T) {
|
||||
|
||||
func TestDoTListenerPadding(t *testing.T) {
|
||||
// Define a listener that does not respond with padding
|
||||
upstream := NewDNSClient("test-dns", "8.8.8.8:53", "udp", DNSClientOptions{})
|
||||
upstream, _ := NewDNSClient("test-dns", "8.8.8.8:53", "udp", DNSClientOptions{})
|
||||
|
||||
// Find a free port for the listener
|
||||
addr, err := getLnAddress()
|
||||
|
||||
@@ -33,6 +33,9 @@ var _ Resolver = &DTLSClient{}
|
||||
|
||||
// NewDTLSClient instantiates a new DNS-over-TLS resolver.
|
||||
func NewDTLSClient(id, endpoint string, opt DTLSClientOptions) (*DTLSClient, error) {
|
||||
if err := validEndpoint(endpoint); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
host, port, err := net.SplitHostPort(endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -22,8 +22,8 @@ func Example_resolver() {
|
||||
|
||||
func Example_group() {
|
||||
// Define resolvers
|
||||
r1 := rdns.NewDNSClient("google1", "8.8.8.8:53", "udp", rdns.DNSClientOptions{})
|
||||
r2 := rdns.NewDNSClient("google2", "8.8.4.4:53", "udp", rdns.DNSClientOptions{})
|
||||
r1, _ := rdns.NewDNSClient("google1", "8.8.8.8:53", "udp", rdns.DNSClientOptions{})
|
||||
r2, _ := rdns.NewDNSClient("google2", "8.8.4.4:53", "udp", rdns.DNSClientOptions{})
|
||||
|
||||
// Combine them int a group that does round-robin over the two resolvers
|
||||
g := rdns.NewRoundRobin("test-rr", r1, r2)
|
||||
@@ -39,8 +39,8 @@ func Example_group() {
|
||||
|
||||
func Example_router() {
|
||||
// Define resolvers
|
||||
google := rdns.NewDNSClient("g-dns", "8.8.8.8:53", "udp", rdns.DNSClientOptions{})
|
||||
cloudflare := rdns.NewDNSClient("cf-dns", "1.1.1.1:53", "udp", rdns.DNSClientOptions{})
|
||||
google, _ := rdns.NewDNSClient("g-dns", "8.8.8.8:53", "udp", rdns.DNSClientOptions{})
|
||||
cloudflare, _ := rdns.NewDNSClient("cf-dns", "1.1.1.1:53", "udp", rdns.DNSClientOptions{})
|
||||
|
||||
// Build a router that will send all "*.cloudflare.com" to the cloudflare
|
||||
// resolvber while everything else goes to the google resolver (default)
|
||||
|
||||
60
validate.go
Normal file
60
validate.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package rdns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Returns nil if the endpoint address in the form of <host>:<port> is a valid.
|
||||
func validEndpoint(addr string) error {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := strconv.ParseUint(port, 10, 16); err != nil {
|
||||
return fmt.Errorf("invalid port: %w", err)
|
||||
}
|
||||
// See if we have a valid IP
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
return nil
|
||||
}
|
||||
return validHostname(host)
|
||||
}
|
||||
|
||||
// Returns nil if the given name is a valid hostnam as per https://tools.ietf.org/html/rfc3696#section-2
|
||||
// and https://tools.ietf.org/html/rfc1123#page-13
|
||||
func validHostname(name string) error {
|
||||
if name == "" {
|
||||
return errors.New("hostname empty")
|
||||
}
|
||||
if len(name) > 255 {
|
||||
return fmt.Errorf("invalid hostname %q: too long", name)
|
||||
}
|
||||
name = strings.TrimSuffix(name, ".")
|
||||
labels := strings.Split(name, ".")
|
||||
for _, label := range labels {
|
||||
for _, c := range label {
|
||||
if label == "" {
|
||||
return fmt.Errorf("invalid hostname %q: empty label", name)
|
||||
}
|
||||
if strings.HasPrefix(label, "-") || strings.HasSuffix(label, "-") {
|
||||
return fmt.Errorf("invalid hostname %q: label can not start or end with -", name)
|
||||
}
|
||||
switch {
|
||||
case c >= '0' && c <= '9', c >= 'a' && c <= 'z', c >= 'A' && c <= 'Z', c == '-':
|
||||
default:
|
||||
return fmt.Errorf("invalid hostname %q: invalid character %q", name, string(c))
|
||||
}
|
||||
}
|
||||
}
|
||||
// The last label can not be all-numeric
|
||||
for _, c := range labels[len(labels)-1] {
|
||||
if c < '0' || c > '9' {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("invalid hostname %q: last label can not be all numeric", name)
|
||||
}
|
||||
Reference in New Issue
Block a user