diff --git a/blocklistdb-domain.go b/blocklistdb-domain.go index 073f02a..6e5edbf 100644 --- a/blocklistdb-domain.go +++ b/blocklistdb-domain.go @@ -1,118 +1,118 @@ -package rdns - -import ( - "fmt" - "net" - "strings" - - "github.com/miekg/dns" -) - -// DomainDB holds a list of domain strings (potentially with wildcards). Matching -// logic: -// domain.com: matches just domain.com and not subdomains -// .domain.com: matches domain.com and all subdomains -// *.domain.com: matches all subdomains but not domain.com -type DomainDB struct { - name string - root node - loader BlocklistLoader -} - -type node map[string]node - -var _ BlocklistDB = &DomainDB{} - -// NewDomainDB returns a new instance of a matcher for a list of regular expressions. -func NewDomainDB(name string, loader BlocklistLoader) (*DomainDB, error) { - rules, err := loader.Load() - if err != nil { - return nil, err - } - root := make(node) - for _, r := range rules { - r = strings.TrimSpace(r) - - // Strip trailing . in case the list has FQDN names with . suffixes. - r = strings.TrimSuffix(r, ".") - - // Break up the domain into its parts and iterare backwards over them, building - // a graph of maps - parts := strings.Split(r, ".") - n := root - for i := len(parts) - 1; i >= 0; i-- { - part := parts[i] - - // Only allow wildcards as the first domain part, and not in a string - if strings.Contains(part, "*") && (i > 0 || len(part) != 1) { - return nil, fmt.Errorf("invalid blocklist item: '%s'", part) - } - - subNode, ok := n[part] - if !ok { - subNode = make(node) - n[part] = subNode - } - n = subNode - } - } - return &DomainDB{name, root, loader}, nil -} - -func (m *DomainDB) Reload() (BlocklistDB, error) { - return NewDomainDB(m.name, m.loader) -} - -func (m *DomainDB) Match(q dns.Question) (net.IP, string, *BlocklistMatch, bool) { - s := strings.TrimSuffix(q.Name, ".") - var matched []string - parts := strings.Split(s, ".") - n := m.root - for i := len(parts) - 1; i >= 0; i-- { - part := parts[i] - subNode, ok := n[part] - if !ok { - return nil, "", nil, false - } - matched = append(matched, part) - if _, ok := subNode[""]; ok { // exact and sub-domain match - return nil, - "", - &BlocklistMatch{ - List: m.name, - Rule: matchedDomainParts(".", matched), - }, - true - } - if _, ok := subNode["*"]; ok && i > 0 { // wildcard match on sub-domains - return nil, - "", - &BlocklistMatch{ - List: m.name, - Rule: matchedDomainParts("*.", matched), - }, - true - } - n = subNode - } - return nil, - "", - &BlocklistMatch{ - List: m.name, - Rule: matchedDomainParts("", matched), - }, - len(n) == 0 // exact match -} - -func (m *DomainDB) String() string { - return "Domain" -} - -// Turn a list of matched domain fragments into a domain (rule) -func matchedDomainParts(prefix string, p []string) string { - for i := len(p)/2 - 1; i >= 0; i-- { - opp := len(p) - 1 - i - p[i], p[opp] = p[opp], p[i] - } - return prefix + strings.Join(p, ".") -} +package rdns + +import ( + "fmt" + "net" + "strings" + + "github.com/miekg/dns" +) + +// DomainDB holds a list of domain strings (potentially with wildcards). Matching +// logic: +// domain.com: matches just domain.com and not subdomains +// .domain.com: matches domain.com and all subdomains +// *.domain.com: matches all subdomains but not domain.com +type DomainDB struct { + name string + root node + loader BlocklistLoader +} + +type node map[string]node + +var _ BlocklistDB = &DomainDB{} + +// NewDomainDB returns a new instance of a matcher for a list of regular expressions. +func NewDomainDB(name string, loader BlocklistLoader) *DomainDB { + return &DomainDB{name, nil, loader} +} + +func (m *DomainDB) Reload() (BlocklistDB, error) { + rules, err := m.loader.Load() + if err != nil { + return nil, err + } + root := make(node) + for _, r := range rules { + r = strings.TrimSpace(r) + + // Strip trailing . in case the list has FQDN names with . suffixes. + r = strings.TrimSuffix(r, ".") + + // Break up the domain into its parts and iterate backwards over them, building + // a graph of maps + parts := strings.Split(r, ".") + n := root + for i := len(parts) - 1; i >= 0; i-- { + part := parts[i] + + // Only allow wildcards as the first domain part, and not in a string + if strings.Contains(part, "*") && (i > 0 || len(part) != 1) { + return nil, fmt.Errorf("invalid blocklist item: '%s'", part) + } + + subNode, ok := n[part] + if !ok { + subNode = make(node) + n[part] = subNode + } + n = subNode + } + } + return &DomainDB{m.name, root, m.loader}, nil +} + +func (m *DomainDB) Match(q dns.Question) (net.IP, string, *BlocklistMatch, bool) { + s := strings.TrimSuffix(q.Name, ".") + var matched []string + parts := strings.Split(s, ".") + n := m.root + for i := len(parts) - 1; i >= 0; i-- { + part := parts[i] + subNode, ok := n[part] + if !ok { + return nil, "", nil, false + } + matched = append(matched, part) + if _, ok := subNode[""]; ok { // exact and sub-domain match + return nil, + "", + &BlocklistMatch{ + List: m.name, + Rule: matchedDomainParts(".", matched), + }, + true + } + if _, ok := subNode["*"]; ok && i > 0 { // wildcard match on sub-domains + return nil, + "", + &BlocklistMatch{ + List: m.name, + Rule: matchedDomainParts("*.", matched), + }, + true + } + n = subNode + } + return nil, + "", + &BlocklistMatch{ + List: m.name, + Rule: matchedDomainParts("", matched), + }, + len(n) == 0 // exact match +} + +func (m *DomainDB) String() string { + return "Domain" +} + +// Turn a list of matched domain fragments into a domain (rule) +func matchedDomainParts(prefix string, p []string) string { + for i := len(p)/2 - 1; i >= 0; i-- { + opp := len(p) - 1 - i + p[i], p[opp] = p[opp], p[i] + } + return prefix + strings.Join(p, ".") +} diff --git a/blocklistdb-hosts.go b/blocklistdb-hosts.go index 4479bf2..96aa561 100644 --- a/blocklistdb-hosts.go +++ b/blocklistdb-hosts.go @@ -25,8 +25,12 @@ type ipRecords struct { var _ BlocklistDB = &HostsDB{} // NewHostsDB returns a new instance of a matcher for a list of regular expressions. -func NewHostsDB(name string, loader BlocklistLoader) (*HostsDB, error) { - rules, err := loader.Load() +func NewHostsDB(name string, loader BlocklistLoader) *HostsDB { + return &HostsDB{name, nil,nil, loader} +} + +func (m *HostsDB) Reload() (BlocklistDB, error) { + rules, err := m.loader.Load() if err != nil { return nil, err } @@ -70,11 +74,8 @@ func NewHostsDB(name string, loader BlocklistLoader) (*HostsDB, error) { } ptrMap[reverseAddr] = names[0] } - return &HostsDB{name, filters, ptrMap, loader}, nil -} + return &HostsDB{m.name, filters, ptrMap, m.loader}, nil -func (m *HostsDB) Reload() (BlocklistDB, error) { - return NewHostsDB(m.name, m.loader) } func (m *HostsDB) Match(q dns.Question) (net.IP, string, *BlocklistMatch, bool) { diff --git a/blocklistdb-regexp.go b/blocklistdb-regexp.go index f119590..74857d6 100644 --- a/blocklistdb-regexp.go +++ b/blocklistdb-regexp.go @@ -18,12 +18,17 @@ type RegexpDB struct { var _ BlocklistDB = &RegexpDB{} // NewRegexpDB returns a new instance of a matcher for a list of regular expressions. -func NewRegexpDB(name string, loader BlocklistLoader) (*RegexpDB, error) { - rules, err := loader.Load() +func NewRegexpDB(name string, loader BlocklistLoader) *RegexpDB { + return &RegexpDB{name, nil, loader} +} + +func (m *RegexpDB) Reload() (BlocklistDB, error) { + rules, err := m.loader.Load() if err != nil { return nil, err } var filters []*regexp.Regexp + for _, r := range rules { r = strings.TrimSpace(r) if r == "" || strings.HasPrefix(r, "#") { @@ -36,11 +41,7 @@ func NewRegexpDB(name string, loader BlocklistLoader) (*RegexpDB, error) { filters = append(filters, re) } - return &RegexpDB{name, filters, loader}, nil -} - -func (m *RegexpDB) Reload() (BlocklistDB, error) { - return NewRegexpDB(m.name, m.loader) + return &RegexpDB{m.name, filters, m.loader}, nil } func (m *RegexpDB) Match(q dns.Question) (net.IP, string, *BlocklistMatch, bool) { diff --git a/cidr-db.go b/cidr-db.go index b63a973..7db09c0 100644 --- a/cidr-db.go +++ b/cidr-db.go @@ -17,16 +17,20 @@ type CidrDB struct { var _ IPBlocklistDB = &CidrDB{} // NewCidrDB returns a new instance of a matcher for a list of networks. -func NewCidrDB(name string, loader BlocklistLoader) (*CidrDB, error) { - rules, err := loader.Load() +func NewCidrDB(name string, loader BlocklistLoader) *CidrDB { + return &CidrDB{name, nil, nil, loader} +} + +func (m *CidrDB) Reload() (IPBlocklistDB, error) { + rules, err := m.loader.Load() if err != nil { return nil, err } db := &CidrDB{ - name: name, + name: m.name, ip4: new(ipBlocklistTrie), ip6: new(ipBlocklistTrie), - loader: loader, + loader: m.loader, } for _, r := range rules { r = strings.TrimSpace(r) @@ -51,11 +55,7 @@ func NewCidrDB(name string, loader BlocklistLoader) (*CidrDB, error) { db.ip4.add(n) } } - return db, nil -} - -func (m *CidrDB) Reload() (IPBlocklistDB, error) { - return NewCidrDB(m.name, m.loader) + return &CidrDB{m.name, db.ip4, db.ip6, m.loader}, nil } func (m *CidrDB) Match(ip net.IP) (*BlocklistMatch, bool) { diff --git a/cmd/routedns/config.go b/cmd/routedns/config.go index a3d33f9..0a13f0e 100644 --- a/cmd/routedns/config.go +++ b/cmd/routedns/config.go @@ -142,10 +142,11 @@ type group struct { // Block/Allowlist items for blocklist-v2 type list struct { - Name string - Format string - Source string - CacheDir string `toml:"cache-dir"` // Where to store copies of remote blocklists for faster startup + Name string + Format string + Source string + CacheDir string `toml:"cache-dir"` // Where to store copies of remote blocklists for faster startup + AllowFailOnStart bool `toml:"allow-fail-on-startup"` // Don't fail if the blocklist can't be loaded on startup, just print a warning } type router struct { diff --git a/cmd/routedns/main.go b/cmd/routedns/main.go index 5301a8c..927f52c 100644 --- a/cmd/routedns/main.go +++ b/cmd/routedns/main.go @@ -768,18 +768,28 @@ func newBlocklistDB(l list, rules []string) (rdns.BlocklistDB, error) { return nil, fmt.Errorf("unsupported scheme '%s' in '%s'", loc.Scheme, l.Source) } } + var db rdns.BlocklistDB switch l.Format { case "regexp", "": - return rdns.NewRegexpDB(name, loader) + db = rdns.NewRegexpDB(name, loader) case "domain": - return rdns.NewDomainDB(name, loader) + db = rdns.NewDomainDB(name, loader) case "hosts": - return rdns.NewHostsDB(name, loader) + db = rdns.NewHostsDB(name, loader) default: return nil, fmt.Errorf("unsupported format '%s'", l.Format) } + db, err = db.Reload() + if err != nil { + rdns.Log.WithError(err).Warn("failed to load list") + if !l.AllowFailOnStart { + return nil, fmt.Errorf("failed to load list on startup, set allow-fail-on-startup to skip: %w", err) + } + } + return db, nil } + func newIPBlocklistDB(l list, locationDB string, rules []string) (rdns.IPBlocklistDB, error) { loc, err := url.Parse(l.Source) if err != nil { @@ -805,15 +815,24 @@ func newIPBlocklistDB(l list, locationDB string, rules []string) (rdns.IPBlockli return nil, fmt.Errorf("unsupported scheme '%s' in '%s'", loc.Scheme, l.Source) } } - + var db rdns.IPBlocklistDB switch l.Format { - case "cidr", "": - return rdns.NewCidrDB(name, loader) - case "location": - return rdns.NewGeoIPDB(name, loader, locationDB) - default: - return nil, fmt.Errorf("unsupported format '%s'", l.Format) + case "cidr", "": + db = rdns.NewCidrDB(name, loader) + case "location": + return rdns.NewGeoIPDB(name, loader, locationDB) + default: + return nil, fmt.Errorf("unsupported format '%s'", l.Format) } + db, err = db.Reload() + + if err != nil { + rdns.Log.WithError(err).Warn("failed to load list") + if !l.AllowFailOnStart { + return nil, fmt.Errorf("failed to load list on startup, set allow-fail-on-startup to skip: %w", err) + } + } + return db, nil } func printVersion() {