From 654fd43c832c0060e12da53574db78e0f0a61a8b Mon Sep 17 00:00:00 2001 From: Frank Olbricht Date: Tue, 29 Dec 2020 18:19:26 -0700 Subject: [PATCH] Return PTR records for hosts file blocklists (#115) * Return PTR records for hosts file blocklists * Undo testing changes in example file --- blocklist.go | 12 +++++++++--- blocklistdb-domain.go | 10 +++++----- blocklistdb-domain_test.go | 2 +- blocklistdb-hosts.go | 22 ++++++++++++++++++---- blocklistdb-hosts_test.go | 2 +- blocklistdb-multi.go | 8 ++++---- blocklistdb-regexp.go | 6 +++--- blocklistdb.go | 2 +- message.go | 19 ++++++++++++++++++- response-blocklist-name.go | 2 +- 10 files changed, 61 insertions(+), 24 deletions(-) diff --git a/blocklist.go b/blocklist.go index b9c4640..8fbdda6 100644 --- a/blocklist.go +++ b/blocklist.go @@ -36,7 +36,7 @@ type BlocklistOptions struct { // alternative resolver rather than the default upstream one. AllowListResolver Resolver - // Rules that override the blocklist rules, effecively negate them. + // Rules that override the blocklist rules, effectively negate them. AllowlistDB BlocklistDB // Refresh period for the allowlist. Disabled if 0. @@ -92,7 +92,7 @@ func (r *Blocklist) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) { // Forward to upstream or the optional allowlist-resolver immediately if there's a match in the allowlist if allowlistDB != nil { - if _, rule, ok := allowlistDB.Match(question); ok { + if _, _, rule, ok := allowlistDB.Match(question); ok { log = log.WithField("rule", rule) r.metrics.allowed.Add(1) if r.AllowListResolver != nil { @@ -104,7 +104,7 @@ func (r *Blocklist) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) { } } - ip, rule, ok := blocklistDB.Match(question) + ip, name, rule, ok := blocklistDB.Match(question) if !ok { // Didn't match anything, pass it on to the next resolver log.WithField("resolver", r.resolver.String()).Debug("forwarding unmodified query to resolver") @@ -114,6 +114,12 @@ func (r *Blocklist) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) { log = log.WithField("rule", rule) r.metrics.blocked.Add(1) + // If we got a name for the PTR query, respond to it + if question.Qtype == dns.TypePTR && name != "" { + log.Debug("responding with ptr blocklist from blocklist") + return ptr(q, name), nil + } + // If an optional blocklist-resolver was given, send the query to that instead of returning NXDOMAIN. if r.BlocklistResolver != nil { log.WithField("resolver", r.resolver.String()).Debug("matched blocklist, forwarding") diff --git a/blocklistdb-domain.go b/blocklistdb-domain.go index 6f9c5b3..2c08ae9 100644 --- a/blocklistdb-domain.go +++ b/blocklistdb-domain.go @@ -62,7 +62,7 @@ func (m *DomainDB) Reload() (BlocklistDB, error) { return NewDomainDB(m.loader) } -func (m *DomainDB) Match(q dns.Question) (net.IP, string, bool) { +func (m *DomainDB) Match(q dns.Question) (net.IP, string, string, bool) { s := strings.TrimSuffix(q.Name, ".") var matched []string parts := strings.Split(s, ".") @@ -71,18 +71,18 @@ func (m *DomainDB) Match(q dns.Question) (net.IP, string, bool) { part := parts[i] subNode, ok := n[part] if !ok { - return nil, "", false + return nil, "", "", false } matched = append(matched, part) if _, ok := subNode[""]; ok { // exact and sub-domain match - return nil, matchedDomainParts(".", matched), true + return nil, "", matchedDomainParts(".", matched), true } if _, ok := subNode["*"]; ok && i > 0 { // wildcard match on sub-domains - return nil, matchedDomainParts("*.", matched), true + return nil, "", matchedDomainParts("*.", matched), true } n = subNode } - return nil, matchedDomainParts("", matched), len(n) == 0 // exact match + return nil, "", matchedDomainParts("", matched), len(n) == 0 // exact match } func (m *DomainDB) String() string { diff --git a/blocklistdb-domain_test.go b/blocklistdb-domain_test.go index 6d47a02..8ba7d3d 100644 --- a/blocklistdb-domain_test.go +++ b/blocklistdb-domain_test.go @@ -47,7 +47,7 @@ func TestDomainDB(t *testing.T) { } for _, test := range tests { q := dns.Question{Name: test.q, Qtype: dns.TypeA, Qclass: dns.ClassINET} - _, _, ok := m.Match(q) + _, _, _, ok := m.Match(q) require.Equal(t, test.match, ok, "query: %s", test.q) } } diff --git a/blocklistdb-hosts.go b/blocklistdb-hosts.go index 8dceee6..0f3bdbd 100644 --- a/blocklistdb-hosts.go +++ b/blocklistdb-hosts.go @@ -12,6 +12,7 @@ import ( // IP4 is given but no IP6, then a domain match will still result in an NXDOMAIN for the IP6 address. type HostsDB struct { filters map[string]ipRecords + ptrMap map[string]string // PTR lookup map loader BlocklistLoader } @@ -29,6 +30,7 @@ func NewHostsDB(loader BlocklistLoader) (*HostsDB, error) { return nil, err } filters := make(map[string]ipRecords) + ptrMap := make(map[string]string) for _, r := range rules { r = strings.TrimSpace(r) fields := strings.Fields(r) @@ -40,6 +42,9 @@ func NewHostsDB(loader BlocklistLoader) (*HostsDB, error) { if strings.HasPrefix(ipString, "#") { continue } + if len(names) == 0 { + continue + } ip := net.ParseIP(ipString) var isIP4 bool if ip4 := ip.To4(); len(ip4) == net.IPv4len { @@ -58,21 +63,30 @@ func NewHostsDB(loader BlocklistLoader) (*HostsDB, error) { } filters[name] = ips } + reverseAddr, err := dns.ReverseAddr(ipString) + if err != nil { + continue + } + ptrMap[reverseAddr] = names[0] } - return &HostsDB{filters, loader}, nil + return &HostsDB{filters, ptrMap, loader}, nil } func (m *HostsDB) Reload() (BlocklistDB, error) { return NewHostsDB(m.loader) } -func (m *HostsDB) Match(q dns.Question) (net.IP, string, bool) { +func (m *HostsDB) Match(q dns.Question) (net.IP, string, string, bool) { + if q.Qtype == dns.TypePTR { + name, ok := m.ptrMap[q.Name] + return nil, name, "", ok + } name := strings.TrimSuffix(q.Name, ".") ips, ok := m.filters[name] if q.Qtype == dns.TypeA { - return ips.ip4, ips.ip4.String() + " " + name, ok + return ips.ip4, "", ips.ip4.String() + " " + name, ok } - return ips.ip6, ips.ip6.String() + " " + name, ok + return ips.ip6, "", ips.ip6.String() + " " + name, ok } func (m *HostsDB) String() string { diff --git a/blocklistdb-hosts_test.go b/blocklistdb-hosts_test.go index 1bdfb02..63bbcf0 100644 --- a/blocklistdb-hosts_test.go +++ b/blocklistdb-hosts_test.go @@ -39,7 +39,7 @@ func TestHostsDB(t *testing.T) { } for _, test := range tests { q := dns.Question{Name: test.q, Qtype: test.typ, Qclass: dns.ClassINET} - ip, _, ok := m.Match(q) + ip, _, _, ok := m.Match(q) require.Equal(t, test.match, ok, "query: %s", test.q) require.Equal(t, test.ip, ip, "query: %s", test.q) } diff --git a/blocklistdb-multi.go b/blocklistdb-multi.go index e0d5a17..3fdf453 100644 --- a/blocklistdb-multi.go +++ b/blocklistdb-multi.go @@ -30,13 +30,13 @@ func (m MultiDB) Reload() (BlocklistDB, error) { return NewMultiDB(newDBs...) } -func (m MultiDB) Match(q dns.Question) (net.IP, string, bool) { +func (m MultiDB) Match(q dns.Question) (net.IP, string, string, bool) { for _, db := range m.dbs { - if ip, rule, ok := db.Match(q); ok { - return ip, rule, ok + if ip, name, rule, ok := db.Match(q); ok { + return ip, name, rule, ok } } - return nil, "", false + return nil, "", "", false } func (m MultiDB) String() string { diff --git a/blocklistdb-regexp.go b/blocklistdb-regexp.go index fc2e15a..f5f6a58 100644 --- a/blocklistdb-regexp.go +++ b/blocklistdb-regexp.go @@ -42,13 +42,13 @@ func (m *RegexpDB) Reload() (BlocklistDB, error) { return NewRegexpDB(m.loader) } -func (m *RegexpDB) Match(q dns.Question) (net.IP, string, bool) { +func (m *RegexpDB) Match(q dns.Question) (net.IP, string, string, bool) { for _, rule := range m.rules { if rule.MatchString(q.Name) { - return nil, rule.String(), true + return nil, "", rule.String(), true } } - return nil, "", false + return nil, "", "", false } func (m *RegexpDB) String() string { diff --git a/blocklistdb.go b/blocklistdb.go index 806559a..336bd62 100644 --- a/blocklistdb.go +++ b/blocklistdb.go @@ -14,7 +14,7 @@ type BlocklistDB interface { // Returns true if the question matches a rule. If the IP is not nil, // respond with the given IP. NXDOMAIN otherwise. - Match(q dns.Question) (net.IP, string, bool) + Match(q dns.Question) (net.IP, string, string, bool) fmt.Stringer } diff --git a/message.go b/message.go index 39171ce..63eec74 100644 --- a/message.go +++ b/message.go @@ -35,7 +35,24 @@ func refused(q *dns.Msg) *dns.Msg { // Build a response for a query with the given responce code. func responseWithCode(q *dns.Msg, rcode int) *dns.Msg { a := new(dns.Msg) - a.SetReply(q) a.SetRcode(q, rcode) return a } + +// Answers a PTR query with a name +func ptr(q *dns.Msg, name string) *dns.Msg { + a := new(dns.Msg) + a.SetReply(q) + a.Answer = []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{ + Name: q.Question[0].Name, + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: 3600, + }, + Ptr: dns.Fqdn(name), + }, + } + return a +} diff --git a/response-blocklist-name.go b/response-blocklist-name.go index 3233020..8ff7b5f 100644 --- a/response-blocklist-name.go +++ b/response-blocklist-name.go @@ -87,7 +87,7 @@ func (r *ResponseBlocklistName) blockIfMatch(query, answer *dns.Msg, ci ClientIn default: continue } - if _, rule, ok := r.BlocklistDB.Match(dns.Question{Name: name}); ok { + if _, _, rule, ok := r.BlocklistDB.Match(dns.Question{Name: name}); ok { log := logger(r.id, query, ci).WithField("rule", rule) if r.BlocklistResolver != nil { log.WithField("resolver", r.BlocklistResolver).Debug("blocklist match, forwarding to blocklist-resolver")