Bunch of DoH fixes bundled together (#419)

* removed unused hostname parameter from quicDial function

* removed outdated use of ioutil

* implemented proper quic 0rtt decision making
This commit is contained in:
Leonard Walter
2025-01-15 06:26:23 +01:00
committed by GitHub
parent e854967971
commit d0564a8d77
6 changed files with 53 additions and 30 deletions

View File

@@ -5,7 +5,6 @@ import (
"context"
"crypto/sha256"
"fmt"
"io/ioutil"
"net/http"
"os"
"path/filepath"
@@ -118,7 +117,7 @@ func (l *HTTPLoader) loadFromDisk() ([]string, error) {
}
func (l *HTTPLoader) writeToDisk(rules []string) (err error) {
f, err := ioutil.TempFile(l.opt.CacheDir, "routedns")
f, err := os.CreateTemp(l.opt.CacheDir, "routedns")
if err != nil {
return
}

View File

@@ -305,7 +305,7 @@ func dohQuicTransport(endpoint string, opt DoHClientOptions) (http.RoundTripper,
}
dialer := func(ctx context.Context, addr string, tlsConfig *tls.Config, config *quic.Config) (quic.EarlyConnection, error) {
return newQuicConnection(u.Hostname(), addr, lAddr, tlsConfig, config)
return newQuicConnection(u.Hostname(), addr, lAddr, tlsConfig, config, opt.Use0RTT)
}
if opt.BootstrapAddr != "" {
dialer = func(ctx context.Context, addr string, tlsConfig *tls.Config, config *quic.Config) (quic.EarlyConnection, error) {
@@ -314,7 +314,7 @@ func dohQuicTransport(endpoint string, opt DoHClientOptions) (http.RoundTripper,
return nil, err
}
addr = net.JoinHostPort(opt.BootstrapAddr, port)
return newQuicConnection(u.Hostname(), addr, lAddr, tlsConfig, config)
return newQuicConnection(u.Hostname(), addr, lAddr, tlsConfig, config, opt.Use0RTT)
}
}
@@ -342,10 +342,11 @@ type quicConnection struct {
config *quic.Config
mu sync.Mutex
udpConn *net.UDPConn
Use0RTT bool
}
func newQuicConnection(hostname, rAddr string, lAddr net.IP, tlsConfig *tls.Config, config *quic.Config) (quic.EarlyConnection, error) {
connection, udpConn, err := quicDial(context.TODO(), hostname, rAddr, lAddr, tlsConfig, config)
func newQuicConnection(hostname, rAddr string, lAddr net.IP, tlsConfig *tls.Config, config *quic.Config, use0RTT bool) (quic.EarlyConnection, error) {
connection, udpConn, err := quicDial(context.TODO(), rAddr, lAddr, tlsConfig, config, use0RTT)
if err != nil {
return nil, err
}
@@ -365,6 +366,7 @@ func newQuicConnection(hostname, rAddr string, lAddr net.IP, tlsConfig *tls.Conf
config: config,
udpConn: udpConn,
EarlyConnection: connection,
Use0RTT: use0RTT,
}, nil
}
@@ -419,7 +421,7 @@ func quicRestart(s *quicConnection) error {
)
var err error
var earlyConn quic.EarlyConnection
earlyConn, s.udpConn, err = quicDial(context.TODO(), s.hostname, s.rAddr, s.lAddr, s.tlsConfig, s.config)
earlyConn, s.udpConn, err = quicDial(context.TODO(), s.rAddr, s.lAddr, s.tlsConfig, s.config, s.Use0RTT)
if err != nil || s.udpConn == nil {
Log.Error("couldn't restart quic connection", slog.Group("details", slog.String("protocol", "quic"), slog.String("address", s.hostname), slog.String("local", s.lAddr.String())), "error", err)
return err
@@ -430,7 +432,8 @@ func quicRestart(s *quicConnection) error {
return nil
}
func quicDial(ctx context.Context, hostname, rAddr string, lAddr net.IP, tlsConfig *tls.Config, config *quic.Config) (quic.EarlyConnection, *net.UDPConn, error) {
func quicDial(ctx context.Context, rAddr string, lAddr net.IP, tlsConfig *tls.Config, config *quic.Config, use0RTT bool) (quic.EarlyConnection, *net.UDPConn, error) {
var earlyConn quic.EarlyConnection
udpAddr, err := net.ResolveUDPAddr("udp", rAddr)
if err != nil {
Log.Error("couldn't resolve remote addr for UDP quic client", "error", err, "rAddr", rAddr)
@@ -441,13 +444,36 @@ func quicDial(ctx context.Context, hostname, rAddr string, lAddr net.IP, tlsConf
Log.Error("couldn't listen on UDP socket on local address", "error", err, "local", lAddr.String())
return nil, nil, err
}
// use DialEarly so that we attempt to use 0-RTT DNS queries, it's lower latency (if the server supports it)
earlyConn, err := quic.DialEarly(ctx, udpConn, udpAddr, tlsConfig, config)
if err != nil {
// don't leak filehandles / sockets; if we got here udpConn must exist
_ = udpConn.Close()
Log.Error("couldn't dial quic early connection", "error", err)
return nil, nil, err
if use0RTT {
earlyConn, err = quic.DialEarly(ctx, udpConn, udpAddr, tlsConfig, config)
if err != nil {
_ = udpConn.Close()
Log.Error("couldn't dial quic early connection", "error", err)
return nil, nil, err
}
} else {
conn, err := quic.Dial(ctx, udpConn, udpAddr, tlsConfig, config)
if err != nil {
_ = udpConn.Close()
Log.Error("couldn't dial quic connection", "error", err)
return nil, nil, err
}
earlyConn = &earlyConnWrapper{Connection: conn}
}
return earlyConn, udpConn, nil
}
type earlyConnWrapper struct {
quic.Connection
}
func (e *earlyConnWrapper) HandshakeComplete() <-chan struct{} {
ch := make(chan struct{})
close(ch)
return ch
}
func (e *earlyConnWrapper) NextConnection(ctx context.Context) (quic.Connection, error) {
return nil, fmt.Errorf("NextConnection not supported for non-0RTT connections")
}

View File

@@ -6,7 +6,7 @@ import (
"encoding/base64"
"expvar"
"fmt"
"io/ioutil"
"io"
"net"
"net/http"
"strings"
@@ -188,7 +188,7 @@ func (s *DoHListener) getHandler(w http.ResponseWriter, r *http.Request) {
}
func (s *DoHListener) postHandler(w http.ResponseWriter, r *http.Request) {
b, err := ioutil.ReadAll(r.Body)
b, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return

View File

@@ -38,13 +38,10 @@ type DoQClientOptions struct {
BootstrapAddr string
// Local IP to use for outbound connections. If nil, a local address is chosen.
LocalAddr net.IP
TLSConfig *tls.Config
LocalAddr net.IP
TLSConfig *tls.Config
QueryTimeout time.Duration
Use0RTT bool
Use0RTT bool
}
var _ Resolver = &DoQClient{}
@@ -106,6 +103,7 @@ func NewDoQClient(id, endpoint string, opt DoQClientOptions) (*DoQClient, error)
TokenStore: quic.NewLRUTokenStore(10, 10),
HandshakeIdleTimeout: opt.QueryTimeout,
},
Use0RTT: opt.Use0RTT,
},
metrics: NewListenerMetrics("client", id),
}, nil
@@ -217,7 +215,7 @@ func (s *quicConnection) getStream(endpoint string, log *slog.Logger) (quic.Stre
// If we don't have a connection yet, make one
if s.EarlyConnection == nil {
var err error
s.EarlyConnection, s.udpConn, err = quicDial(context.TODO(), s.hostname, endpoint, s.lAddr, s.tlsConfig, s.config)
s.EarlyConnection, s.udpConn, err = quicDial(context.TODO(), endpoint, s.lAddr, s.tlsConfig, s.config, s.Use0RTT)
if err != nil {
log.Error("failed to open connection",
"hostname", s.hostname,

View File

@@ -4,7 +4,7 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"os"
"github.com/pion/dtls/v2"
)
@@ -18,7 +18,7 @@ func DTLSServerConfig(caFile, crtFile, keyFile string, mutualTLS bool) (*dtls.Co
}
if caFile != "" {
certPool := x509.NewCertPool()
b, err := ioutil.ReadFile(caFile)
b, err := os.ReadFile(caFile)
if err != nil {
return nil, err
}
@@ -57,7 +57,7 @@ func DTLSClientConfig(caFile, crtFile, keyFile string) (*dtls.Config, error) {
// Load custom CA set if provided
if caFile != "" {
certPool := x509.NewCertPool()
b, err := ioutil.ReadFile(caFile)
b, err := os.ReadFile(caFile)
if err != nil {
return nil, err
}

6
tls.go
View File

@@ -4,7 +4,7 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"os"
)
// TLSServerConfig is a convenience function that builds a tls.Config instance for TLS servers
@@ -18,7 +18,7 @@ func TLSServerConfig(caFile, crtFile, keyFile string, mutualTLS bool) (*tls.Conf
}
if caFile != "" {
certPool := x509.NewCertPool()
b, err := ioutil.ReadFile(caFile)
b, err := os.ReadFile(caFile)
if err != nil {
return nil, err
}
@@ -59,7 +59,7 @@ func TLSClientConfig(caFile, crtFile, keyFile, serverName string) (*tls.Config,
// Load custom CA set if provided
if caFile != "" {
certPool := x509.NewCertPool()
b, err := ioutil.ReadFile(caFile)
b, err := os.ReadFile(caFile)
if err != nil {
return nil, err
}