mirror of
https://github.com/pommee/goaway.git
synced 2026-01-08 14:59:42 -06:00
feat: added ability to set 'bypass' for each client to bypass any rules
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
model "goaway/backend/dns/server/models"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -8,9 +9,12 @@ import (
|
||||
|
||||
func (api *API) registerClientRoutes() {
|
||||
api.routes.GET("/clients", api.getClients)
|
||||
api.routes.GET("/clientDetails", api.getClientDetails)
|
||||
api.routes.GET("/clientHistory", api.getClientHistory)
|
||||
api.routes.GET("/topClients", api.getTopClients)
|
||||
|
||||
api.routes.GET("/client/:ip/details", api.getClientDetails)
|
||||
api.routes.GET("/client/:ip/history", api.getClientHistory)
|
||||
|
||||
api.routes.PUT("/client/:ip/bypass/:bypass", api.updateClientBypass)
|
||||
}
|
||||
|
||||
func (api *API) getClients(c *gin.Context) {
|
||||
@@ -20,51 +24,44 @@ func (api *API) getClients(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
clients := make([]map[string]interface{}, 0, len(uniqueClients))
|
||||
for ip, entry := range uniqueClients {
|
||||
clients = append(clients, map[string]interface{}{
|
||||
"ip": ip,
|
||||
"name": entry.Name,
|
||||
"lastSeen": entry.LastSeen,
|
||||
"mac": entry.Mac,
|
||||
"vendor": entry.Vendor,
|
||||
})
|
||||
clients := make([]model.Client, 0, len(uniqueClients))
|
||||
for _, entry := range uniqueClients {
|
||||
clients = append(clients, entry)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, clients)
|
||||
}
|
||||
|
||||
func (api *API) getClientDetails(c *gin.Context) {
|
||||
clientIP := c.DefaultQuery("clientIP", "")
|
||||
ip := c.Param("ip")
|
||||
|
||||
clientRequestDetails, mostQueriedDomain, domainQueryCounts, err := api.RequestService.GetClientDetailsWithDomains(clientIP)
|
||||
requestDetails, mostQueriedDomain, domainQueryCount, err := api.RequestService.GetClientDetailsWithDomains(ip)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
clientDetails, err := api.RequestService.FetchClient(ip)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, map[string]any{
|
||||
"ip": clientIP,
|
||||
"totalRequests": clientRequestDetails.TotalRequests,
|
||||
"uniqueDomains": clientRequestDetails.UniqueDomains,
|
||||
"blockedRequests": clientRequestDetails.BlockedRequests,
|
||||
"cachedRequests": clientRequestDetails.CachedRequests,
|
||||
"avgResponseTimeMs": clientRequestDetails.AvgResponseTimeMs,
|
||||
"totalRequests": requestDetails.TotalRequests,
|
||||
"uniqueDomains": requestDetails.UniqueDomains,
|
||||
"blockedRequests": requestDetails.BlockedRequests,
|
||||
"cachedRequests": requestDetails.CachedRequests,
|
||||
"avgResponseTimeMs": requestDetails.AvgResponseTimeMs,
|
||||
"mostQueriedDomain": mostQueriedDomain,
|
||||
"lastSeen": clientRequestDetails.LastSeen,
|
||||
"allDomains": domainQueryCounts,
|
||||
"allDomains": domainQueryCount,
|
||||
"clientInfo": clientDetails,
|
||||
})
|
||||
}
|
||||
|
||||
func (api *API) getClientHistory(c *gin.Context) {
|
||||
clientIP := c.Query("ip")
|
||||
|
||||
if clientIP == "" {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "No client ip was provided"})
|
||||
return
|
||||
}
|
||||
|
||||
history, err := api.RequestService.GetClientHistory(clientIP)
|
||||
ip := c.Param("ip")
|
||||
history, err := api.RequestService.GetClientHistory(ip)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -85,3 +82,28 @@ func (api *API) getTopClients(c *gin.Context) {
|
||||
|
||||
c.JSON(http.StatusOK, topClients)
|
||||
}
|
||||
|
||||
func (api *API) updateClientBypass(c *gin.Context) {
|
||||
ip := c.Param("ip")
|
||||
bypass := c.Param("bypass")
|
||||
|
||||
if bypass != "true" && bypass != "false" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Bypass value must be true or false"})
|
||||
return
|
||||
}
|
||||
|
||||
err := api.RequestService.UpdateClientBypass(ip, bypass == "true")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Refresh DNS server client caches to reflect the updated bypass status
|
||||
err = api.DNS.PopulateClientCaches()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to refresh DNS server client caches"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
@@ -65,6 +65,7 @@ type MacAddress struct {
|
||||
MAC string `gorm:"primaryKey" json:"mac" validate:"required,mac"`
|
||||
IP string `gorm:"index" json:"ip" validate:"required,ip"`
|
||||
Vendor string `json:"vendor"`
|
||||
Bypass bool `gorm:"default:false" json:"bypass"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"createdAt"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updatedAt"`
|
||||
}
|
||||
|
||||
@@ -46,7 +46,12 @@ func (s *DNSServer) checkAndUpdatePauseStatus() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DNSServer) shouldBlockQuery(domainName, fullName string) bool {
|
||||
func (s *DNSServer) shouldBlockQuery(client *model.Client, domainName, fullName string) bool {
|
||||
if client.Bypass {
|
||||
log.Debug("Allowing client '%s' to bypass %s", client.IP, fullName)
|
||||
return false
|
||||
}
|
||||
|
||||
return !s.Config.DNS.Status.Paused &&
|
||||
s.BlacklistService.IsBlacklisted(domainName) &&
|
||||
!s.WhitelistService.IsWhitelisted(fullName)
|
||||
@@ -65,7 +70,7 @@ func (s *DNSServer) processQuery(request *Request) model.RequestLogEntry {
|
||||
|
||||
s.checkAndUpdatePauseStatus()
|
||||
|
||||
if s.shouldBlockQuery(domainName, request.Question.Name) {
|
||||
if s.shouldBlockQuery(request.Client, domainName, request.Question.Name) {
|
||||
return s.handleBlacklisted(request)
|
||||
}
|
||||
|
||||
@@ -84,7 +89,7 @@ func (s *DNSServer) processQuery(request *Request) model.RequestLogEntry {
|
||||
func (s *DNSServer) reverseHostnameLookup(requestedHostname string) (string, bool) {
|
||||
trimmed := strings.TrimSuffix(requestedHostname, ".")
|
||||
|
||||
if value, ok := s.hostnameCache.Load(trimmed); ok {
|
||||
if value, ok := s.clientHostnameCache.Load(trimmed); ok {
|
||||
if ip, ok := value.(string); ok {
|
||||
return ip, true
|
||||
}
|
||||
@@ -96,8 +101,11 @@ func (s *DNSServer) reverseHostnameLookup(requestedHostname string) (string, boo
|
||||
func (s *DNSServer) getClientInfo(ip net.IP) *model.Client {
|
||||
clientIP := ip.String()
|
||||
|
||||
if cachedClient, ok := s.clientCache.Load(clientIP); ok {
|
||||
return cachedClient.(*model.Client)
|
||||
if loaded, ok := s.clientIPCache.Load(clientIP); ok {
|
||||
client, ok := loaded.(model.Client)
|
||||
if ok {
|
||||
return &client
|
||||
}
|
||||
}
|
||||
|
||||
macAddress := arp.GetMacAddress(clientIP)
|
||||
@@ -136,15 +144,18 @@ func (s *DNSServer) getClientInfo(ip net.IP) *model.Client {
|
||||
}
|
||||
|
||||
client := model.Client{
|
||||
IP: resultIP,
|
||||
Name: hostname,
|
||||
MAC: macAddress,
|
||||
IP: resultIP,
|
||||
LastSeen: time.Now(),
|
||||
Name: hostname,
|
||||
Mac: macAddress,
|
||||
Vendor: vendor,
|
||||
Bypass: false,
|
||||
}
|
||||
|
||||
s.clientCache.Store(clientIP, &client)
|
||||
s.clientIPCache.Store(clientIP, client)
|
||||
|
||||
if client.Name != unknownHostname {
|
||||
s.hostnameCache.Store(client.Name, client.IP)
|
||||
s.clientHostnameCache.Store(client.Name, client)
|
||||
}
|
||||
|
||||
return &client
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// Client represents a DNS client with associated metadata.
|
||||
// It includes the client's IP address, hostname, MAC address, and an ignored flag.'
|
||||
// The 'bypass' field indicates whether the client should be allowed to bypass blacklist rules.
|
||||
type Client struct {
|
||||
IP string `json:"ip"`
|
||||
Name string `json:"name"`
|
||||
MAC string `json:"mac"`
|
||||
IP string `json:"ip"`
|
||||
LastSeen time.Time `json:"lastSeen"`
|
||||
Name string `json:"name"`
|
||||
Mac string `json:"mac"`
|
||||
Vendor string `json:"vendor"`
|
||||
Bypass bool `json:"bypass"`
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ type RequestLogEntry struct {
|
||||
|
||||
func (r *RequestLogEntry) String() string {
|
||||
return fmt.Sprintf(
|
||||
"Time: %d, Client: %s, Domain: %s, Status: %s, Type: %s, Protocol: %s, IPs: %+v, ID: %d, ResponseSize: %d, ResponseTime: %dns, Blocked: %t, Cached: %t",
|
||||
"Time: %d, Client: %v, Domain: %s, Status: %s, Type: %s, Protocol: %s, IPs: %+v, ID: %d, ResponseSize: %d, ResponseTime: %dns, Blocked: %t, Cached: %t",
|
||||
r.Timestamp.Unix(),
|
||||
r.ClientInfo,
|
||||
r.Domain,
|
||||
|
||||
@@ -28,18 +28,42 @@ var (
|
||||
log = logging.GetLogger()
|
||||
)
|
||||
|
||||
// DNSServer encapsulates the DNS handling logic and the runtime state used by
|
||||
// the various DNS transports (UDP/TCP), secure transports (DoT) and HTTP-based frontends (DoH).
|
||||
type DNSServer struct {
|
||||
DBConn *gorm.DB
|
||||
dnsClient *dns.Client
|
||||
Config *settings.Config
|
||||
logEntryChannel chan model.RequestLogEntry
|
||||
WSQueries *websocket.Conn
|
||||
WSCommunication *websocket.Conn
|
||||
hostnameCache sync.Map
|
||||
clientCache sync.Map
|
||||
DomainCache sync.Map
|
||||
// Database connection used by services for persistence
|
||||
dbConn *gorm.DB
|
||||
|
||||
// Client used when querying upstream servers
|
||||
dnsClient *dns.Client
|
||||
|
||||
// Application level settings, mostly used for DNS behaviour
|
||||
Config *settings.Config
|
||||
|
||||
// Central channel where processed request log entries are pushed
|
||||
logEntryChannel chan model.RequestLogEntry
|
||||
|
||||
// Websocket connection used to stream query logs to the web UI
|
||||
WSQueries *websocket.Conn
|
||||
|
||||
// Websocket connection used to stream communication events to the UI
|
||||
// Used to visualize client/upstream/DNS activity
|
||||
WSCommunication *websocket.Conn
|
||||
|
||||
// Guards writes to WSCommunication.
|
||||
WSCommunicationLock sync.Mutex
|
||||
|
||||
// Cache mapping hostnames to client metadata to avoid repeated lookups when resolving PTR/hostnames
|
||||
clientHostnameCache sync.Map
|
||||
|
||||
// Cache mapping IP -> client info (name, mac) for quick lookup during request processing
|
||||
clientIPCache sync.Map
|
||||
|
||||
// In-memory cache for resolved DNS records to speed up responses and reduce upstream queries
|
||||
DomainCache sync.Map
|
||||
|
||||
// DNSServer delegates database-backed lookups and persistence to these services,
|
||||
// rather than performing raw DB operations itself.
|
||||
RequestService *request.Service
|
||||
AuditService *audit.Service
|
||||
UserService *user.Service
|
||||
@@ -85,7 +109,7 @@ func NewDNSServer(config *settings.Config, dbconn *gorm.DB, cert tls.Certificate
|
||||
|
||||
server := &DNSServer{
|
||||
Config: config,
|
||||
DBConn: dbconn,
|
||||
dbConn: dbconn,
|
||||
logEntryChannel: make(chan model.RequestLogEntry, 1000),
|
||||
dnsClient: &client,
|
||||
DomainCache: sync.Map{},
|
||||
@@ -157,13 +181,20 @@ func (s *DNSServer) detectProtocol(w dns.ResponseWriter) model.Protocol {
|
||||
return model.UDP
|
||||
}
|
||||
|
||||
func (s *DNSServer) PopulateHostnameCache() error {
|
||||
uniqueClients := s.RequestService.GetUniqueClientNameAndIP()
|
||||
for _, client := range uniqueClients {
|
||||
_, _ = s.hostnameCache.LoadOrStore(client.IP, client.Name)
|
||||
func (s *DNSServer) PopulateClientCaches() error {
|
||||
clients, err := s.RequestService.FetchAllClients()
|
||||
|
||||
if err != nil {
|
||||
log.Warning("Could not populate client caches, reason: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("Populated hostname cache with %d client(s)", len(uniqueClients))
|
||||
for _, client := range clients {
|
||||
s.clientHostnameCache.Store(client.Name, client)
|
||||
s.clientIPCache.Store(client.IP, client)
|
||||
}
|
||||
|
||||
log.Debug("Populated client caches with %d client(s)", len(clients))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ func (b *BackgroundJobs) Start(readyChan <-chan struct{}) {
|
||||
}
|
||||
|
||||
func (b *BackgroundJobs) startHostnameCachePopulation() {
|
||||
if err := b.registry.Context.DNSServer.PopulateHostnameCache(); err != nil {
|
||||
if err := b.registry.Context.DNSServer.PopulateClientCaches(); err != nil {
|
||||
log.Warning("Unable to populate hostname cache: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
package request
|
||||
|
||||
import "time"
|
||||
|
||||
type Client struct {
|
||||
LastSeen time.Time
|
||||
Name string
|
||||
Mac string
|
||||
Vendor string
|
||||
}
|
||||
|
||||
type ClientNameAndIP struct {
|
||||
Name string
|
||||
IP string
|
||||
}
|
||||
@@ -14,21 +14,23 @@ import (
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
SaveRequestLog(entries []model.RequestLogEntry) error
|
||||
|
||||
GetClientName(ip string) string
|
||||
GetDistinctRequestIP() int
|
||||
GetRequestSummaryByInterval(interval int) ([]model.RequestLogIntervalSummary, error)
|
||||
GetResponseSizeSummaryByInterval(intervalMinutes int) ([]model.ResponseSizeSummary, error)
|
||||
GetUniqueQueryTypes() ([]models.QueryTypeCount, error)
|
||||
FetchQueries(q models.QueryParams) ([]model.RequestLogEntry, error)
|
||||
GetUniqueClientNameAndIP() []database.RequestLog
|
||||
FetchAllClients() (map[string]Client, error)
|
||||
FetchClient(ip string) (*model.Client, error)
|
||||
FetchAllClients() (map[string]model.Client, error)
|
||||
GetClientDetailsWithDomains(clientIP string) (ClientRequestDetails, string, map[string]int, error)
|
||||
GetClientHistory(clientIP string) ([]models.DomainHistory, error)
|
||||
GetTopBlockedDomains(blockedRequests int) ([]map[string]interface{}, error)
|
||||
GetTopClients() ([]map[string]interface{}, error)
|
||||
CountQueries(search string) (int, error)
|
||||
|
||||
SaveRequestLog(entries []model.RequestLogEntry) error
|
||||
UpdateClientBypass(ip string, bypass bool) error
|
||||
|
||||
DeleteRequestLogsTimebased(vacuum vacuumFunc, requestThreshold, maxRetries int, retryDelay time.Duration) error
|
||||
}
|
||||
@@ -51,6 +53,38 @@ func NewRepository(db *gorm.DB) *repository {
|
||||
return &repository{db: db}
|
||||
}
|
||||
|
||||
func (r *repository) SaveRequestLog(entries []model.RequestLogEntry) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
for _, entry := range entries {
|
||||
rl := database.RequestLog{
|
||||
Timestamp: entry.Timestamp,
|
||||
Domain: entry.Domain,
|
||||
Blocked: entry.Blocked,
|
||||
Cached: entry.Cached,
|
||||
ResponseTimeNs: entry.ResponseTime.Nanoseconds(),
|
||||
ClientIP: entry.ClientInfo.IP,
|
||||
ClientName: entry.ClientInfo.Name,
|
||||
Status: entry.Status,
|
||||
QueryType: entry.QueryType,
|
||||
ResponseSizeBytes: entry.ResponseSizeBytes,
|
||||
Protocol: string(entry.Protocol),
|
||||
}
|
||||
|
||||
for _, resolvedIP := range entry.IP {
|
||||
rl.IPs = append(rl.IPs, database.RequestLogIP{
|
||||
IP: resolvedIP.IP,
|
||||
RecordType: resolvedIP.RType,
|
||||
})
|
||||
}
|
||||
|
||||
if err := tx.Create(&rl).Error; err != nil {
|
||||
return fmt.Errorf("could not save request log: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (r *repository) GetClientName(ip string) string {
|
||||
var hostname string
|
||||
err := r.db.Model(&database.RequestLog{}).
|
||||
@@ -239,13 +273,53 @@ func (r *repository) FetchQueries(q models.QueryParams) ([]model.RequestLogEntry
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (r *repository) FetchAllClients() (map[string]Client, error) {
|
||||
func (r *repository) FetchClient(ip string) (*model.Client, error) {
|
||||
var row struct {
|
||||
ClientIP string `gorm:"column:client_ip"`
|
||||
ClientName string `gorm:"column:client_name"`
|
||||
Timestamp time.Time `gorm:"column:timestamp"`
|
||||
Mac sql.NullString `gorm:"column:mac"`
|
||||
Vendor sql.NullString `gorm:"column:vendor"`
|
||||
Bypass sql.NullBool `gorm:"column:bypass"`
|
||||
}
|
||||
|
||||
subquery := r.db.Table("request_logs").
|
||||
Select("MAX(timestamp)").
|
||||
Where("client_ip = ?", ip)
|
||||
|
||||
if err := r.db.Table("request_logs r").
|
||||
Select("r.client_ip, r.client_name, r.timestamp, m.mac, m.vendor, m.bypass").
|
||||
Joins("LEFT JOIN mac_addresses m ON r.client_ip = m.ip").
|
||||
Where("r.client_ip = ?", ip).
|
||||
Where("r.timestamp = (?)", subquery).
|
||||
Scan(&row).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if row.ClientIP == "" {
|
||||
return nil, fmt.Errorf("client with ip '%s' was not found", ip)
|
||||
}
|
||||
|
||||
client := &model.Client{
|
||||
IP: row.ClientIP,
|
||||
Name: row.ClientName,
|
||||
LastSeen: row.Timestamp,
|
||||
Mac: row.Mac.String,
|
||||
Vendor: row.Vendor.String,
|
||||
Bypass: row.Bypass.Bool,
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (r *repository) FetchAllClients() (map[string]model.Client, error) {
|
||||
var rows []struct {
|
||||
ClientIP string `gorm:"column:client_ip"`
|
||||
ClientName string `gorm:"column:client_name"`
|
||||
Timestamp time.Time `gorm:"column:timestamp"`
|
||||
Mac sql.NullString `gorm:"column:mac"`
|
||||
Vendor sql.NullString `gorm:"column:vendor"`
|
||||
Bypass sql.NullBool `gorm:"column:bypass"`
|
||||
}
|
||||
|
||||
subquery := r.db.Table("request_logs").
|
||||
@@ -253,46 +327,28 @@ func (r *repository) FetchAllClients() (map[string]Client, error) {
|
||||
Group("client_ip")
|
||||
|
||||
if err := r.db.Table("request_logs r").
|
||||
Select("r.client_ip, r.client_name, r.timestamp, m.mac, m.vendor").
|
||||
Select("r.client_ip, r.client_name, r.timestamp, m.mac, m.vendor, m.bypass").
|
||||
Joins("INNER JOIN (?) latest ON r.client_ip = latest.client_ip AND r.timestamp = latest.max_timestamp", subquery).
|
||||
Joins("LEFT JOIN mac_addresses m ON r.client_ip = m.ip").
|
||||
Scan(&rows).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uniqueClients := make(map[string]Client, len(rows))
|
||||
uniqueClients := make(map[string]model.Client, len(rows))
|
||||
for _, row := range rows {
|
||||
macStr := ""
|
||||
vendorStr := ""
|
||||
if row.Mac.Valid {
|
||||
macStr = row.Mac.String
|
||||
}
|
||||
if row.Vendor.Valid {
|
||||
vendorStr = row.Vendor.String
|
||||
}
|
||||
|
||||
uniqueClients[row.ClientIP] = Client{
|
||||
uniqueClients[row.ClientIP] = model.Client{
|
||||
IP: row.ClientIP,
|
||||
Name: row.ClientName,
|
||||
LastSeen: row.Timestamp,
|
||||
Mac: macStr,
|
||||
Vendor: vendorStr,
|
||||
Mac: row.Mac.String,
|
||||
Vendor: row.Vendor.String,
|
||||
Bypass: row.Bypass.Bool,
|
||||
}
|
||||
}
|
||||
|
||||
return uniqueClients, nil
|
||||
}
|
||||
|
||||
func (r *repository) GetUniqueClientNameAndIP() []database.RequestLog {
|
||||
var results []database.RequestLog
|
||||
|
||||
r.db.Model(&database.RequestLog{}).
|
||||
Select("DISTINCT client_ip, client_name").
|
||||
Where("client_name IS NOT NULL AND client_name != ?", "unknown").
|
||||
Find(&results)
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
func (r *repository) GetClientDetailsWithDomains(clientIP string) (ClientRequestDetails, string, map[string]int, error) {
|
||||
var crd ClientRequestDetails
|
||||
err := r.db.Table("request_logs").
|
||||
@@ -432,36 +488,19 @@ func (r *repository) CountQueries(search string) (int, error) {
|
||||
return int(total), err
|
||||
}
|
||||
|
||||
func (r *repository) SaveRequestLog(entries []model.RequestLogEntry) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
for _, entry := range entries {
|
||||
rl := database.RequestLog{
|
||||
Timestamp: entry.Timestamp,
|
||||
Domain: entry.Domain,
|
||||
Blocked: entry.Blocked,
|
||||
Cached: entry.Cached,
|
||||
ResponseTimeNs: entry.ResponseTime.Nanoseconds(),
|
||||
ClientIP: entry.ClientInfo.IP,
|
||||
ClientName: entry.ClientInfo.Name,
|
||||
Status: entry.Status,
|
||||
QueryType: entry.QueryType,
|
||||
ResponseSizeBytes: entry.ResponseSizeBytes,
|
||||
Protocol: string(entry.Protocol),
|
||||
}
|
||||
func (r *repository) UpdateClientBypass(ip string, bypass bool) error {
|
||||
result := r.db.Model(&database.MacAddress{}).
|
||||
Where("ip = ?", ip).
|
||||
Updates(map[string]any{
|
||||
"bypass": bypass,
|
||||
"updated_at": time.Now(),
|
||||
})
|
||||
|
||||
for _, resolvedIP := range entry.IP {
|
||||
rl.IPs = append(rl.IPs, database.RequestLogIP{
|
||||
IP: resolvedIP.IP,
|
||||
RecordType: resolvedIP.RType,
|
||||
})
|
||||
}
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to update client bypass: %w", result.Error)
|
||||
}
|
||||
|
||||
if err := tx.Create(&rl).Error; err != nil {
|
||||
return fmt.Errorf("could not save request log: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repository) DeleteRequestLogsTimebased(vacuum vacuumFunc, requestThreshold, maxRetries int, retryDelay time.Duration) error {
|
||||
|
||||
@@ -18,6 +18,15 @@ func NewService(repo Repository) *Service {
|
||||
return &Service{repository: repo}
|
||||
}
|
||||
|
||||
func (s *Service) SaveRequestLog(entries []model.RequestLogEntry) error {
|
||||
if err := s.repository.SaveRequestLog(entries); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("Saved %d new request log(s)", len(entries))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) GetClientNameFromIP(ip string) string {
|
||||
return s.repository.GetClientName(ip)
|
||||
}
|
||||
@@ -42,22 +51,12 @@ func (s *Service) FetchQueries(q models.QueryParams) ([]model.RequestLogEntry, e
|
||||
return s.repository.FetchQueries(q)
|
||||
}
|
||||
|
||||
func (s *Service) FetchAllClients() (map[string]Client, error) {
|
||||
return s.repository.FetchAllClients()
|
||||
func (s *Service) FetchClient(ip string) (*model.Client, error) {
|
||||
return s.repository.FetchClient(ip)
|
||||
}
|
||||
|
||||
func (s *Service) GetUniqueClientNameAndIP() []ClientNameAndIP {
|
||||
queryResult := s.repository.GetUniqueClientNameAndIP()
|
||||
var uniqueClients []ClientNameAndIP
|
||||
|
||||
for _, client := range queryResult {
|
||||
uniqueClients = append(uniqueClients, ClientNameAndIP{
|
||||
Name: client.ClientName,
|
||||
IP: client.ClientIP,
|
||||
})
|
||||
}
|
||||
|
||||
return uniqueClients
|
||||
func (s *Service) FetchAllClients() (map[string]model.Client, error) {
|
||||
return s.repository.FetchAllClients()
|
||||
}
|
||||
|
||||
func (s *Service) GetClientDetailsWithDomains(clientIP string) (ClientRequestDetails, string, map[string]int, error) {
|
||||
@@ -80,8 +79,13 @@ func (s *Service) CountQueries(search string) (int, error) {
|
||||
return s.repository.CountQueries(search)
|
||||
}
|
||||
|
||||
func (s *Service) SaveRequestLog(entries []model.RequestLogEntry) error {
|
||||
return s.repository.SaveRequestLog(entries)
|
||||
func (s *Service) UpdateClientBypass(ip string, bypass bool) error {
|
||||
if err := s.repository.UpdateClientBypass(ip, bypass); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("Bypass toggled to %t for %s", bypass, ip)
|
||||
return nil
|
||||
}
|
||||
|
||||
type vacuumFunc func(ctx context.Context)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Dialog, DialogContent, DialogTitle } from "@/components/ui/dialog";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ClientEntry } from "@/pages/clients";
|
||||
import { GetRequest } from "@/util";
|
||||
import { GetRequest, PutRequest } from "@/util";
|
||||
import {
|
||||
CaretDownIcon,
|
||||
ClockCounterClockwiseIcon,
|
||||
@@ -18,6 +19,7 @@ import {
|
||||
import { useEffect, useState } from "react";
|
||||
import TimeAgo from "react-timeago";
|
||||
import { toast } from "sonner";
|
||||
import { SettingRow } from "../settings/SettingsRow";
|
||||
|
||||
type AllDomains = {
|
||||
[domain: string]: number;
|
||||
@@ -28,11 +30,17 @@ type ClientEntryDetails = {
|
||||
avgResponseTimeMs: number;
|
||||
blockedRequests: number;
|
||||
cachedRequests: number;
|
||||
ip: string;
|
||||
lastSeen: string;
|
||||
mostQueriedDomain: string;
|
||||
totalRequests: number;
|
||||
uniqueDomains: number;
|
||||
clientInfo: {
|
||||
name: string;
|
||||
ip: string;
|
||||
mac: string;
|
||||
vendor: string;
|
||||
lastSeen: string;
|
||||
bypass: boolean;
|
||||
};
|
||||
};
|
||||
|
||||
export function CardDetails({
|
||||
@@ -50,7 +58,7 @@ export function CardDetails({
|
||||
setIsLoading(true);
|
||||
try {
|
||||
const [code, response] = await GetRequest(
|
||||
`clientDetails?clientIP=${clientEntry.ip}`
|
||||
`client/${clientEntry.ip}/details`
|
||||
);
|
||||
if (code !== 200) {
|
||||
toast.warning("Unable to fetch client details");
|
||||
@@ -79,7 +87,7 @@ export function CardDetails({
|
||||
async function getClientHistory() {
|
||||
try {
|
||||
const [code, response] = await GetRequest(
|
||||
`clientHistory?ip=${clientEntry.ip}`
|
||||
`client/${clientEntry.ip}/history`
|
||||
);
|
||||
if (code !== 200) {
|
||||
toast.warning("Unable to fetch client history");
|
||||
@@ -97,6 +105,41 @@ export function CardDetails({
|
||||
getClientHistory();
|
||||
}, [clientEntry.ip]);
|
||||
|
||||
async function updateClientBypass(enabled: boolean) {
|
||||
try {
|
||||
const [code, response] = await PutRequest(
|
||||
`client/${clientEntry.ip}/bypass/${enabled}`,
|
||||
null,
|
||||
false
|
||||
);
|
||||
|
||||
if (code !== 200) {
|
||||
toast.error("Failed to update bypass setting", {
|
||||
description: response.error
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
setClientDetails((prev) =>
|
||||
prev
|
||||
? {
|
||||
...prev,
|
||||
clientInfo: {
|
||||
...prev.clientInfo,
|
||||
bypass: enabled
|
||||
}
|
||||
}
|
||||
: prev
|
||||
);
|
||||
|
||||
toast.success(
|
||||
enabled ? "Client bypass enabled" : "Client bypass disabled"
|
||||
);
|
||||
} catch {
|
||||
toast.error("Error updating bypass setting");
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog open onOpenChange={onClose}>
|
||||
<DialogContent className="border-none bg-accent rounded-lg w-full max-w-6xl max-h-3/4 overflow-y-auto">
|
||||
@@ -242,16 +285,22 @@ export function CardDetails({
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-1 sm:grid-cols-3 gap-3 mt-4">
|
||||
<Button variant="outline" disabled={true}>
|
||||
[WIP] View Details
|
||||
</Button>
|
||||
<Button variant="outline" disabled={true}>
|
||||
[WIP] Block Device
|
||||
</Button>
|
||||
<Button variant="outline" disabled={true}>
|
||||
[WIP] Device Settings
|
||||
</Button>
|
||||
<Separator className="mb-4" />
|
||||
|
||||
<div className="bg-muted-foreground/10 rounded-md p-2 shadow-sm">
|
||||
<SettingRow
|
||||
title="Bypass"
|
||||
description="Allow this client to bypass any blacklist rules."
|
||||
action={
|
||||
<Switch
|
||||
id="logging-enabled"
|
||||
checked={clientDetails.clientInfo.bypass}
|
||||
onCheckedChange={(checked) =>
|
||||
updateClientBypass(checked)
|
||||
}
|
||||
/>
|
||||
}
|
||||
/>{" "}
|
||||
</div>
|
||||
</TabsContent>
|
||||
<TabsContent value="domains">
|
||||
|
||||
@@ -12,9 +12,10 @@ import { toast } from "sonner";
|
||||
interface ClientEntry {
|
||||
ip: string;
|
||||
lastSeen: string;
|
||||
mac: string;
|
||||
name: string;
|
||||
mac: string;
|
||||
vendor: string;
|
||||
bypass: boolean;
|
||||
}
|
||||
|
||||
interface NetworkNode {
|
||||
|
||||
@@ -8,9 +8,10 @@ import { toast } from "sonner";
|
||||
export type ClientEntry = {
|
||||
ip: string;
|
||||
lastSeen: string;
|
||||
mac: string;
|
||||
name: string;
|
||||
mac: string;
|
||||
vendor: string;
|
||||
bypass: boolean;
|
||||
x?: number;
|
||||
y?: number;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user