diff --git a/blocklistloader-http.go b/blocklistloader-http.go index 144f6ce..b941979 100644 --- a/blocklistloader-http.go +++ b/blocklistloader-http.go @@ -5,7 +5,6 @@ import ( "context" "crypto/sha256" "fmt" - "io/ioutil" "net/http" "os" "path/filepath" @@ -118,7 +117,7 @@ func (l *HTTPLoader) loadFromDisk() ([]string, error) { } func (l *HTTPLoader) writeToDisk(rules []string) (err error) { - f, err := ioutil.TempFile(l.opt.CacheDir, "routedns") + f, err := os.CreateTemp(l.opt.CacheDir, "routedns") if err != nil { return } diff --git a/dohclient.go b/dohclient.go index 3859ec7..21019c0 100644 --- a/dohclient.go +++ b/dohclient.go @@ -305,7 +305,7 @@ func dohQuicTransport(endpoint string, opt DoHClientOptions) (http.RoundTripper, } dialer := func(ctx context.Context, addr string, tlsConfig *tls.Config, config *quic.Config) (quic.EarlyConnection, error) { - return newQuicConnection(u.Hostname(), addr, lAddr, tlsConfig, config) + return newQuicConnection(u.Hostname(), addr, lAddr, tlsConfig, config, opt.Use0RTT) } if opt.BootstrapAddr != "" { dialer = func(ctx context.Context, addr string, tlsConfig *tls.Config, config *quic.Config) (quic.EarlyConnection, error) { @@ -314,7 +314,7 @@ func dohQuicTransport(endpoint string, opt DoHClientOptions) (http.RoundTripper, return nil, err } addr = net.JoinHostPort(opt.BootstrapAddr, port) - return newQuicConnection(u.Hostname(), addr, lAddr, tlsConfig, config) + return newQuicConnection(u.Hostname(), addr, lAddr, tlsConfig, config, opt.Use0RTT) } } @@ -342,10 +342,11 @@ type quicConnection struct { config *quic.Config mu sync.Mutex udpConn *net.UDPConn + Use0RTT bool } -func newQuicConnection(hostname, rAddr string, lAddr net.IP, tlsConfig *tls.Config, config *quic.Config) (quic.EarlyConnection, error) { - connection, udpConn, err := quicDial(context.TODO(), hostname, rAddr, lAddr, tlsConfig, config) +func newQuicConnection(hostname, rAddr string, lAddr net.IP, tlsConfig *tls.Config, config *quic.Config, use0RTT bool) (quic.EarlyConnection, error) { + connection, udpConn, err := quicDial(context.TODO(), rAddr, lAddr, tlsConfig, config, use0RTT) if err != nil { return nil, err } @@ -365,6 +366,7 @@ func newQuicConnection(hostname, rAddr string, lAddr net.IP, tlsConfig *tls.Conf config: config, udpConn: udpConn, EarlyConnection: connection, + Use0RTT: use0RTT, }, nil } @@ -419,7 +421,7 @@ func quicRestart(s *quicConnection) error { ) var err error var earlyConn quic.EarlyConnection - earlyConn, s.udpConn, err = quicDial(context.TODO(), s.hostname, s.rAddr, s.lAddr, s.tlsConfig, s.config) + earlyConn, s.udpConn, err = quicDial(context.TODO(), s.rAddr, s.lAddr, s.tlsConfig, s.config, s.Use0RTT) if err != nil || s.udpConn == nil { Log.Error("couldn't restart quic connection", slog.Group("details", slog.String("protocol", "quic"), slog.String("address", s.hostname), slog.String("local", s.lAddr.String())), "error", err) return err @@ -430,7 +432,8 @@ func quicRestart(s *quicConnection) error { return nil } -func quicDial(ctx context.Context, hostname, rAddr string, lAddr net.IP, tlsConfig *tls.Config, config *quic.Config) (quic.EarlyConnection, *net.UDPConn, error) { +func quicDial(ctx context.Context, rAddr string, lAddr net.IP, tlsConfig *tls.Config, config *quic.Config, use0RTT bool) (quic.EarlyConnection, *net.UDPConn, error) { + var earlyConn quic.EarlyConnection udpAddr, err := net.ResolveUDPAddr("udp", rAddr) if err != nil { Log.Error("couldn't resolve remote addr for UDP quic client", "error", err, "rAddr", rAddr) @@ -441,13 +444,36 @@ func quicDial(ctx context.Context, hostname, rAddr string, lAddr net.IP, tlsConf Log.Error("couldn't listen on UDP socket on local address", "error", err, "local", lAddr.String()) return nil, nil, err } - // use DialEarly so that we attempt to use 0-RTT DNS queries, it's lower latency (if the server supports it) - earlyConn, err := quic.DialEarly(ctx, udpConn, udpAddr, tlsConfig, config) - if err != nil { - // don't leak filehandles / sockets; if we got here udpConn must exist - _ = udpConn.Close() - Log.Error("couldn't dial quic early connection", "error", err) - return nil, nil, err + + if use0RTT { + earlyConn, err = quic.DialEarly(ctx, udpConn, udpAddr, tlsConfig, config) + if err != nil { + _ = udpConn.Close() + Log.Error("couldn't dial quic early connection", "error", err) + return nil, nil, err + } + } else { + conn, err := quic.Dial(ctx, udpConn, udpAddr, tlsConfig, config) + if err != nil { + _ = udpConn.Close() + Log.Error("couldn't dial quic connection", "error", err) + return nil, nil, err + } + earlyConn = &earlyConnWrapper{Connection: conn} } return earlyConn, udpConn, nil } + +type earlyConnWrapper struct { + quic.Connection +} + +func (e *earlyConnWrapper) HandshakeComplete() <-chan struct{} { + ch := make(chan struct{}) + close(ch) + return ch +} + +func (e *earlyConnWrapper) NextConnection(ctx context.Context) (quic.Connection, error) { + return nil, fmt.Errorf("NextConnection not supported for non-0RTT connections") +} diff --git a/dohlistener.go b/dohlistener.go index d7cced5..3e1da0b 100644 --- a/dohlistener.go +++ b/dohlistener.go @@ -6,7 +6,7 @@ import ( "encoding/base64" "expvar" "fmt" - "io/ioutil" + "io" "net" "net/http" "strings" @@ -188,7 +188,7 @@ func (s *DoHListener) getHandler(w http.ResponseWriter, r *http.Request) { } func (s *DoHListener) postHandler(w http.ResponseWriter, r *http.Request) { - b, err := ioutil.ReadAll(r.Body) + b, err := io.ReadAll(r.Body) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return diff --git a/doqclient.go b/doqclient.go index 0d27471..1353310 100644 --- a/doqclient.go +++ b/doqclient.go @@ -38,13 +38,10 @@ type DoQClientOptions struct { BootstrapAddr string // Local IP to use for outbound connections. If nil, a local address is chosen. - LocalAddr net.IP - - TLSConfig *tls.Config - + LocalAddr net.IP + TLSConfig *tls.Config QueryTimeout time.Duration - - Use0RTT bool + Use0RTT bool } var _ Resolver = &DoQClient{} @@ -106,6 +103,7 @@ func NewDoQClient(id, endpoint string, opt DoQClientOptions) (*DoQClient, error) TokenStore: quic.NewLRUTokenStore(10, 10), HandshakeIdleTimeout: opt.QueryTimeout, }, + Use0RTT: opt.Use0RTT, }, metrics: NewListenerMetrics("client", id), }, nil @@ -217,7 +215,7 @@ func (s *quicConnection) getStream(endpoint string, log *slog.Logger) (quic.Stre // If we don't have a connection yet, make one if s.EarlyConnection == nil { var err error - s.EarlyConnection, s.udpConn, err = quicDial(context.TODO(), s.hostname, endpoint, s.lAddr, s.tlsConfig, s.config) + s.EarlyConnection, s.udpConn, err = quicDial(context.TODO(), endpoint, s.lAddr, s.tlsConfig, s.config, s.Use0RTT) if err != nil { log.Error("failed to open connection", "hostname", s.hostname, diff --git a/dtls.go b/dtls.go index 58953f5..f8db4bc 100644 --- a/dtls.go +++ b/dtls.go @@ -4,7 +4,7 @@ import ( "crypto/tls" "crypto/x509" "fmt" - "io/ioutil" + "os" "github.com/pion/dtls/v2" ) @@ -18,7 +18,7 @@ func DTLSServerConfig(caFile, crtFile, keyFile string, mutualTLS bool) (*dtls.Co } if caFile != "" { certPool := x509.NewCertPool() - b, err := ioutil.ReadFile(caFile) + b, err := os.ReadFile(caFile) if err != nil { return nil, err } @@ -57,7 +57,7 @@ func DTLSClientConfig(caFile, crtFile, keyFile string) (*dtls.Config, error) { // Load custom CA set if provided if caFile != "" { certPool := x509.NewCertPool() - b, err := ioutil.ReadFile(caFile) + b, err := os.ReadFile(caFile) if err != nil { return nil, err } diff --git a/tls.go b/tls.go index 45c4fb0..90846ed 100644 --- a/tls.go +++ b/tls.go @@ -4,7 +4,7 @@ import ( "crypto/tls" "crypto/x509" "fmt" - "io/ioutil" + "os" ) // TLSServerConfig is a convenience function that builds a tls.Config instance for TLS servers @@ -18,7 +18,7 @@ func TLSServerConfig(caFile, crtFile, keyFile string, mutualTLS bool) (*tls.Conf } if caFile != "" { certPool := x509.NewCertPool() - b, err := ioutil.ReadFile(caFile) + b, err := os.ReadFile(caFile) if err != nil { return nil, err } @@ -59,7 +59,7 @@ func TLSClientConfig(caFile, crtFile, keyFile, serverName string) (*tls.Config, // Load custom CA set if provided if caFile != "" { certPool := x509.NewCertPool() - b, err := ioutil.ReadFile(caFile) + b, err := os.ReadFile(caFile) if err != nil { return nil, err }