mirror of
https://github.com/folbricht/routedns.git
synced 2026-02-11 19:48:56 -06:00
Bunch of DoH fixes bundled together (#419)
* removed unused hostname parameter from quicDial function * removed outdated use of ioutil * implemented proper quic 0rtt decision making
This commit is contained in:
@@ -5,7 +5,6 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -118,7 +117,7 @@ func (l *HTTPLoader) loadFromDisk() ([]string, error) {
|
||||
}
|
||||
|
||||
func (l *HTTPLoader) writeToDisk(rules []string) (err error) {
|
||||
f, err := ioutil.TempFile(l.opt.CacheDir, "routedns")
|
||||
f, err := os.CreateTemp(l.opt.CacheDir, "routedns")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
52
dohclient.go
52
dohclient.go
@@ -305,7 +305,7 @@ func dohQuicTransport(endpoint string, opt DoHClientOptions) (http.RoundTripper,
|
||||
}
|
||||
|
||||
dialer := func(ctx context.Context, addr string, tlsConfig *tls.Config, config *quic.Config) (quic.EarlyConnection, error) {
|
||||
return newQuicConnection(u.Hostname(), addr, lAddr, tlsConfig, config)
|
||||
return newQuicConnection(u.Hostname(), addr, lAddr, tlsConfig, config, opt.Use0RTT)
|
||||
}
|
||||
if opt.BootstrapAddr != "" {
|
||||
dialer = func(ctx context.Context, addr string, tlsConfig *tls.Config, config *quic.Config) (quic.EarlyConnection, error) {
|
||||
@@ -314,7 +314,7 @@ func dohQuicTransport(endpoint string, opt DoHClientOptions) (http.RoundTripper,
|
||||
return nil, err
|
||||
}
|
||||
addr = net.JoinHostPort(opt.BootstrapAddr, port)
|
||||
return newQuicConnection(u.Hostname(), addr, lAddr, tlsConfig, config)
|
||||
return newQuicConnection(u.Hostname(), addr, lAddr, tlsConfig, config, opt.Use0RTT)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -342,10 +342,11 @@ type quicConnection struct {
|
||||
config *quic.Config
|
||||
mu sync.Mutex
|
||||
udpConn *net.UDPConn
|
||||
Use0RTT bool
|
||||
}
|
||||
|
||||
func newQuicConnection(hostname, rAddr string, lAddr net.IP, tlsConfig *tls.Config, config *quic.Config) (quic.EarlyConnection, error) {
|
||||
connection, udpConn, err := quicDial(context.TODO(), hostname, rAddr, lAddr, tlsConfig, config)
|
||||
func newQuicConnection(hostname, rAddr string, lAddr net.IP, tlsConfig *tls.Config, config *quic.Config, use0RTT bool) (quic.EarlyConnection, error) {
|
||||
connection, udpConn, err := quicDial(context.TODO(), rAddr, lAddr, tlsConfig, config, use0RTT)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -365,6 +366,7 @@ func newQuicConnection(hostname, rAddr string, lAddr net.IP, tlsConfig *tls.Conf
|
||||
config: config,
|
||||
udpConn: udpConn,
|
||||
EarlyConnection: connection,
|
||||
Use0RTT: use0RTT,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -419,7 +421,7 @@ func quicRestart(s *quicConnection) error {
|
||||
)
|
||||
var err error
|
||||
var earlyConn quic.EarlyConnection
|
||||
earlyConn, s.udpConn, err = quicDial(context.TODO(), s.hostname, s.rAddr, s.lAddr, s.tlsConfig, s.config)
|
||||
earlyConn, s.udpConn, err = quicDial(context.TODO(), s.rAddr, s.lAddr, s.tlsConfig, s.config, s.Use0RTT)
|
||||
if err != nil || s.udpConn == nil {
|
||||
Log.Error("couldn't restart quic connection", slog.Group("details", slog.String("protocol", "quic"), slog.String("address", s.hostname), slog.String("local", s.lAddr.String())), "error", err)
|
||||
return err
|
||||
@@ -430,7 +432,8 @@ func quicRestart(s *quicConnection) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func quicDial(ctx context.Context, hostname, rAddr string, lAddr net.IP, tlsConfig *tls.Config, config *quic.Config) (quic.EarlyConnection, *net.UDPConn, error) {
|
||||
func quicDial(ctx context.Context, rAddr string, lAddr net.IP, tlsConfig *tls.Config, config *quic.Config, use0RTT bool) (quic.EarlyConnection, *net.UDPConn, error) {
|
||||
var earlyConn quic.EarlyConnection
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", rAddr)
|
||||
if err != nil {
|
||||
Log.Error("couldn't resolve remote addr for UDP quic client", "error", err, "rAddr", rAddr)
|
||||
@@ -441,13 +444,36 @@ func quicDial(ctx context.Context, hostname, rAddr string, lAddr net.IP, tlsConf
|
||||
Log.Error("couldn't listen on UDP socket on local address", "error", err, "local", lAddr.String())
|
||||
return nil, nil, err
|
||||
}
|
||||
// use DialEarly so that we attempt to use 0-RTT DNS queries, it's lower latency (if the server supports it)
|
||||
earlyConn, err := quic.DialEarly(ctx, udpConn, udpAddr, tlsConfig, config)
|
||||
if err != nil {
|
||||
// don't leak filehandles / sockets; if we got here udpConn must exist
|
||||
_ = udpConn.Close()
|
||||
Log.Error("couldn't dial quic early connection", "error", err)
|
||||
return nil, nil, err
|
||||
|
||||
if use0RTT {
|
||||
earlyConn, err = quic.DialEarly(ctx, udpConn, udpAddr, tlsConfig, config)
|
||||
if err != nil {
|
||||
_ = udpConn.Close()
|
||||
Log.Error("couldn't dial quic early connection", "error", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
} else {
|
||||
conn, err := quic.Dial(ctx, udpConn, udpAddr, tlsConfig, config)
|
||||
if err != nil {
|
||||
_ = udpConn.Close()
|
||||
Log.Error("couldn't dial quic connection", "error", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
earlyConn = &earlyConnWrapper{Connection: conn}
|
||||
}
|
||||
return earlyConn, udpConn, nil
|
||||
}
|
||||
|
||||
type earlyConnWrapper struct {
|
||||
quic.Connection
|
||||
}
|
||||
|
||||
func (e *earlyConnWrapper) HandshakeComplete() <-chan struct{} {
|
||||
ch := make(chan struct{})
|
||||
close(ch)
|
||||
return ch
|
||||
}
|
||||
|
||||
func (e *earlyConnWrapper) NextConnection(ctx context.Context) (quic.Connection, error) {
|
||||
return nil, fmt.Errorf("NextConnection not supported for non-0RTT connections")
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"encoding/base64"
|
||||
"expvar"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -188,7 +188,7 @@ func (s *DoHListener) getHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *DoHListener) postHandler(w http.ResponseWriter, r *http.Request) {
|
||||
b, err := ioutil.ReadAll(r.Body)
|
||||
b, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
|
||||
12
doqclient.go
12
doqclient.go
@@ -38,13 +38,10 @@ type DoQClientOptions struct {
|
||||
BootstrapAddr string
|
||||
|
||||
// Local IP to use for outbound connections. If nil, a local address is chosen.
|
||||
LocalAddr net.IP
|
||||
|
||||
TLSConfig *tls.Config
|
||||
|
||||
LocalAddr net.IP
|
||||
TLSConfig *tls.Config
|
||||
QueryTimeout time.Duration
|
||||
|
||||
Use0RTT bool
|
||||
Use0RTT bool
|
||||
}
|
||||
|
||||
var _ Resolver = &DoQClient{}
|
||||
@@ -106,6 +103,7 @@ func NewDoQClient(id, endpoint string, opt DoQClientOptions) (*DoQClient, error)
|
||||
TokenStore: quic.NewLRUTokenStore(10, 10),
|
||||
HandshakeIdleTimeout: opt.QueryTimeout,
|
||||
},
|
||||
Use0RTT: opt.Use0RTT,
|
||||
},
|
||||
metrics: NewListenerMetrics("client", id),
|
||||
}, nil
|
||||
@@ -217,7 +215,7 @@ func (s *quicConnection) getStream(endpoint string, log *slog.Logger) (quic.Stre
|
||||
// If we don't have a connection yet, make one
|
||||
if s.EarlyConnection == nil {
|
||||
var err error
|
||||
s.EarlyConnection, s.udpConn, err = quicDial(context.TODO(), s.hostname, endpoint, s.lAddr, s.tlsConfig, s.config)
|
||||
s.EarlyConnection, s.udpConn, err = quicDial(context.TODO(), endpoint, s.lAddr, s.tlsConfig, s.config, s.Use0RTT)
|
||||
if err != nil {
|
||||
log.Error("failed to open connection",
|
||||
"hostname", s.hostname,
|
||||
|
||||
6
dtls.go
6
dtls.go
@@ -4,7 +4,7 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
|
||||
"github.com/pion/dtls/v2"
|
||||
)
|
||||
@@ -18,7 +18,7 @@ func DTLSServerConfig(caFile, crtFile, keyFile string, mutualTLS bool) (*dtls.Co
|
||||
}
|
||||
if caFile != "" {
|
||||
certPool := x509.NewCertPool()
|
||||
b, err := ioutil.ReadFile(caFile)
|
||||
b, err := os.ReadFile(caFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -57,7 +57,7 @@ func DTLSClientConfig(caFile, crtFile, keyFile string) (*dtls.Config, error) {
|
||||
// Load custom CA set if provided
|
||||
if caFile != "" {
|
||||
certPool := x509.NewCertPool()
|
||||
b, err := ioutil.ReadFile(caFile)
|
||||
b, err := os.ReadFile(caFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
6
tls.go
6
tls.go
@@ -4,7 +4,7 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
)
|
||||
|
||||
// TLSServerConfig is a convenience function that builds a tls.Config instance for TLS servers
|
||||
@@ -18,7 +18,7 @@ func TLSServerConfig(caFile, crtFile, keyFile string, mutualTLS bool) (*tls.Conf
|
||||
}
|
||||
if caFile != "" {
|
||||
certPool := x509.NewCertPool()
|
||||
b, err := ioutil.ReadFile(caFile)
|
||||
b, err := os.ReadFile(caFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -59,7 +59,7 @@ func TLSClientConfig(caFile, crtFile, keyFile, serverName string) (*tls.Config,
|
||||
// Load custom CA set if provided
|
||||
if caFile != "" {
|
||||
certPool := x509.NewCertPool()
|
||||
b, err := ioutil.ReadFile(caFile)
|
||||
b, err := os.ReadFile(caFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user