mirror of
https://github.com/folbricht/routedns.git
synced 2026-01-06 09:40:03 -06:00
Preserve non-IP records in response from fastest-tcp (#192)
This commit is contained in:
@@ -3,6 +3,7 @@ package rdns
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
@@ -13,9 +14,9 @@ import (
|
||||
|
||||
// FastestTCP first resolves the query with the upstream resolver, then
|
||||
// performs TCP connection tests with the response IPs to determine which
|
||||
// IP responds the fastest. This IP is then returned in the response.
|
||||
// This should be used in combination with a Cache to avoid the TCP
|
||||
// connection overhead on every query.
|
||||
// IP responds the fastest. This IP is then returned in the response as first
|
||||
// A/AAAA record. This should be used in combination with a Cache to avoid
|
||||
// the TCP connection overhead on every query.
|
||||
type FastestTCP struct {
|
||||
id string
|
||||
resolver Resolver
|
||||
@@ -55,7 +56,8 @@ func NewFastestTCP(id string, resolver Resolver, opt FastestTCPOptions) *Fastest
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve a DNS query using a random resolver.
|
||||
// Resolve a DNS query and order the response based on which IP was able to establish
|
||||
// a TCP connection the fastest.
|
||||
func (r *FastestTCP) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) {
|
||||
log := logger(r.id, q, ci)
|
||||
a, err := r.resolver.Resolve(q, ci)
|
||||
@@ -68,6 +70,8 @@ func (r *FastestTCP) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) {
|
||||
if question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA {
|
||||
return a, nil
|
||||
}
|
||||
fmt.Println("Responses")
|
||||
fmt.Println(a)
|
||||
|
||||
// Extract the IP responses
|
||||
var ipRRs []dns.RR
|
||||
@@ -85,25 +89,27 @@ func (r *FastestTCP) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) {
|
||||
// Send TCP probes to all, if anything returns an error, just return
|
||||
// the original response rather than trying to be clever and pick one.
|
||||
log = log.WithField("port", r.port)
|
||||
var sorted []dns.RR
|
||||
if r.opt.WaitAll {
|
||||
rrs, err := r.probeAll(log, ipRRs)
|
||||
if err != nil {
|
||||
log.WithError(err).Debug("tcp probe failed")
|
||||
return a, nil
|
||||
}
|
||||
r.setTTL(rrs...)
|
||||
a.Answer = rrs
|
||||
return a, nil
|
||||
sorted, err = r.probeAll(log, ipRRs)
|
||||
} else {
|
||||
first, err := r.probeFastest(log, ipRRs)
|
||||
if err != nil {
|
||||
log.WithError(err).Debug("tcp probe failed")
|
||||
return a, nil
|
||||
}
|
||||
r.setTTL(first)
|
||||
a.Answer = []dns.RR{first}
|
||||
sorted, err = r.probeFastest(log, ipRRs)
|
||||
}
|
||||
if err != nil {
|
||||
log.WithError(err).Debug("tcp probe failed")
|
||||
return a, nil
|
||||
}
|
||||
r.setTTL(sorted...)
|
||||
|
||||
// Merge the sorted list of RRs back into the original answer in the same
|
||||
// positions. The original answer could have CNAMEs and other types in it.
|
||||
for i, rr := range a.Answer {
|
||||
if rr.Header().Rrtype == question.Qtype {
|
||||
a.Answer[i] = sorted[0]
|
||||
sorted = sorted[1:]
|
||||
}
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// Sets the TTL of the given RRs if the option was provided
|
||||
@@ -123,13 +129,22 @@ func (r *FastestTCP) String() string {
|
||||
// Probes all IPs and returns only the RR with the fastest responding IP.
|
||||
// Waits for the first one that comes back. Returns an error if the fastest response
|
||||
// is an error.
|
||||
func (r *FastestTCP) probeFastest(log logrus.FieldLogger, rrs []dns.RR) (dns.RR, error) {
|
||||
func (r *FastestTCP) probeFastest(log logrus.FieldLogger, rrs []dns.RR) ([]dns.RR, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
resultCh := r.probe(ctx, log, rrs)
|
||||
select {
|
||||
case res := <-resultCh:
|
||||
return res.rr, res.err
|
||||
// Re-order the list in-place to put the fastest at the top
|
||||
rr := res.rr
|
||||
err := res.err
|
||||
for i := 0; i < len(rrs); i++ {
|
||||
if rrs[i] == rr {
|
||||
return rrs, err
|
||||
}
|
||||
rrs[i], rr = rr, rrs[i]
|
||||
}
|
||||
return rrs, err
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user