fix: use net package instead of parsing raw strings for ip

This commit is contained in:
pommee
2025-12-20 10:08:44 +01:00
parent 7868f8949f
commit fb0b3bb9a8
3 changed files with 32 additions and 13 deletions

View File

@@ -75,11 +75,11 @@ func (s *DNSServer) handleDoHRequest(w http.ResponseWriter, r *http.Request) {
var (
clientIP, _, _ = net.SplitHostPort(r.RemoteAddr)
xRealIP = r.Header.Get("X-Real-IP")
xRealIP = net.ParseIP(r.Header.Get("X-Real-IP"))
client model.Client
)
if xRealIP != "" {
go s.WSCom(communicationMessage{IP: xRealIP, Client: true, Upstream: false, DNS: false})
if xRealIP != nil {
go s.WSCom(communicationMessage{IP: xRealIP.String(), Client: true, Upstream: false, DNS: false})
} else {
go s.WSCom(communicationMessage{IP: clientIP, Client: true, Upstream: false, DNS: false})
}
@@ -127,12 +127,11 @@ func (s *DNSServer) handleDoHRequest(w http.ResponseWriter, r *http.Request) {
DoHPort: s.Config.DNS.Ports.DoH,
}
if xRealIP != "" {
// TODO: Remove mock port once 'getClientInfo' handles real net.IP as input
client = *s.getClientInfo(xRealIP + ":")
if xRealIP != nil {
client = *s.getClientInfo(xRealIP)
clientIP = client.IP
} else {
client = *s.getClientInfo(r.RemoteAddr)
client = *s.getClientInfo(net.ParseIP(clientIP))
}
req := &Request{

View File

@@ -93,8 +93,8 @@ func (s *DNSServer) reverseHostnameLookup(requestedHostname string) (string, boo
return "", false
}
func (s *DNSServer) getClientInfo(remoteAddr string) *model.Client {
clientIP, _, _ := net.SplitHostPort(remoteAddr)
func (s *DNSServer) getClientInfo(ip net.IP) *model.Client {
clientIP := ip.String()
if cachedClient, ok := s.clientCache.Load(clientIP); ok {
return cachedClient.(*model.Client)
@@ -112,12 +112,15 @@ func (s *DNSServer) getClientInfo(remoteAddr string) *model.Client {
if err == nil {
s.MACService.SaveMac(clientIP, macAddress, vendor)
} else {
log.Warning("Was not able to find vendor for addr '%s' with MAC '%s'. %v", remoteAddr, macAddress, err)
log.Warning(
"Was not able to find vendor for addr '%s' with MAC '%s'. %v",
clientIP, macAddress, err,
)
}
}
}
if clientIP == IPv4Loopback || clientIP == "::1" || clientIP == "[" {
if ip.IsLoopback() {
localIP, err := getLocalIP()
if err != nil {
log.Warning("Failed to get local IP: %v", err)
@@ -132,7 +135,12 @@ func (s *DNSServer) getClientInfo(remoteAddr string) *model.Client {
}
}
client := model.Client{IP: resultIP, Name: hostname, MAC: macAddress}
client := model.Client{
IP: resultIP,
Name: hostname,
MAC: macAddress,
}
s.clientCache.Store(clientIP, &client)
if client.Name != unknownHostname {

View File

@@ -99,7 +99,18 @@ func (s *DNSServer) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
client := s.getClientInfo(w.RemoteAddr().String())
var clientIP net.IP
switch addr := w.RemoteAddr().(type) {
case *net.UDPAddr:
clientIP = addr.IP
case *net.TCPAddr:
clientIP = addr.IP
default:
return
}
client := s.getClientInfo(clientIP)
protocol := s.detectProtocol(w)
go s.WSCom(communicationMessage{
@@ -125,6 +136,7 @@ func (s *DNSServer) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
DNS: true,
IP: client.IP,
})
s.logEntryChannel <- entry
}