diff --git a/backend/dns/server/handler.go b/backend/dns/server/handler.go index c357993..027f4fd 100644 --- a/backend/dns/server/handler.go +++ b/backend/dns/server/handler.go @@ -88,10 +88,9 @@ func (s *DNSServer) processQuery(request *Request) model.RequestLogEntry { func (s *DNSServer) reverseHostnameLookup(requestedHostname string) (string, bool) { trimmed := strings.TrimSuffix(requestedHostname, ".") - if value, ok := s.clientHostnameCache.Load(trimmed); ok { - if ip, ok := value.(string); ok { - return ip, true + if client, ok := value.(*model.Client); ok { + return client.IP, true } } @@ -99,43 +98,30 @@ func (s *DNSServer) reverseHostnameLookup(requestedHostname string) (string, boo } func (s *DNSServer) getClientInfo(ip net.IP) *model.Client { - clientIP := ip.String() + var ( + clientIP = ip.String() + isLoopback = ip.IsLoopback() + ) + + if isLoopback { + if localIP, err := getLocalIP(); err == nil { + clientIP = localIP + } else { + log.Warning("Failed to get local IP: %v", err) + clientIP = IPv4Loopback + } + } if loaded, ok := s.clientIPCache.Load(clientIP); ok { - client, ok := loaded.(model.Client) - if ok { - return &client + if client, ok := loaded.(*model.Client); ok { + return client } } macAddress := arp.GetMacAddress(clientIP) hostname := s.resolveHostname(clientIP) - resultIP := clientIP - - vendor, err := s.MACService.FindVendor(macAddress) - if macAddress != unknownHostname { - if err != nil || vendor == "" { - log.Debug("Lookup vendor for mac %s", macAddress) - vendor, err = arp.GetMacVendor(macAddress) - if err == nil { - s.MACService.SaveMac(clientIP, macAddress, vendor) - } else { - log.Warning( - "Was not able to find vendor for addr '%s' with MAC '%s'. %v", - clientIP, macAddress, err, - ) - } - } - } - - if ip.IsLoopback() { - localIP, err := getLocalIP() - if err != nil { - log.Warning("Failed to get local IP: %v", err) - localIP = IPv4Loopback - } - resultIP = localIP + if isLoopback { if h, err := os.Hostname(); err == nil { hostname = h } else { @@ -143,8 +129,9 @@ func (s *DNSServer) getClientInfo(ip net.IP) *model.Client { } } - client := model.Client{ - IP: resultIP, + vendor := s.lookupVendor(clientIP, macAddress) + client := &model.Client{ + IP: clientIP, LastSeen: time.Now(), Name: hostname, Mac: macAddress, @@ -152,13 +139,34 @@ func (s *DNSServer) getClientInfo(ip net.IP) *model.Client { Bypass: false, } - s.clientIPCache.Store(clientIP, client) + log.Debug("Saving new client: %s", client.IP) + _ = s.PopulateClientCaches() - if client.Name != unknownHostname { - s.clientHostnameCache.Store(client.Name, client) + return client +} + +func (s *DNSServer) lookupVendor(clientIP, macAddress string) string { + if macAddress == unknownHostname { + return "" } - return &client + vendor, err := s.MACService.FindVendor(macAddress) + if err == nil && vendor != "" { + return vendor + } + + log.Debug("Lookup vendor for mac %s", macAddress) + vendor, err = arp.GetMacVendor(macAddress) + if err != nil { + log.Warning( + "Was not able to find vendor for addr '%s' with MAC '%s'. %v", + clientIP, macAddress, err, + ) + return "" + } + + s.MACService.SaveMac(clientIP, macAddress, vendor) + return vendor } func (s *DNSServer) resolveHostname(clientIP string) string { diff --git a/backend/dns/server/server.go b/backend/dns/server/server.go index 0eeb9d0..8216e2f 100644 --- a/backend/dns/server/server.go +++ b/backend/dns/server/server.go @@ -190,8 +190,8 @@ func (s *DNSServer) PopulateClientCaches() error { } for _, client := range clients { - s.clientHostnameCache.Store(client.Name, client) - s.clientIPCache.Store(client.IP, client) + s.clientHostnameCache.Store(client.Name, &client) + s.clientIPCache.Store(client.IP, &client) } log.Debug("Populated client caches with %d client(s)", len(clients)) diff --git a/backend/request/repository.go b/backend/request/repository.go index d874333..823de30 100644 --- a/backend/request/repository.go +++ b/backend/request/repository.go @@ -489,15 +489,15 @@ func (r *repository) CountQueries(search string) (int, error) { } func (r *repository) UpdateClientBypass(ip string, bypass bool) error { - result := r.db.Model(&database.MacAddress{}). + err := r.db.Model(&database.MacAddress{}). Where("ip = ?", ip). Updates(map[string]any{ "bypass": bypass, "updated_at": time.Now(), - }) + }).Error - if result.Error != nil { - return fmt.Errorf("failed to update client bypass: %w", result.Error) + if err != nil { + return fmt.Errorf("failed to update client bypass: %w", err) } return nil