Return PTR records for hosts file blocklists (#115)

* Return PTR records for hosts file blocklists

* Undo testing changes in example file
This commit is contained in:
Frank Olbricht
2020-12-29 18:19:26 -07:00
committed by GitHub
parent 30e563b7ca
commit 654fd43c83
10 changed files with 61 additions and 24 deletions

View File

@@ -36,7 +36,7 @@ type BlocklistOptions struct {
// alternative resolver rather than the default upstream one.
AllowListResolver Resolver
// Rules that override the blocklist rules, effecively negate them.
// Rules that override the blocklist rules, effectively negate them.
AllowlistDB BlocklistDB
// Refresh period for the allowlist. Disabled if 0.
@@ -92,7 +92,7 @@ func (r *Blocklist) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) {
// Forward to upstream or the optional allowlist-resolver immediately if there's a match in the allowlist
if allowlistDB != nil {
if _, rule, ok := allowlistDB.Match(question); ok {
if _, _, rule, ok := allowlistDB.Match(question); ok {
log = log.WithField("rule", rule)
r.metrics.allowed.Add(1)
if r.AllowListResolver != nil {
@@ -104,7 +104,7 @@ func (r *Blocklist) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) {
}
}
ip, rule, ok := blocklistDB.Match(question)
ip, name, rule, ok := blocklistDB.Match(question)
if !ok {
// Didn't match anything, pass it on to the next resolver
log.WithField("resolver", r.resolver.String()).Debug("forwarding unmodified query to resolver")
@@ -114,6 +114,12 @@ func (r *Blocklist) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) {
log = log.WithField("rule", rule)
r.metrics.blocked.Add(1)
// If we got a name for the PTR query, respond to it
if question.Qtype == dns.TypePTR && name != "" {
log.Debug("responding with ptr blocklist from blocklist")
return ptr(q, name), nil
}
// If an optional blocklist-resolver was given, send the query to that instead of returning NXDOMAIN.
if r.BlocklistResolver != nil {
log.WithField("resolver", r.resolver.String()).Debug("matched blocklist, forwarding")

View File

@@ -62,7 +62,7 @@ func (m *DomainDB) Reload() (BlocklistDB, error) {
return NewDomainDB(m.loader)
}
func (m *DomainDB) Match(q dns.Question) (net.IP, string, bool) {
func (m *DomainDB) Match(q dns.Question) (net.IP, string, string, bool) {
s := strings.TrimSuffix(q.Name, ".")
var matched []string
parts := strings.Split(s, ".")
@@ -71,18 +71,18 @@ func (m *DomainDB) Match(q dns.Question) (net.IP, string, bool) {
part := parts[i]
subNode, ok := n[part]
if !ok {
return nil, "", false
return nil, "", "", false
}
matched = append(matched, part)
if _, ok := subNode[""]; ok { // exact and sub-domain match
return nil, matchedDomainParts(".", matched), true
return nil, "", matchedDomainParts(".", matched), true
}
if _, ok := subNode["*"]; ok && i > 0 { // wildcard match on sub-domains
return nil, matchedDomainParts("*.", matched), true
return nil, "", matchedDomainParts("*.", matched), true
}
n = subNode
}
return nil, matchedDomainParts("", matched), len(n) == 0 // exact match
return nil, "", matchedDomainParts("", matched), len(n) == 0 // exact match
}
func (m *DomainDB) String() string {

View File

@@ -47,7 +47,7 @@ func TestDomainDB(t *testing.T) {
}
for _, test := range tests {
q := dns.Question{Name: test.q, Qtype: dns.TypeA, Qclass: dns.ClassINET}
_, _, ok := m.Match(q)
_, _, _, ok := m.Match(q)
require.Equal(t, test.match, ok, "query: %s", test.q)
}
}

View File

@@ -12,6 +12,7 @@ import (
// IP4 is given but no IP6, then a domain match will still result in an NXDOMAIN for the IP6 address.
type HostsDB struct {
filters map[string]ipRecords
ptrMap map[string]string // PTR lookup map
loader BlocklistLoader
}
@@ -29,6 +30,7 @@ func NewHostsDB(loader BlocklistLoader) (*HostsDB, error) {
return nil, err
}
filters := make(map[string]ipRecords)
ptrMap := make(map[string]string)
for _, r := range rules {
r = strings.TrimSpace(r)
fields := strings.Fields(r)
@@ -40,6 +42,9 @@ func NewHostsDB(loader BlocklistLoader) (*HostsDB, error) {
if strings.HasPrefix(ipString, "#") {
continue
}
if len(names) == 0 {
continue
}
ip := net.ParseIP(ipString)
var isIP4 bool
if ip4 := ip.To4(); len(ip4) == net.IPv4len {
@@ -58,21 +63,30 @@ func NewHostsDB(loader BlocklistLoader) (*HostsDB, error) {
}
filters[name] = ips
}
reverseAddr, err := dns.ReverseAddr(ipString)
if err != nil {
continue
}
ptrMap[reverseAddr] = names[0]
}
return &HostsDB{filters, loader}, nil
return &HostsDB{filters, ptrMap, loader}, nil
}
func (m *HostsDB) Reload() (BlocklistDB, error) {
return NewHostsDB(m.loader)
}
func (m *HostsDB) Match(q dns.Question) (net.IP, string, bool) {
func (m *HostsDB) Match(q dns.Question) (net.IP, string, string, bool) {
if q.Qtype == dns.TypePTR {
name, ok := m.ptrMap[q.Name]
return nil, name, "", ok
}
name := strings.TrimSuffix(q.Name, ".")
ips, ok := m.filters[name]
if q.Qtype == dns.TypeA {
return ips.ip4, ips.ip4.String() + " " + name, ok
return ips.ip4, "", ips.ip4.String() + " " + name, ok
}
return ips.ip6, ips.ip6.String() + " " + name, ok
return ips.ip6, "", ips.ip6.String() + " " + name, ok
}
func (m *HostsDB) String() string {

View File

@@ -39,7 +39,7 @@ func TestHostsDB(t *testing.T) {
}
for _, test := range tests {
q := dns.Question{Name: test.q, Qtype: test.typ, Qclass: dns.ClassINET}
ip, _, ok := m.Match(q)
ip, _, _, ok := m.Match(q)
require.Equal(t, test.match, ok, "query: %s", test.q)
require.Equal(t, test.ip, ip, "query: %s", test.q)
}

View File

@@ -30,13 +30,13 @@ func (m MultiDB) Reload() (BlocklistDB, error) {
return NewMultiDB(newDBs...)
}
func (m MultiDB) Match(q dns.Question) (net.IP, string, bool) {
func (m MultiDB) Match(q dns.Question) (net.IP, string, string, bool) {
for _, db := range m.dbs {
if ip, rule, ok := db.Match(q); ok {
return ip, rule, ok
if ip, name, rule, ok := db.Match(q); ok {
return ip, name, rule, ok
}
}
return nil, "", false
return nil, "", "", false
}
func (m MultiDB) String() string {

View File

@@ -42,13 +42,13 @@ func (m *RegexpDB) Reload() (BlocklistDB, error) {
return NewRegexpDB(m.loader)
}
func (m *RegexpDB) Match(q dns.Question) (net.IP, string, bool) {
func (m *RegexpDB) Match(q dns.Question) (net.IP, string, string, bool) {
for _, rule := range m.rules {
if rule.MatchString(q.Name) {
return nil, rule.String(), true
return nil, "", rule.String(), true
}
}
return nil, "", false
return nil, "", "", false
}
func (m *RegexpDB) String() string {

View File

@@ -14,7 +14,7 @@ type BlocklistDB interface {
// Returns true if the question matches a rule. If the IP is not nil,
// respond with the given IP. NXDOMAIN otherwise.
Match(q dns.Question) (net.IP, string, bool)
Match(q dns.Question) (net.IP, string, string, bool)
fmt.Stringer
}

View File

@@ -35,7 +35,24 @@ func refused(q *dns.Msg) *dns.Msg {
// Build a response for a query with the given responce code.
func responseWithCode(q *dns.Msg, rcode int) *dns.Msg {
a := new(dns.Msg)
a.SetReply(q)
a.SetRcode(q, rcode)
return a
}
// Answers a PTR query with a name
func ptr(q *dns.Msg, name string) *dns.Msg {
a := new(dns.Msg)
a.SetReply(q)
a.Answer = []dns.RR{
&dns.PTR{
Hdr: dns.RR_Header{
Name: q.Question[0].Name,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 3600,
},
Ptr: dns.Fqdn(name),
},
}
return a
}

View File

@@ -87,7 +87,7 @@ func (r *ResponseBlocklistName) blockIfMatch(query, answer *dns.Msg, ci ClientIn
default:
continue
}
if _, rule, ok := r.BlocklistDB.Match(dns.Question{Name: name}); ok {
if _, _, rule, ok := r.BlocklistDB.Match(dns.Question{Name: name}); ok {
log := logger(r.id, query, ci).WithField("rule", rule)
if r.BlocklistResolver != nil {
log.WithField("resolver", r.BlocklistResolver).Debug("blocklist match, forwarding to blocklist-resolver")