Preserve non-IP records in response from fastest-tcp (#192)

This commit is contained in:
Frank Olbricht
2021-11-29 06:37:12 -07:00
committed by GitHub
parent 1b86db1416
commit 195ca70171

View File

@@ -3,6 +3,7 @@ package rdns
import (
"context"
"errors"
"fmt"
"net"
"strconv"
"time"
@@ -13,9 +14,9 @@ import (
// FastestTCP first resolves the query with the upstream resolver, then
// performs TCP connection tests with the response IPs to determine which
// IP responds the fastest. This IP is then returned in the response.
// This should be used in combination with a Cache to avoid the TCP
// connection overhead on every query.
// IP responds the fastest. This IP is then returned in the response as first
// A/AAAA record. This should be used in combination with a Cache to avoid
// the TCP connection overhead on every query.
type FastestTCP struct {
id string
resolver Resolver
@@ -55,7 +56,8 @@ func NewFastestTCP(id string, resolver Resolver, opt FastestTCPOptions) *Fastest
}
}
// Resolve a DNS query using a random resolver.
// Resolve a DNS query and order the response based on which IP was able to establish
// a TCP connection the fastest.
func (r *FastestTCP) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) {
log := logger(r.id, q, ci)
a, err := r.resolver.Resolve(q, ci)
@@ -68,6 +70,8 @@ func (r *FastestTCP) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) {
if question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA {
return a, nil
}
fmt.Println("Responses")
fmt.Println(a)
// Extract the IP responses
var ipRRs []dns.RR
@@ -85,25 +89,27 @@ func (r *FastestTCP) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) {
// Send TCP probes to all, if anything returns an error, just return
// the original response rather than trying to be clever and pick one.
log = log.WithField("port", r.port)
var sorted []dns.RR
if r.opt.WaitAll {
rrs, err := r.probeAll(log, ipRRs)
if err != nil {
log.WithError(err).Debug("tcp probe failed")
return a, nil
}
r.setTTL(rrs...)
a.Answer = rrs
return a, nil
sorted, err = r.probeAll(log, ipRRs)
} else {
first, err := r.probeFastest(log, ipRRs)
if err != nil {
log.WithError(err).Debug("tcp probe failed")
return a, nil
}
r.setTTL(first)
a.Answer = []dns.RR{first}
sorted, err = r.probeFastest(log, ipRRs)
}
if err != nil {
log.WithError(err).Debug("tcp probe failed")
return a, nil
}
r.setTTL(sorted...)
// Merge the sorted list of RRs back into the original answer in the same
// positions. The original answer could have CNAMEs and other types in it.
for i, rr := range a.Answer {
if rr.Header().Rrtype == question.Qtype {
a.Answer[i] = sorted[0]
sorted = sorted[1:]
}
}
return a, nil
}
// Sets the TTL of the given RRs if the option was provided
@@ -123,13 +129,22 @@ func (r *FastestTCP) String() string {
// Probes all IPs and returns only the RR with the fastest responding IP.
// Waits for the first one that comes back. Returns an error if the fastest response
// is an error.
func (r *FastestTCP) probeFastest(log logrus.FieldLogger, rrs []dns.RR) (dns.RR, error) {
func (r *FastestTCP) probeFastest(log logrus.FieldLogger, rrs []dns.RR) ([]dns.RR, error) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
resultCh := r.probe(ctx, log, rrs)
select {
case res := <-resultCh:
return res.rr, res.err
// Re-order the list in-place to put the fastest at the top
rr := res.rr
err := res.err
for i := 0; i < len(rrs); i++ {
if rrs[i] == rr {
return rrs, err
}
rrs[i], rr = rr, rrs[i]
}
return rrs, err
case <-ctx.Done():
return nil, ctx.Err()
}