mirror of
https://github.com/pommee/goaway.git
synced 2026-05-14 13:28:36 -05:00
fix: fix bypass bug and cleanup client name/ip caches
This commit is contained in:
@@ -88,10 +88,9 @@ func (s *DNSServer) processQuery(request *Request) model.RequestLogEntry {
|
||||
|
||||
func (s *DNSServer) reverseHostnameLookup(requestedHostname string) (string, bool) {
|
||||
trimmed := strings.TrimSuffix(requestedHostname, ".")
|
||||
|
||||
if value, ok := s.clientHostnameCache.Load(trimmed); ok {
|
||||
if ip, ok := value.(string); ok {
|
||||
return ip, true
|
||||
if client, ok := value.(*model.Client); ok {
|
||||
return client.IP, true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,43 +98,30 @@ func (s *DNSServer) reverseHostnameLookup(requestedHostname string) (string, boo
|
||||
}
|
||||
|
||||
func (s *DNSServer) getClientInfo(ip net.IP) *model.Client {
|
||||
clientIP := ip.String()
|
||||
var (
|
||||
clientIP = ip.String()
|
||||
isLoopback = ip.IsLoopback()
|
||||
)
|
||||
|
||||
if isLoopback {
|
||||
if localIP, err := getLocalIP(); err == nil {
|
||||
clientIP = localIP
|
||||
} else {
|
||||
log.Warning("Failed to get local IP: %v", err)
|
||||
clientIP = IPv4Loopback
|
||||
}
|
||||
}
|
||||
|
||||
if loaded, ok := s.clientIPCache.Load(clientIP); ok {
|
||||
client, ok := loaded.(model.Client)
|
||||
if ok {
|
||||
return &client
|
||||
if client, ok := loaded.(*model.Client); ok {
|
||||
return client
|
||||
}
|
||||
}
|
||||
|
||||
macAddress := arp.GetMacAddress(clientIP)
|
||||
hostname := s.resolveHostname(clientIP)
|
||||
resultIP := clientIP
|
||||
|
||||
vendor, err := s.MACService.FindVendor(macAddress)
|
||||
if macAddress != unknownHostname {
|
||||
if err != nil || vendor == "" {
|
||||
log.Debug("Lookup vendor for mac %s", macAddress)
|
||||
vendor, err = arp.GetMacVendor(macAddress)
|
||||
if err == nil {
|
||||
s.MACService.SaveMac(clientIP, macAddress, vendor)
|
||||
} else {
|
||||
log.Warning(
|
||||
"Was not able to find vendor for addr '%s' with MAC '%s'. %v",
|
||||
clientIP, macAddress, err,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ip.IsLoopback() {
|
||||
localIP, err := getLocalIP()
|
||||
if err != nil {
|
||||
log.Warning("Failed to get local IP: %v", err)
|
||||
localIP = IPv4Loopback
|
||||
}
|
||||
resultIP = localIP
|
||||
|
||||
if isLoopback {
|
||||
if h, err := os.Hostname(); err == nil {
|
||||
hostname = h
|
||||
} else {
|
||||
@@ -143,8 +129,9 @@ func (s *DNSServer) getClientInfo(ip net.IP) *model.Client {
|
||||
}
|
||||
}
|
||||
|
||||
client := model.Client{
|
||||
IP: resultIP,
|
||||
vendor := s.lookupVendor(clientIP, macAddress)
|
||||
client := &model.Client{
|
||||
IP: clientIP,
|
||||
LastSeen: time.Now(),
|
||||
Name: hostname,
|
||||
Mac: macAddress,
|
||||
@@ -152,13 +139,34 @@ func (s *DNSServer) getClientInfo(ip net.IP) *model.Client {
|
||||
Bypass: false,
|
||||
}
|
||||
|
||||
s.clientIPCache.Store(clientIP, client)
|
||||
log.Debug("Saving new client: %s", client.IP)
|
||||
_ = s.PopulateClientCaches()
|
||||
|
||||
if client.Name != unknownHostname {
|
||||
s.clientHostnameCache.Store(client.Name, client)
|
||||
return client
|
||||
}
|
||||
|
||||
func (s *DNSServer) lookupVendor(clientIP, macAddress string) string {
|
||||
if macAddress == unknownHostname {
|
||||
return ""
|
||||
}
|
||||
|
||||
return &client
|
||||
vendor, err := s.MACService.FindVendor(macAddress)
|
||||
if err == nil && vendor != "" {
|
||||
return vendor
|
||||
}
|
||||
|
||||
log.Debug("Lookup vendor for mac %s", macAddress)
|
||||
vendor, err = arp.GetMacVendor(macAddress)
|
||||
if err != nil {
|
||||
log.Warning(
|
||||
"Was not able to find vendor for addr '%s' with MAC '%s'. %v",
|
||||
clientIP, macAddress, err,
|
||||
)
|
||||
return ""
|
||||
}
|
||||
|
||||
s.MACService.SaveMac(clientIP, macAddress, vendor)
|
||||
return vendor
|
||||
}
|
||||
|
||||
func (s *DNSServer) resolveHostname(clientIP string) string {
|
||||
|
||||
@@ -190,8 +190,8 @@ func (s *DNSServer) PopulateClientCaches() error {
|
||||
}
|
||||
|
||||
for _, client := range clients {
|
||||
s.clientHostnameCache.Store(client.Name, client)
|
||||
s.clientIPCache.Store(client.IP, client)
|
||||
s.clientHostnameCache.Store(client.Name, &client)
|
||||
s.clientIPCache.Store(client.IP, &client)
|
||||
}
|
||||
|
||||
log.Debug("Populated client caches with %d client(s)", len(clients))
|
||||
|
||||
@@ -489,15 +489,15 @@ func (r *repository) CountQueries(search string) (int, error) {
|
||||
}
|
||||
|
||||
func (r *repository) UpdateClientBypass(ip string, bypass bool) error {
|
||||
result := r.db.Model(&database.MacAddress{}).
|
||||
err := r.db.Model(&database.MacAddress{}).
|
||||
Where("ip = ?", ip).
|
||||
Updates(map[string]any{
|
||||
"bypass": bypass,
|
||||
"updated_at": time.Now(),
|
||||
})
|
||||
}).Error
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to update client bypass: %w", result.Error)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update client bypass: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user