fix: restructure codebase, make setup and flow easier

This commit is contained in:
pommee
2025-10-31 18:29:58 +01:00
committed by Hugo
parent fcf0007763
commit 75fb86cc00
167 changed files with 8092 additions and 5861 deletions

View File

@@ -5,7 +5,7 @@ tmp_dir = "tmp"
bin = "./tmp/goaway"
args_bin = [ "--dns-port=6121", "--dot-port=6122", "--doh-port=9012", "--webserver-port=8080", "--log-level=0", "--logging=true", "--auth=false", "--statistics-retention=7", "--dashboard=false", "--ansi=true" ]
cmd = 'go build -o goaway -ldflags="-X main.version=0.0.0 -X main.commit=ead2d7830add26d53ecab3c907a290f0cdc1e078 -X main.date=2025-04-11T13:37:56Z" -o ./tmp/goaway .'
exclude_dir = [ "assets", "tmp", "vendor", "client", "test", "resources" ]
exclude_dir = [ "assets", "tmp", "vendor", "client", "test", "resources", "config", "data" ]
include_ext = [ "go", "tpl", "tmpl", "html", "css", "js", "jsx", "ts", "tsx" ]
[color]

34
.github/workflows/docs.yml vendored Normal file
View File

@@ -0,0 +1,34 @@
name: Build & Deploy Homepage
on:
push:
branches:
- main
permissions:
contents: write
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Configure Git Credentials
run: |
git config user.name github-actions[bot]
git config user.email 41898282+github-actions[bot]@users.noreply.github.com
- uses: actions/setup-python@v5
with:
python-version: 3.x
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
- uses: actions/cache@v4
with:
key: mkdocs-material-${{ env.cache_id }}
path: ~/.cache
restore-keys: |
mkdocs-material-
- run: pip install mkdocs-material mkdocs-git-revision-date-localized-plugin
- run: |
mkdocs gh-deploy --force \
--config-file docs/mkdocs.yml \
-m "docs: update website"

View File

@@ -41,7 +41,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v6
with:
go-version: "1.25.1"
go-version: "1.25.3"
- name: Golangci-lint
uses: golangci/golangci-lint-action@v8.0.0

4
.gitignore vendored
View File

@@ -2,7 +2,7 @@ tmp
requests.json
counters.json
main
goaway*
goaway
*.db**
.vite
dist
@@ -38,3 +38,5 @@ benchmark.prof
*.crt
*.key
docs/site/**

View File

@@ -1,4 +1,4 @@
FROM golang:1.25.1-alpine
FROM golang:1.25.3-alpine
WORKDIR /app

View File

@@ -1,5 +1,3 @@
# GoAway - DNS Sinkhole
![GitHub Release](https://img.shields.io/github/v/release/pommee/goaway)
![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/pommee/goaway/release.yml)
![GitHub Downloads (all assets, all releases)](https://img.shields.io/github/downloads/pommee/goaway/total?color=cornflowerblue)
@@ -7,11 +5,19 @@
A lightweight DNS sinkhole for blocking unwanted domains at the network level. Block ads, trackers, and malicious domains before they reach your devices.
![goaway Dashboard Preview](./resources/dashboard.png)
![goaway banner](./resources/preview.png)
**[View more screenshots](./resources/PREVIEW.md)**
<<<<<<< Updated upstream
## 🌟 Features
=======
## Getting started
Instructions for installation, configuration and more can be found on the homepage: https://pommee.github.io/goaway
## Features
>>>>>>> Stashed changes
- DNS-level domain blocking
- Web-based admin dashboard
@@ -219,7 +225,7 @@ Contributions are welcomed! Here's how you can help:
3. **Submit PRs:** Before any work is started, create a new issue explaining what is wanted, why it would fit, how it can be done, so on and so forth...
Once the topic has been discussed with a maintainer then either you or a maintainer starts with the implementation. This is done to prevent any collisions, save time and confusion. [Read more here](./CONTRIBUTING.md)
## ⚠️ Platform Support
## Platform Support
| Platform | Architecture | Support Level |
| -------- | ------------ | ------------- |
@@ -233,7 +239,7 @@ Contributions are welcomed! Here's how you can help:
> **Note**: Primary testing is conducted on Linux (amd64). While the aim is to support all listed platforms, functionality on macOS and Windows may vary.
## 🔍 Troubleshooting
## Troubleshooting
### Common Issues
@@ -252,7 +258,7 @@ Contributions are welcomed! Here's how you can help:
- Check device DNS settings point to GoAway's IP address
- Test with `nslookup google.com <goaway-ip>` or `dig @<goaway-ip> google.com.`
## 📈 Performance
## Performance
GoAway is designed to be lightweight and efficient:
@@ -261,10 +267,10 @@ GoAway is designed to be lightweight and efficient:
- **Network:** Low latency DNS resolution
- **Storage:** Logs and statistics use minimal disk space
## 📜 License
## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
## 🙏 Acknowledgments
## Acknowledgments
This project is heavily inspired by [Pi-hole](https://github.com/pi-hole/pi-hole). Thanks to all people involved for their work.

View File

@@ -16,8 +16,8 @@ type DiscordConfig struct {
}
type DiscordService struct {
config DiscordConfig
client *http.Client
config DiscordConfig
}
type DiscordWebhookPayload struct {
@@ -28,12 +28,12 @@ type DiscordWebhookPayload struct {
}
type DiscordEmbed struct {
Author *DiscordEmbedAuthor `json:"author,omitempty"`
Title string `json:"title,omitempty"`
Description string `json:"description,omitempty"`
Color int `json:"color,omitempty"`
Fields []DiscordEmbedField `json:"fields,omitempty"`
Timestamp string `json:"timestamp,omitempty"`
Author *DiscordEmbedAuthor `json:"author,omitempty"`
Fields []DiscordEmbedField `json:"fields,omitempty"`
Color int `json:"color,omitempty"`
}
type DiscordEmbedField struct {
@@ -81,7 +81,7 @@ func (d *DiscordService) SendMessage(ctx context.Context, msg Message) error {
return fmt.Errorf("failed to marshal Discord payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", d.config.WebhookURL, bytes.NewBuffer(jsonData))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, d.config.WebhookURL, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create Discord request: %w", err)
}

View File

@@ -0,0 +1,61 @@
package alert
import (
"fmt"
"goaway/backend/database"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type Repository interface {
SaveAlert(alert database.Alert) error
GetAllAlerts() ([]database.Alert, error)
RemoveAlert(alertType string) error
}
type repository struct {
db *gorm.DB
}
func NewRepository(db *gorm.DB) Repository {
return &repository{db: db}
}
func (r *repository) SaveAlert(alert database.Alert) error {
result := r.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "type"}},
DoUpdates: clause.AssignmentColumns([]string{"enabled", "name", "webhook"}),
}).Create(&alert)
if result.Error != nil {
return fmt.Errorf("failed to save alert: %w", result.Error)
}
return nil
}
func (r *repository) GetAllAlerts() ([]database.Alert, error) {
var alerts []database.Alert
result := r.db.Find(&alerts)
return alerts, result.Error
}
func (r *repository) RemoveAlert(alertType string) error {
if alertType == "" {
return fmt.Errorf("alert type cannot be empty")
}
result := r.db.Where("type = ?", alertType).Delete(&database.Alert{})
if result.Error != nil {
return fmt.Errorf("failed to remove alert: %w", result.Error)
}
if result.RowsAffected == 0 {
log.Warning("No alert found with type: %s", alertType)
return fmt.Errorf("no alert found with type: %s", alertType)
}
return nil
}

View File

@@ -3,14 +3,14 @@ package alert
import (
"context"
"fmt"
"goaway/backend/dns/database"
"goaway/backend/database"
"goaway/backend/logging"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
var log = logging.GetLogger()
type Service struct {
repository Repository
services []messageSender
}
type Message struct {
Content string
@@ -19,32 +19,50 @@ type Message struct {
Severity string
}
type MessageSender interface {
type messageSender interface {
SendMessage(ctx context.Context, msg Message) error
IsEnabled() bool
GetServiceName() string
}
type Manager struct {
DB *gorm.DB
services []MessageSender
}
var log = logging.GetLogger()
func NewManager(db *gorm.DB) *Manager {
return &Manager{
DB: db,
services: make([]MessageSender, 0),
func NewService(repo Repository) *Service {
return &Service{
repository: repo,
services: make([]messageSender, 0),
}
}
func (m *Manager) Reset() {
m.services = make([]MessageSender, 0)
func (s *Service) reset() {
s.services = make([]messageSender, 0)
}
func (m *Manager) Load() {
func (s *Service) registerService(service messageSender) {
log.Debug("Registering alert service: %s", service.GetServiceName())
s.services = append(s.services, service)
}
func (s *Service) SaveAlert(alert database.Alert) error {
err := s.repository.SaveAlert(alert)
if err != nil {
log.Error("Failed to save alert settings: %v", err)
return err
}
s.Load()
log.Info("Alert settings saved for type: %s", alert.Type)
return nil
}
func (s *Service) GetAllAlerts() ([]database.Alert, error) {
return s.repository.GetAllAlerts()
}
func (s *Service) Load() {
discordService := NewDiscordService(DiscordConfig{})
alerts, err := m.GetAllAlerts()
alerts, err := s.GetAllAlerts()
if err != nil {
log.Warning("Failed to load alerts from database: %v", err)
@@ -62,61 +80,15 @@ func (m *Manager) Load() {
}
}
m.Reset()
m.RegisterService(discordService)
log.Debug("Alert Manager loaded with %d services", len(m.services))
s.reset()
s.registerService(discordService)
log.Debug("Alert Manager loaded with %d services", len(s.services))
}
func (m *Manager) SaveAlert(alert database.Alert) error {
result := m.DB.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "type"}},
DoUpdates: clause.AssignmentColumns([]string{"enabled", "name", "webhook"}),
}).Create(&alert)
if result.Error != nil {
return fmt.Errorf("failed to save alert: %w", result.Error)
}
log.Info("Alert settings saved for type: %s", alert.Type)
m.Load()
return nil
}
func (m *Manager) RemoveAlert(alertType string) error {
if alertType == "" {
return fmt.Errorf("alert type cannot be empty")
}
result := m.DB.Where("type = ?", alertType).Delete(&database.Alert{})
if result.Error != nil {
return fmt.Errorf("failed to remove alert: %w", result.Error)
}
if result.RowsAffected == 0 {
log.Warning("No alert found with type: %s", alertType)
return fmt.Errorf("no alert found with type: %s", alertType)
}
m.Load()
return nil
}
func (m *Manager) GetAllAlerts() ([]database.Alert, error) {
ctx := context.Background()
return gorm.G[database.Alert](m.DB).Find(ctx)
}
func (m *Manager) RegisterService(service MessageSender) {
log.Debug("Registering alert service: %s", service.GetServiceName())
m.services = append(m.services, service)
}
func (m *Manager) SendToAll(ctx context.Context, msg Message) error {
func (s *Service) SendToAll(ctx context.Context, msg Message) error {
var errors []error
for _, service := range m.services {
for _, service := range s.services {
if !service.IsEnabled() {
continue
}
@@ -135,8 +107,20 @@ func (m *Manager) SendToAll(ctx context.Context, msg Message) error {
return nil
}
func (m *Manager) SendTest(ctx context.Context, alertType, name, webhook string) error {
err := m.SaveAlert(database.Alert{
func (s *Service) RemoveAlert(alertType string) error {
err := s.repository.RemoveAlert(alertType)
if err != nil {
log.Error("Failed to remove alert: %v", err)
return err
}
s.Load()
log.Info("Alert removed for type: %s", alertType)
return nil
}
func (s *Service) SendTest(ctx context.Context, alertType, name, webhook string) error {
err := s.SaveAlert(database.Alert{
Type: alertType,
Enabled: true,
Name: name,
@@ -147,7 +131,7 @@ func (m *Manager) SendTest(ctx context.Context, alertType, name, webhook string)
return err
}
for _, service := range m.services {
for _, service := range s.services {
if service.GetServiceName() == alertType {
log.Info("Sending test alert via %s", service.GetServiceName())
err = service.SendMessage(ctx, Message{
@@ -162,7 +146,7 @@ func (m *Manager) SendTest(ctx context.Context, alertType, name, webhook string)
}
}
err = m.RemoveAlert(alertType)
err = s.RemoveAlert(alertType)
if err != nil {
log.Error("Failed to remove test alert settings: %v", err)
return err

View File

@@ -2,7 +2,7 @@ package api
import (
"context"
"goaway/backend/dns/database"
"goaway/backend/database"
"net/http"
"github.com/gin-gonic/gin"
@@ -15,14 +15,14 @@ const (
SeverityError = "error"
)
type DiscordSettings struct {
Enabled bool `json:"enabled"`
type discordSettings struct {
Name string `json:"name"`
Webhook string `json:"webhook"`
Enabled bool `json:"enabled"`
}
type AlertSettings struct {
Discord DiscordSettings `json:"discord"`
type alertSettings struct {
Discord discordSettings `json:"discord"`
}
func (api *API) registerAlertRoutes() {
@@ -33,7 +33,7 @@ func (api *API) registerAlertRoutes() {
}
func (api *API) setAlert(c *gin.Context) {
var request AlertSettings
var request alertSettings
err := c.Bind(&request)
if err != nil {
log.Error("Failed to parse alert settings: %v", err)
@@ -41,7 +41,7 @@ func (api *API) setAlert(c *gin.Context) {
return
}
err = api.DNSServer.Alerts.SaveAlert(database.Alert{
err = api.DNSServer.AlertService.SaveAlert(database.Alert{
Type: "discord",
Enabled: request.Discord.Enabled,
Name: request.Discord.Name,
@@ -57,17 +57,17 @@ func (api *API) setAlert(c *gin.Context) {
}
func (api *API) getAlert(c *gin.Context) {
alerts, err := api.DNSServer.Alerts.GetAllAlerts()
alerts, err := api.DNSServer.AlertService.GetAllAlerts()
if err != nil {
log.Error("Failed to retrieve alert settings: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve alert settings"})
return
}
var response AlertSettings
var response alertSettings
for _, alert := range alerts {
if alert.Type == "discord" {
response.Discord = DiscordSettings{
response.Discord = discordSettings{
Enabled: alert.Enabled,
Name: alert.Name,
Webhook: alert.Webhook,
@@ -79,7 +79,7 @@ func (api *API) getAlert(c *gin.Context) {
}
func (api *API) testAlert(c *gin.Context) {
var request AlertSettings
var request alertSettings
err := c.Bind(&request)
if err != nil {
log.Error("Failed to parse alert settings: %v", err)
@@ -87,7 +87,7 @@ func (api *API) testAlert(c *gin.Context) {
return
}
err = api.DNSServer.Alerts.SendTest(
err = api.DNSServer.AlertService.SendTest(
context.Background(),
"discord",
request.Discord.Name,

View File

@@ -5,6 +5,18 @@ import (
"embed"
"encoding/base64"
"fmt"
"goaway/backend/api/key"
"goaway/backend/api/ratelimit"
"goaway/backend/blacklist"
"goaway/backend/dns/server"
"goaway/backend/logging"
"goaway/backend/notification"
"goaway/backend/prefetch"
"goaway/backend/request"
"goaway/backend/resolution"
"goaway/backend/settings"
"goaway/backend/user"
"goaway/backend/whitelist"
"io/fs"
"mime"
"net"
@@ -18,65 +30,53 @@ import (
"github.com/gin-contrib/gzip"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"goaway/backend/api/key"
"goaway/backend/api/ratelimit"
"goaway/backend/api/user"
"goaway/backend/dns/database"
"goaway/backend/dns/lists"
"goaway/backend/dns/server"
"goaway/backend/dns/server/prefetch"
"goaway/backend/logging"
notification "goaway/backend/notifications"
"goaway/backend/settings"
"gorm.io/gorm"
)
var log = logging.GetLogger()
const (
maxRetries = 10
retryDelay = 10 * time.Second
)
type API struct {
Authentication bool
Config *settings.Config
DNSPort int
Version string
Commit string
Date string
router *gin.Engine
routes *gin.RouterGroup
DNSServer *server.DNSServer
DBManager *database.DatabaseManager
Blacklist *lists.Blacklist
Whitelist *lists.Whitelist
KeyManager *key.ApiKeyManager
PrefetchedDomainsManager *prefetch.Manager
Notifications *notification.Manager
WSQueries *websocket.Conn
DNS *server.DNSServer
RateLimiter *ratelimit.RateLimiter
DBConn *gorm.DB
WSCommunication *websocket.Conn
WSQueries *websocket.Conn
router *gin.Engine
routes *gin.RouterGroup
Config *settings.Config
DNSServer *server.DNSServer
Version string
Date string
Commit string
DNSPort int
Authentication bool
RateLimiter *ratelimit.RateLimiter
RequestService *request.Service
UserService *user.Service
KeyService *key.Service
PrefetchService *prefetch.Service
ResolutionService *resolution.Service
NotificationService *notification.Service
BlacklistService *blacklist.Service
WhitelistService *whitelist.Service
}
func (api *API) Start(content embed.FS, errorChannel chan struct{}) {
api.initializeRouter()
api.configureCORS()
api.KeyManager = key.NewApiKeyManager(api.DBManager)
api.RateLimiter = ratelimit.NewRateLimiter(
api.Config.API.RateLimiterConfig.Enabled,
api.Config.API.RateLimiterConfig.MaxTries,
api.Config.API.RateLimiterConfig.Window,
)
api.setupRoutes()
api.RateLimiter = ratelimit.NewRateLimiter(
api.Config.API.RateLimit.Enabled,
api.Config.API.RateLimit.MaxTries,
api.Config.API.RateLimit.Window,
)
if api.Config.Dashboard {
api.ServeEmbeddedContent(content)
if api.Config.Misc.Dashboard {
api.serveEmbeddedContent(content)
}
api.startServer(errorChannel)
@@ -104,7 +104,7 @@ func (api *API) configureCORS() {
}
)
if api.Config.Dashboard {
if api.Config.Misc.Dashboard {
corsConfig.AllowOrigins = append(corsConfig.AllowOrigins, "*")
} else {
log.Warning("Dashboard UI is disabled")
@@ -134,23 +134,19 @@ func (api *API) setupRoutes() {
func (api *API) setupAuthAndMiddleware() {
if api.Authentication {
api.SetupAuth()
api.setupAuth()
api.routes.Use(api.authMiddleware())
} else {
log.Warning("Authentication is disabled.")
log.Warning("Dashboard authentication is disabled.")
}
}
func (api *API) SetupAuth() {
newUser := &user.User{Username: "admin"}
if newUser.Exists(api.DBManager.Conn) {
func (api *API) setupAuth() {
if api.UserService.Exists("admin") {
return
}
password := api.getOrGeneratePassword()
newUser.Password = password
if err := newUser.Create(api.DBManager.Conn); err != nil {
if err := api.UserService.CreateUser("admin", api.getOrGeneratePassword()); err != nil {
log.Error("Unable to create new user: %v", err)
}
}
@@ -204,7 +200,7 @@ func (api *API) startServer(errorChannel chan struct{}) {
}
}
func (api *API) ServeEmbeddedContent(content embed.FS) {
func (api *API) serveEmbeddedContent(content embed.FS) {
ipAddress, err := GetServerIP()
if err != nil {
log.Error("Error getting IP address: %v", err)
@@ -291,6 +287,7 @@ func injectServerConfig(htmlContent, serverIP string, port int) string {
)
}
// GetServerIP retrieves the first non-loopback IPv4 address of the server.
func GetServerIP() (string, error) {
addrs, err := net.InterfaceAddrs()
if err != nil {

View File

@@ -11,7 +11,7 @@ func (api *API) registerAuditRoutes() {
}
func (api *API) getAudits(c *gin.Context) {
audits, err := api.DNSServer.Audits.ReadAudits()
audits, err := api.DNSServer.AuditService.ReadAudits()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return

View File

@@ -3,18 +3,14 @@ package api
import (
"context"
"encoding/json"
"errors"
"fmt"
"goaway/backend/alert"
"goaway/backend/api/user"
"goaway/backend/audit"
"goaway/backend/dns/database"
"goaway/backend/user"
"io"
"net/http"
"github.com/gin-gonic/gin"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
func (api *API) registerAuthRoutes() {
@@ -27,11 +23,6 @@ func (api *API) registerAuthRoutes() {
api.routes.GET("/deleteApiKey", api.deleteAPIKey)
}
func (api *API) validateCredentials(username, password string) bool {
existingUser := &user.User{Username: username, Password: password}
return existingUser.Authenticate(api.DBManager.Conn)
}
func (api *API) handleLogin(c *gin.Context) {
allowed, timeUntilReset := api.RateLimiter.CheckLimit(c.ClientIP())
if !allowed {
@@ -42,21 +33,21 @@ func (api *API) handleLogin(c *gin.Context) {
return
}
var creds user.Credentials
if err := c.BindJSON(&creds); err != nil {
var loginUser user.User
if err := c.BindJSON(&loginUser); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
return
}
if err := creds.Validate(); err != nil {
if err := api.UserService.ValidateCredentials(loginUser); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid input"})
return
}
if api.authenticateUser(creds.Username, creds.Password) {
token, err := generateToken(creds.Username)
if api.UserService.Authenticate(loginUser.Username, loginUser.Password) {
token, err := generateToken(loginUser.Username)
if err != nil {
log.Info("Token generation failed for user %s: %v", creds.Username, err)
log.Info("Token generation failed for user %s: %v", loginUser.Username, err)
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Authentication service temporarily unavailable",
})
@@ -75,27 +66,6 @@ func (api *API) handleLogin(c *gin.Context) {
}
}
func (api *API) authenticateUser(username, password string) bool {
var user database.User
if err := api.DBManager.Conn.Where("username = ?", username).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Info("Authentication attempt for non-existent or invalid credentials")
} else {
log.Warning("Database error during authentication: %v", err)
}
return false
}
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
log.Info("Password comparison failed for user: %s", username)
return false
}
log.Info("Successful authentication for user: %s", username)
return true
}
func (api *API) getAuthentication(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"enabled": api.Authentication})
}
@@ -112,24 +82,23 @@ func (api *API) updatePassword(c *gin.Context) {
return
}
if !api.validateCredentials("admin", newCredentials.CurrentPassword) {
if !api.UserService.Authenticate("admin", newCredentials.CurrentPassword) {
c.JSON(http.StatusBadRequest, gin.H{"error": "Current password is not valid"})
return
}
existingUser := user.User{Username: "admin", Password: newCredentials.NewPassword}
if err := existingUser.UpdatePassword(api.DBManager.Conn); err != nil {
if err := api.UserService.UpdatePassword("admin", newCredentials.NewPassword); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Unable to update password"})
return
}
logMsg := fmt.Sprintf("Password changed for user '%s'", existingUser.Username)
api.DNSServer.Audits.CreateAudit(&audit.Entry{
logMsg := "Password changed for user 'admin'"
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
Topic: audit.TopicUser,
Message: logMsg,
})
go func() {
_ = api.DNSServer.Alerts.SendToAll(context.Background(), alert.Message{
_ = api.DNSServer.AlertService.SendToAll(context.Background(), alert.Message{
Title: "System",
Content: logMsg,
Severity: SeverityWarning,
@@ -141,7 +110,7 @@ func (api *API) updatePassword(c *gin.Context) {
}
func (api *API) createAPIKey(c *gin.Context) {
type NewApiKeyName struct {
type NewAPIKeyName struct {
Name string `json:"name"`
}
@@ -152,21 +121,21 @@ func (api *API) createAPIKey(c *gin.Context) {
return
}
var request NewApiKeyName
var request NewAPIKeyName
if err := json.Unmarshal(body, &request); err != nil {
log.Error("Failed to parse JSON: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid JSON format"})
return
}
apiKey, err := api.KeyManager.CreateKey(request.Name)
apiKey, err := api.KeyService.CreateKey(request.Name)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
go func() {
_ = api.DNSServer.Alerts.SendToAll(context.Background(), alert.Message{
_ = api.DNSServer.AlertService.SendToAll(context.Background(), alert.Message{
Title: "System",
Content: fmt.Sprintf("New API key created with the name '%s'", request.Name),
Severity: SeverityWarning,
@@ -177,7 +146,7 @@ func (api *API) createAPIKey(c *gin.Context) {
}
func (api *API) getAPIKeys(c *gin.Context) {
apiKeys, err := api.KeyManager.GetAllKeys()
apiKeys, err := api.KeyService.GetAllKeys()
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
@@ -188,7 +157,7 @@ func (api *API) getAPIKeys(c *gin.Context) {
func (api *API) deleteAPIKey(c *gin.Context) {
keyName := c.Query("name")
err := api.KeyManager.DeleteKey(keyName)
err := api.KeyService.DeleteKey(keyName)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return

View File

@@ -1,11 +1,11 @@
package api
import (
"context"
"encoding/json"
"fmt"
"goaway/backend/audit"
"goaway/backend/dns/database"
"goaway/backend/dns/server/prefetch"
"goaway/backend/database"
"io"
"net/http"
"strconv"
@@ -44,13 +44,13 @@ func (api *API) createPrefetchedDomain(c *gin.Context) {
return
}
err = api.PrefetchedDomainsManager.AddPrefetchedDomain(prefetchedDomain.Domain, prefetchedDomain.Refresh, prefetchedDomain.QType)
err = api.PrefetchService.AddPrefetchedDomain(prefetchedDomain.Domain, prefetchedDomain.Refresh, prefetchedDomain.QType)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
api.DNSServer.Audits.CreateAudit(&audit.Entry{
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
Topic: audit.TopicPrefetch,
Message: fmt.Sprintf("Added new prefetch '%s'", prefetchedDomain.Domain),
})
@@ -58,8 +58,8 @@ func (api *API) createPrefetchedDomain(c *gin.Context) {
}
func (api *API) fetchPrefetchedDomains(c *gin.Context) {
prefetchedDomains := make([]prefetch.PrefetchedDomain, 0)
for _, b := range api.PrefetchedDomainsManager.Domains {
prefetchedDomains := make([]database.Prefetch, 0)
for _, b := range api.PrefetchService.Domains {
prefetchedDomains = append(prefetchedDomains, b)
}
c.JSON(http.StatusOK, prefetchedDomains)
@@ -72,9 +72,9 @@ func (api *API) removeDomainFromCustom(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "Empty domain name"})
}
err := api.Blacklist.RemoveCustomDomain(domain)
err := api.BlacklistService.RemoveCustomDomain(context.Background(), domain)
if err != nil {
log.Debug("Error occured while removing domain from custom list: %v", err)
log.Debug("Error occurred while removing domain from custom list: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to update custom blocklist."})
return
}
@@ -98,7 +98,7 @@ func (api *API) getBlacklistedDomains(c *gin.Context) {
pageSizeInt = 10
}
domains, total, err := api.Blacklist.LoadPaginatedBlacklist(pageInt, pageSizeInt, search)
domains, total, err := api.BlacklistService.LoadPaginatedBlacklist(context.Background(), pageInt, pageSizeInt, search)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -113,8 +113,8 @@ func (api *API) getBlacklistedDomains(c *gin.Context) {
}
func (api *API) getTopBlockedDomains(c *gin.Context) {
_, blocked, _ := api.Blacklist.GetAllowedAndBlocked()
topBlockedDomains, err := database.GetTopBlockedDomains(api.DBManager.Conn, blocked)
_, blocked, _ := api.BlacklistService.GetAllowedAndBlocked(context.Background())
topBlockedDomains, err := api.RequestService.GetTopBlockedDomains(blocked)
if err != nil {
log.Error("%v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -131,7 +131,7 @@ func (api *API) getDomainsForList(c *gin.Context) {
return
}
domains, err := api.Blacklist.GetDomainsForList(list)
domains, _, err := api.BlacklistService.FetchDBHostsList(context.Background(), list)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -143,19 +143,19 @@ func (api *API) getDomainsForList(c *gin.Context) {
func (api *API) deletePrefetchedDomain(c *gin.Context) {
domainPrefetchToDelete := c.Query("domain")
domain := api.PrefetchedDomainsManager.Domains[domainPrefetchToDelete]
if (domain == prefetch.PrefetchedDomain{}) {
domain := api.PrefetchService.Domains[domainPrefetchToDelete]
if (domain == database.Prefetch{}) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not exist", domainPrefetchToDelete)})
return
}
err := api.PrefetchedDomainsManager.RemovePrefetchedDomain(domainPrefetchToDelete)
err := api.PrefetchService.RemovePrefetchedDomain(domainPrefetchToDelete)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
api.DNSServer.Audits.CreateAudit(&audit.Entry{
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
Topic: audit.TopicPrefetch,
Message: fmt.Sprintf("Removed prefetched domain '%s'", domainPrefetchToDelete),
})

View File

@@ -1,7 +1,6 @@
package api
import (
"goaway/backend/dns/database"
"net/http"
"github.com/gin-gonic/gin"
@@ -14,7 +13,7 @@ func (api *API) registerClientRoutes() {
}
func (api *API) getClients(c *gin.Context) {
uniqueClients, err := database.FetchAllClients(api.DBManager.Conn)
uniqueClients, err := api.RequestService.FetchAllClients()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -37,7 +36,7 @@ func (api *API) getClients(c *gin.Context) {
func (api *API) getClientDetails(c *gin.Context) {
clientIP := c.DefaultQuery("clientIP", "")
clientRequestDetails, mostQueriedDomain, domainQueryCounts, err := database.GetClientDetailsWithDomains(api.DBManager.Conn, clientIP)
clientRequestDetails, mostQueriedDomain, domainQueryCounts, err := api.RequestService.GetClientDetailsWithDomains(clientIP)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -57,7 +56,7 @@ func (api *API) getClientDetails(c *gin.Context) {
}
func (api *API) getTopClients(c *gin.Context) {
topClients, err := database.GetTopClients(api.DBManager.Conn)
topClients, err := api.RequestService.GetTopClients()
if err != nil {
log.Error("%v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})

View File

@@ -1,9 +1,10 @@
package api
import (
"context"
"goaway/backend/api/models"
"goaway/backend/audit"
"goaway/backend/dns/database"
"goaway/backend/database"
"goaway/backend/dns/server"
model "goaway/backend/dns/server/models"
"goaway/backend/settings"
@@ -18,7 +19,7 @@ import (
)
func (api *API) registerDNSRoutes() {
api.setupWSLiveQueries(api.PrefetchedDomainsManager.DNS)
api.setupWSLiveQueries(api.DNS)
api.routes.POST("/pause", api.pauseBlocking)
api.routes.GET("/pause", api.getBlocking)
@@ -44,52 +45,63 @@ func (api *API) pauseBlocking(c *gin.Context) {
return
}
api.Config.DNS.Status = settings.Status{
Paused: true,
PausedAt: time.Now(),
PauseTime: blockTime.Time,
now := time.Now()
if blockTime.Time <= 0 {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Time must be greater than 0",
})
return
}
duration := time.Duration(blockTime.Time) * time.Second
pauseTime := now.Add(duration)
api.Config.DNS.Status.Paused = true
api.Config.DNS.Status.PausedAt = now
api.Config.DNS.Status.PauseTime = pauseTime
log.Info("DNS blocking paused for %d seconds", blockTime.Time)
c.Status(http.StatusOK)
}
func (api *API) getBlocking(c *gin.Context) {
if api.Config.DNS.Status.Paused {
elapsed := time.Since(api.Config.DNS.Status.PausedAt).Seconds()
remainingTime := api.Config.DNS.Status.PauseTime - int(elapsed)
now := time.Now()
remainingTime := api.Config.DNS.Status.PauseTime.Sub(now)
if remainingTime <= 0 {
api.Config.DNS.Status.Paused = false
c.JSON(http.StatusOK, gin.H{"paused": false})
return
} else {
c.JSON(http.StatusOK, gin.H{"paused": true, "timeLeft": remainingTime})
secondsLeft := int(remainingTime.Seconds())
c.JSON(http.StatusOK, gin.H{"paused": true, "timeLeft": secondsLeft})
return
}
}
if !api.Config.DNS.Status.Paused {
c.JSON(http.StatusOK, gin.H{"paused": false})
}
c.JSON(http.StatusOK, gin.H{"paused": false})
}
func (api *API) getQueries(c *gin.Context) {
query := parseQueryParams(c)
type result struct {
err error
queries []model.RequestLogEntry
total int
err error
}
queryCh := make(chan result, 1)
countCh := make(chan result, 1)
go func() {
queries, err := database.FetchQueries(api.DBManager.Conn, query)
queries, err := api.RequestService.FetchQueries(query)
queryCh <- result{queries: queries, err: err}
}()
go func() {
total, err := database.CountQueries(api.DBManager.Conn, query.Search)
total, err := api.RequestService.CountQueries(query.Search)
countCh <- result{total: total, err: err}
}()
@@ -169,7 +181,7 @@ func (api *API) getQueryTimestamps(c *gin.Context) {
return
}
timestamps, err := database.GetRequestSummaryByInterval(interval, api.DBManager.Conn)
timestamps, err := api.RequestService.GetRequestSummaryByInterval(interval)
if err != nil {
log.Error("Failed to get request summary: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"})
@@ -187,7 +199,7 @@ func (api *API) getResponseSizeTimestamps(c *gin.Context) {
return
}
timestamps, err := database.GetResponseSizeSummaryByInterval(interval, api.DBManager.Conn)
timestamps, err := api.RequestService.GetResponseSizeSummaryByInterval(interval)
if err != nil {
log.Error("Error fetching response size timestamps: %v", err.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -198,7 +210,7 @@ func (api *API) getResponseSizeTimestamps(c *gin.Context) {
}
func (api *API) getQueryTypes(c *gin.Context) {
queries, err := database.GetUniqueQueryTypes(api.DBManager.Conn)
queries, err := api.RequestService.GetUniqueQueryTypes()
if err != nil {
log.Error("%v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -209,19 +221,19 @@ func (api *API) getQueryTypes(c *gin.Context) {
}
func (api *API) clearQueries(c *gin.Context) {
if err := api.DBManager.Conn.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&database.RequestLog{}).Error; err != nil {
if err := api.DBConn.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&database.RequestLog{}).Error; err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Could not clear query logs", "reason": err.Error()})
return
}
if err := api.DBManager.Conn.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&database.RequestLogIP{}).Error; err != nil {
if err := api.DBConn.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&database.RequestLogIP{}).Error; err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Could not clear IP logs", "reason": err.Error()})
return
}
api.Blacklist.Vacuum()
api.BlacklistService.Vacuum(context.Background())
api.DNSServer.Audits.CreateAudit(&audit.Entry{
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
Topic: audit.TopicLogs,
Message: "Logs were cleared",
})

View File

@@ -1,177 +0,0 @@
package key
import (
"crypto/rand"
"encoding/hex"
"fmt"
"goaway/backend/dns/database"
"goaway/backend/logging"
"sort"
"sync"
"time"
)
type ApiKey struct {
Name string `json:"name"`
Key string `json:"key"`
CreatedAt time.Time `json:"createdAt"`
}
type ApiKeyManager struct {
dbManager *database.DatabaseManager
keyCache map[string]ApiKey
cacheMu sync.RWMutex
cacheTime time.Time
cacheTTL time.Duration
}
var log = logging.GetLogger()
func NewApiKeyManager(dbManager *database.DatabaseManager) *ApiKeyManager {
return &ApiKeyManager{
dbManager: dbManager,
keyCache: make(map[string]ApiKey),
cacheTTL: 1 * time.Hour,
}
}
func (m *ApiKeyManager) refreshCache() error {
m.cacheMu.RLock()
if time.Since(m.cacheTime) < m.cacheTTL && len(m.keyCache) > 0 {
m.cacheMu.RUnlock()
return nil
}
m.cacheMu.RUnlock()
m.cacheMu.Lock()
defer m.cacheMu.Unlock()
if time.Since(m.cacheTime) < m.cacheTTL && len(m.keyCache) > 0 {
return nil
}
var apiKeys []database.APIKey
result := m.dbManager.Conn.Find(&apiKeys)
if result.Error != nil {
return result.Error
}
newCache := make(map[string]ApiKey)
for _, apiKey := range apiKeys {
newCache[apiKey.Key] = ApiKey{
Name: apiKey.Name,
Key: apiKey.Key,
CreatedAt: apiKey.CreatedAt,
}
}
m.keyCache = newCache
m.cacheTime = time.Now()
return nil
}
func (m *ApiKeyManager) VerifyKey(apiKey string) bool {
if err := m.refreshCache(); err != nil {
log.Warning("Failed to refresh API key cache: %v", err)
var count int64
result := m.dbManager.Conn.Model(&database.APIKey{}).Where("key = ?", apiKey).Count(&count)
if result.Error != nil {
log.Warning("Failed to verify API key in database: %v", result.Error)
return false
}
return count > 0
}
m.cacheMu.RLock()
defer m.cacheMu.RUnlock()
for _, value := range m.keyCache {
if value.Key == apiKey {
return true
}
}
return false
}
func generateKey() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
func (m *ApiKeyManager) CreateKey(name string) (string, error) {
apiKey, err := generateKey()
if err != nil {
return "", err
}
newAPIKey := database.APIKey{
Name: name,
Key: apiKey,
CreatedAt: time.Now(),
}
result := m.dbManager.Conn.Create(&newAPIKey)
if result.Error != nil {
return "", fmt.Errorf("key with name '%s' already exists", name)
}
m.cacheMu.Lock()
m.keyCache[apiKey] = ApiKey{
Name: name,
Key: apiKey,
CreatedAt: newAPIKey.CreatedAt,
}
m.cacheMu.Unlock()
log.Info("Created new API key with name: %s", name)
return apiKey, nil
}
func (m *ApiKeyManager) DeleteKey(keyName string) error {
result := m.dbManager.Conn.Where("name = ?", keyName).Delete(&database.APIKey{})
if result.Error != nil {
return result.Error
}
m.cacheMu.Lock()
for key, value := range m.keyCache {
if value.Name == keyName {
delete(m.keyCache, key)
break
}
}
m.cacheMu.Unlock()
if err := m.refreshCache(); err != nil {
log.Warning("%v", err)
}
return nil
}
func (m *ApiKeyManager) GetAllKeys() ([]ApiKey, error) {
if err := m.refreshCache(); err != nil {
return nil, err
}
m.cacheMu.RLock()
defer m.cacheMu.RUnlock()
keys := make([]ApiKey, 0, len(m.keyCache))
for _, apiKey := range m.keyCache {
keyCopy := apiKey
keyCopy.Key = "redacted"
keys = append(keys, keyCopy)
}
sort.Slice(keys, func(i, j int) bool {
return keys[j].CreatedAt.Before(keys[i].CreatedAt)
})
return keys, nil
}

View File

@@ -0,0 +1,9 @@
package key
import "time"
type APIKey struct {
CreatedAt time.Time `json:"createdAt"`
Name string `json:"name"`
Key string `json:"key"`
}

View File

@@ -0,0 +1,55 @@
package key
import (
"goaway/backend/database"
"gorm.io/gorm"
)
// Repository handles all database operations for API keys
type Repository interface {
Create(apiKey *database.APIKey) error
FindByKey(key string) (*database.APIKey, error)
FindAll() ([]database.APIKey, error)
DeleteByName(name string) error
CountByKey(key string) (int64, error)
}
type repository struct {
db *gorm.DB
}
func NewRepository(db *gorm.DB) Repository {
return &repository{
db: db,
}
}
func (r *repository) Create(apiKey *database.APIKey) error {
return r.db.Create(apiKey).Error
}
func (r *repository) FindByKey(key string) (*database.APIKey, error) {
var apiKey database.APIKey
err := r.db.Where("key = ?", key).First(&apiKey).Error
if err != nil {
return nil, err
}
return &apiKey, nil
}
func (r *repository) FindAll() ([]database.APIKey, error) {
var apiKeys []database.APIKey
err := r.db.Find(&apiKeys).Error
return apiKeys, err
}
func (r *repository) DeleteByName(name string) error {
return r.db.Where("name = ?", name).Delete(&database.APIKey{}).Error
}
func (r *repository) CountByKey(key string) (int64, error) {
var count int64
err := r.db.Model(&database.APIKey{}).Where("key = ?", key).Count(&count).Error
return count, err
}

176
backend/api/key/service.go Normal file
View File

@@ -0,0 +1,176 @@
package key
import (
"crypto/rand"
"encoding/hex"
"fmt"
"goaway/backend/database"
"goaway/backend/logging"
"sort"
"sync"
"time"
)
// Service handles business logic for API keys
type Service struct {
repository Repository
cacheTime time.Time
cacheTTL time.Duration
cacheMu sync.RWMutex
keyCache map[string]APIKey
}
var log = logging.GetLogger()
func NewService(repo Repository) *Service {
return &Service{
repository: repo,
keyCache: make(map[string]APIKey),
cacheTTL: 1 * time.Hour,
}
}
func (s *Service) VerifyKey(apiKey string) bool {
if err := s.refreshCache(); err != nil {
log.Warning("Failed to refresh API key cache: %v", err)
count, err := s.repository.CountByKey(apiKey)
if err != nil {
log.Warning("Failed to verify API key in database: %v", err)
return false
}
return count > 0
}
s.cacheMu.RLock()
defer s.cacheMu.RUnlock()
for _, value := range s.keyCache {
if value.Key == apiKey {
return true
}
}
return false
}
// CreateKey generates and stores a new API key
func (s *Service) CreateKey(name string) (string, error) {
apiKey, err := generateKey()
if err != nil {
return "", err
}
newAPIKey := database.APIKey{
Name: name,
Key: apiKey,
CreatedAt: time.Now(),
}
if err := s.repository.Create(&newAPIKey); err != nil {
return "", fmt.Errorf("key with name '%s' already exists", name)
}
s.cacheMu.Lock()
s.keyCache[apiKey] = APIKey{
Name: name,
Key: apiKey,
CreatedAt: newAPIKey.CreatedAt,
}
s.cacheMu.Unlock()
log.Info("Created new API key with name: %s", name)
return apiKey, nil
}
// DeleteKey removes an API key by name
func (s *Service) DeleteKey(keyName string) error {
if err := s.repository.DeleteByName(keyName); err != nil {
return err
}
s.cacheMu.Lock()
for key, value := range s.keyCache {
if value.Name == keyName {
delete(s.keyCache, key)
break
}
}
s.cacheMu.Unlock()
if err := s.refreshCache(); err != nil {
log.Warning("%v", err)
}
return nil
}
// GetAllKeys returns all API keys with redacted key values
func (s *Service) GetAllKeys() ([]APIKey, error) {
if err := s.refreshCache(); err != nil {
return nil, err
}
s.cacheMu.RLock()
defer s.cacheMu.RUnlock()
keys := make([]APIKey, 0, len(s.keyCache))
for _, apiKey := range s.keyCache {
keyCopy := apiKey
keyCopy.Key = "redacted"
keys = append(keys, keyCopy)
}
sort.Slice(keys, func(i, j int) bool {
return keys[j].CreatedAt.Before(keys[i].CreatedAt)
})
return keys, nil
}
// refreshCache updates the in-memory cache from the database
func (s *Service) refreshCache() error {
s.cacheMu.RLock()
if time.Since(s.cacheTime) < s.cacheTTL && len(s.keyCache) > 0 {
s.cacheMu.RUnlock()
return nil
}
s.cacheMu.RUnlock()
s.cacheMu.Lock()
defer s.cacheMu.Unlock()
// Double-check after acquiring write lock
if time.Since(s.cacheTime) < s.cacheTTL && len(s.keyCache) > 0 {
return nil
}
apiKeys, err := s.repository.FindAll()
if err != nil {
return err
}
newCache := make(map[string]APIKey)
for _, apiKey := range apiKeys {
newCache[apiKey.Key] = APIKey{
Name: apiKey.Name,
Key: apiKey.Key,
CreatedAt: apiKey.CreatedAt,
}
}
s.keyCache = newCache
s.cacheTime = time.Now()
return nil
}
// generateKey creates a random hex-encoded API key
func generateKey() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}

View File

@@ -48,13 +48,13 @@ func (api *API) updateCustom(c *gin.Context) {
return
}
err = api.Blacklist.AddCustomDomains(request.Domains)
err = api.BlacklistService.AddCustomDomains(context.Background(), request.Domains)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to update custom blocklist."})
return
}
api.DNSServer.Audits.CreateAudit(&audit.Entry{
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
Topic: audit.TopicList,
Message: fmt.Sprintf("Added %d domains to custom blacklist", len(request.Domains)),
})
@@ -62,7 +62,7 @@ func (api *API) updateCustom(c *gin.Context) {
}
func (api *API) getLists(c *gin.Context) {
lists, err := api.Blacklist.GetAllListStatistics()
lists, err := api.BlacklistService.GetAllListStatistics(context.Background())
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -85,46 +85,46 @@ func (api *API) addList(c *gin.Context) {
return
}
err = api.ValidateURLAndName(newList.URL, newList.Name, c)
err = api.validateURLAndName(newList.URL, newList.Name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err = api.Blacklist.FetchAndLoadHosts(newList.URL, newList.Name); err != nil {
if err = api.BlacklistService.FetchAndLoadHosts(context.Background(), newList.URL, newList.Name); err != nil {
log.Error("Failed to fetch and load hosts: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if _, err := api.Blacklist.PopulateBlocklistCache(); err != nil {
if err := api.BlacklistService.PopulateCache(context.Background()); err != nil {
log.Error("Failed to populate blocklist cache: %v", err)
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
if err := api.Blacklist.AddSource(newList.Name, newList.URL); err != nil {
if err := api.BlacklistService.AddSource(context.Background(), newList.Name, newList.URL); err != nil {
log.Error("Failed to add source: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if !newList.Active {
if err := api.Blacklist.ToggleBlocklistStatus(newList.Name); err != nil {
if err := api.BlacklistService.ToggleBlocklistStatus(context.Background(), newList.Name); err != nil {
log.Error("Failed to toggle blocklist status: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to toggle status for " + newList.Name})
return
}
}
_, addedList, err := api.Blacklist.GetListStatistics(newList.Name)
_, addedList, err := api.BlacklistService.GetListStatistics(context.Background(), newList.Name)
if err != nil {
log.Error("Failed to get list statistics: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get list statistics"})
return
}
api.DNSServer.Audits.CreateAudit(&audit.Entry{
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
Topic: audit.TopicList,
Message: fmt.Sprintf("New blacklist with name '%s' was added", addedList.Name),
})
@@ -150,31 +150,31 @@ func (api *API) addLists(c *gin.Context) {
var addedList []NewList
var ignoredList []NewList
for _, list := range payload.Lists {
if api.Blacklist.URLExists(list.URL) {
if api.BlacklistService.URLExists(list.URL) {
ignoredList = append(ignoredList, list)
continue
}
if err := api.Blacklist.FetchAndLoadHosts(list.URL, list.Name); err != nil {
if err := api.BlacklistService.FetchAndLoadHosts(context.Background(), list.URL, list.Name); err != nil {
log.Error("Failed to fetch and load hosts: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if _, err := api.Blacklist.PopulateBlocklistCache(); err != nil {
if err := api.BlacklistService.PopulateCache(context.Background()); err != nil {
log.Error("Failed to populate blocklist cache: %v", err)
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
if err := api.Blacklist.AddSource(list.Name, list.URL); err != nil {
if err := api.BlacklistService.AddSource(context.Background(), list.Name, list.URL); err != nil {
log.Error("Failed to add source: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if !list.Active {
if err := api.Blacklist.ToggleBlocklistStatus(list.Name); err != nil {
if err := api.BlacklistService.ToggleBlocklistStatus(context.Background(), list.Name); err != nil {
log.Error("Failed to toggle blocklist status: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to toggle status for " + list.Name})
return
@@ -185,7 +185,7 @@ func (api *API) addLists(c *gin.Context) {
}
if len(addedList) > 0 {
api.DNSServer.Audits.CreateAudit(&audit.Entry{
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
Topic: audit.TopicList,
Message: fmt.Sprintf("Added %d new blacklists in bulk", len(addedList)),
})
@@ -197,19 +197,19 @@ func (api *API) addLists(c *gin.Context) {
func (api *API) updateListName(c *gin.Context) {
oldName := c.Query("old")
newName := c.Query("new")
url := c.Query("url")
listURL := c.Query("url")
if oldName == "" || newName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "New and old names are required"})
return
}
if !api.Blacklist.NameExists(oldName, url) {
if !api.BlacklistService.NameExists(oldName, listURL) {
c.JSON(http.StatusBadRequest, gin.H{"error": "List with that name and url combination does not exist"})
return
}
err := api.Blacklist.UpdateSourceName(oldName, newName, url)
err := api.BlacklistService.UpdateSourceName(context.Background(), oldName, newName, listURL)
if err != nil {
log.Warning("%s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -221,19 +221,19 @@ func (api *API) updateListName(c *gin.Context) {
func (api *API) fetchUpdatedList(c *gin.Context) {
name := c.Query("name")
url := c.Query("url")
listURL := c.Query("url")
if !api.Blacklist.NameExists(name, url) {
if !api.BlacklistService.NameExists(name, listURL) {
c.JSON(http.StatusBadRequest, gin.H{"error": "List with that name and url combination does not exist"})
return
}
if name == "" || url == "" {
if name == "" || listURL == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing 'name' or 'url' query parameter"})
return
}
availableUpdate, err := api.Blacklist.CheckIfUpdateAvailable(url, name)
availableUpdate, err := api.BlacklistService.CheckIfUpdateAvailable(context.Background(), listURL, name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -249,33 +249,33 @@ func (api *API) fetchUpdatedList(c *gin.Context) {
func (api *API) runUpdateList(c *gin.Context) {
name := c.Query("name")
url := c.Query("url")
listURL := c.Query("url")
if !api.Blacklist.NameExists(name, url) {
if !api.BlacklistService.NameExists(name, listURL) {
c.JSON(http.StatusBadRequest, gin.H{"error": "List does not exist"})
return
}
if name == "" || url == "" {
if name == "" || listURL == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing 'name' or 'url' query parameter"})
return
}
err := api.Blacklist.RemoveSourceAndDomains(name, url)
err := api.BlacklistService.RemoveSourceAndDomains(context.Background(), name, listURL)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
err = api.Blacklist.FetchAndLoadHosts(url, name)
err = api.BlacklistService.FetchAndLoadHosts(context.Background(), listURL, name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
go func() {
_ = api.DNSServer.Alerts.SendToAll(context.Background(), alert.Message{
_ = api.DNSServer.AlertService.SendToAll(context.Background(), alert.Message{
Title: "System",
Content: fmt.Sprintf("List '%s' with url '%s' was updated! ", name, url),
Content: fmt.Sprintf("List '%s' with url '%s' was updated! ", name, listURL),
Severity: SeveritySuccess,
})
}()
@@ -291,7 +291,7 @@ func (api *API) toggleBlocklist(c *gin.Context) {
return
}
err := api.Blacklist.ToggleBlocklistStatus(blocklist)
err := api.BlacklistService.ToggleBlocklistStatus(context.Background(), blocklist)
if err != nil {
log.Error("Failed to toggle blocklist status: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Failed to toggle status for %s", blocklist)})
@@ -309,9 +309,9 @@ func (api *API) handleUpdateBlockStatus(c *gin.Context) {
return
}
action := map[string]func(string) error{
"true": api.Blacklist.AddBlacklistedDomain,
"false": api.Blacklist.RemoveDomain,
action := map[string]func(context.Context, string) error{
"true": api.BlacklistService.AddBlacklistedDomain,
"false": api.BlacklistService.RemoveDomain,
}[blocked]
if action == nil {
@@ -319,7 +319,7 @@ func (api *API) handleUpdateBlockStatus(c *gin.Context) {
return
}
if err := action(domain); err != nil {
if err := action(context.Background(), domain); err != nil {
c.JSON(http.StatusOK, gin.H{"message": err.Error()})
return
}
@@ -334,42 +334,42 @@ func (api *API) handleUpdateBlockStatus(c *gin.Context) {
func (api *API) removeList(c *gin.Context) {
name := c.Query("name")
url := c.Query("url")
listURL := c.Query("url")
if !api.Blacklist.NameExists(name, url) {
if !api.BlacklistService.NameExists(name, listURL) {
c.JSON(http.StatusBadRequest, gin.H{"error": "List does not exist"})
return
}
err := api.Blacklist.RemoveSourceAndDomains(name, url)
err := api.BlacklistService.RemoveSourceAndDomains(context.Background(), name, listURL)
if err != nil {
log.Error("%v", err.Error())
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
}
if removed := api.Blacklist.RemoveSourceByNameAndURL(name, url); !removed {
log.Error("Failed to remove source with name '%s' and url '%s'", name, url)
if removed := api.BlacklistService.RemoveSourceByNameAndURL(name, listURL); !removed {
log.Error("Failed to remove source with name '%s' and url '%s'", name, listURL)
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to remove the list"})
return
}
api.DNSServer.Audits.CreateAudit(&audit.Entry{
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
Topic: audit.TopicList,
Message: fmt.Sprintf("Blacklist with name '%s' was deleted", name),
})
c.Status(http.StatusOK)
}
func (api *API) ValidateURLAndName(URL, name string, c *gin.Context) error {
if name == "" || URL == "" {
func (api *API) validateURLAndName(listURL, name string) error {
if name == "" || listURL == "" {
return fmt.Errorf("name and URL are required")
}
if _, err := url.ParseRequestURI(URL); err != nil {
if _, err := url.ParseRequestURI(listURL); err != nil {
return fmt.Errorf("invalid URL format")
}
if api.Blacklist.URLExists(URL) {
if api.BlacklistService.URLExists(listURL) {
return fmt.Errorf("list with the same URL already exists")
}

View File

@@ -11,8 +11,8 @@ import (
)
const (
TokenDuration = 5 * time.Minute
Secret = "kMNSRwKip7Yet4rb2z8"
tokenDuration = 5 * time.Minute
secret = "kMNSRwKip7Yet4rb2z8"
)
func (api *API) authMiddleware() gin.HandlerFunc {
@@ -23,7 +23,7 @@ func (api *API) authMiddleware() gin.HandlerFunc {
}
if apiKey := c.GetHeader("api-key"); apiKey != "" {
if api.KeyManager.VerifyKey(apiKey) {
if api.KeyService.VerifyKey(apiKey) {
c.Next()
return
}
@@ -62,7 +62,7 @@ func (api *API) authMiddleware() gin.HandlerFunc {
return
}
halfDurationSeconds := int64(TokenDuration.Seconds() / 2)
halfDurationSeconds := int64(tokenDuration.Seconds() / 2)
timeUntilExpiration := expiration - now
if timeUntilExpiration <= halfDurationSeconds {
@@ -85,7 +85,7 @@ func parseToken(tokenString string) (jwt.MapClaims, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return []byte(Secret), nil
return []byte(secret), nil
})
if err != nil || !token.Valid {
return nil, err
@@ -103,11 +103,11 @@ func generateToken(username string) (string, error) {
now := time.Now()
claims := jwt.MapClaims{
"username": username,
"exp": now.Add(TokenDuration).Unix(),
"exp": now.Add(tokenDuration).Unix(),
"iat": now.Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(Secret))
return token.SignedString([]byte(secret))
}
func setAuthCookie(w http.ResponseWriter, token string) {
@@ -118,7 +118,7 @@ func setAuthCookie(w http.ResponseWriter, token string) {
HttpOnly: true,
Secure: false,
SameSite: http.SameSiteLaxMode,
Expires: time.Now().Add(TokenDuration),
MaxAge: int(TokenDuration.Seconds()),
Expires: time.Now().Add(tokenDuration),
MaxAge: int(tokenDuration.Seconds()),
})
}

View File

@@ -1,11 +1,11 @@
package models
type QueryParams struct {
Page int
PageSize int
Offset int
Search string
Column string
Direction string
FilterClient string
Page int
PageSize int
Offset int
}

View File

@@ -1,9 +1,7 @@
package api
import (
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
@@ -16,7 +14,7 @@ func (api *API) registerNotificationRoutes() {
}
func (api *API) fetchNotifications(c *gin.Context) {
notifications, err := api.Notifications.ReadNotifications()
notifications, err := api.NotificationService.GetNotifications()
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
}
@@ -28,21 +26,14 @@ func (api *API) markNotificationAsRead(c *gin.Context) {
NotificationIDs []int `json:"notificationIds"`
}
notificationsRead, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Error("Failed to read request body: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
var request NotificationsRead
if err := json.Unmarshal(notificationsRead, &request); err != nil {
log.Error("Failed to parse JSON: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid JSON format"})
err := c.BindJSON(&request)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Unable to parse request"})
return
}
err = api.Notifications.MarkNotificationsAsRead(request.NotificationIDs)
err = api.NotificationService.MarkNotificationsAsRead(request.NotificationIDs)
if err != nil {
log.Warning("Unable to mark notifications as read %v", err)
c.JSON(http.StatusBadGateway, gin.H{"error": fmt.Sprintf("Unable to mark notifications as read %v", err.Error())})

View File

@@ -15,12 +15,12 @@ type RateLimiterConfig struct {
}
type RateLimiter struct {
Config RateLimiterConfig
mutex *sync.RWMutex
attempts map[string][]time.Time
Config RateLimiterConfig
}
func NewRateLimiter(enabled bool, maxTries int, window int) *RateLimiter {
func NewRateLimiter(enabled bool, maxTries, window int) *RateLimiter {
config := RateLimiterConfig{
Enabled: enabled,
MaxTries: maxTries,

View File

@@ -3,7 +3,6 @@ package api
import (
"fmt"
"goaway/backend/audit"
"goaway/backend/dns/database"
"net/http"
"github.com/gin-gonic/gin"
@@ -31,14 +30,14 @@ func (api *API) createResolution(c *gin.Context) {
return
}
err := database.CreateNewResolution(api.DBManager.Conn, newResolution.IP, newResolution.Domain)
err := api.ResolutionService.CreateResolution(newResolution.IP, newResolution.Domain)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
api.DNSServer.RemoveCachedDomain(newResolution.Domain)
api.DNSServer.Audits.CreateAudit(&audit.Entry{
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
Topic: audit.TopicResolution,
Message: fmt.Sprintf("Added new resolution '%s'", newResolution.Domain),
})
@@ -46,7 +45,7 @@ func (api *API) createResolution(c *gin.Context) {
}
func (api *API) getResolutions(c *gin.Context) {
resolutions, err := database.FetchResolutions(api.DBManager.Conn)
resolutions, err := api.ResolutionService.GetResolutions()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -59,7 +58,7 @@ func (api *API) deleteResolution(c *gin.Context) {
domain := c.Query("domain")
ip := c.Query("ip")
rowsAffected, err := database.DeleteResolution(api.DBManager.Conn, ip, domain)
rowsAffected, err := api.ResolutionService.DeleteResolution(ip, domain)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": err.Error(),
@@ -74,7 +73,7 @@ func (api *API) deleteResolution(c *gin.Context) {
api.DNSServer.RemoveCachedDomain(domain)
api.DNSServer.Audits.CreateAudit(&audit.Entry{
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
Topic: audit.TopicResolution,
Message: fmt.Sprintf("Removed resolution '%s'", domain),
})

View File

@@ -1,8 +1,8 @@
package api
import (
"context"
"fmt"
"goaway/backend/dns/database"
"goaway/backend/dns/server"
"goaway/backend/updater"
"net/http"
@@ -10,6 +10,7 @@ import (
"path/filepath"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
@@ -18,7 +19,7 @@ import (
)
func (api *API) registerServerRoutes() {
api.setupWSLiveCommunication(api.PrefetchedDomainsManager.DNS)
api.setupWSLiveCommunication(api.DNS)
api.router.GET("/api/server", api.handleServer)
api.router.GET("/api/dnsMetrics", api.handleMetrics)
@@ -47,7 +48,7 @@ func (api *API) handleServer(c *gin.Context) {
}
c.JSON(http.StatusOK, gin.H{
"portDNS": api.Config.DNS.Port,
"portDNS": api.Config.DNS.Ports.TCPUDP,
"portWebsite": api.DNSPort,
"totalMem": float64(vMem.Total) / 1024 / 1024 / 1024,
"usedMem": float64(vMem.Used) / 1024 / 1024 / 1024,
@@ -56,7 +57,7 @@ func (api *API) handleServer(c *gin.Context) {
"cpuTemp": temp,
"dbSize": dbSize,
"version": api.Version,
"inAppUpdate": api.Config.InAppUpdate,
"inAppUpdate": api.Config.Misc.InAppUpdate,
"commit": api.Commit,
"date": api.Date,
})
@@ -110,7 +111,7 @@ func getDBSizeMB() (float64, error) {
}
func (api *API) handleMetrics(c *gin.Context) {
allowed, blocked, err := api.Blacklist.GetAllowedAndBlocked()
allowed, blocked, err := api.BlacklistService.GetAllowedAndBlocked(context.Background())
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
@@ -122,14 +123,14 @@ func (api *API) handleMetrics(c *gin.Context) {
percentageBlocked = (float64(blocked) / float64(total)) * 100
}
domainsLength, _ := api.Blacklist.CountDomains()
domainsLength, _ := api.BlacklistService.CountDomains(context.Background())
c.JSON(http.StatusOK, gin.H{
"allowed": allowed,
"blocked": blocked,
"total": total,
"percentageBlocked": percentageBlocked,
"domainBlockLen": domainsLength,
"clients": database.GetDistinctRequestIP(api.DBManager.Conn),
"clients": api.RequestService.GetDistinctRequestIP(),
})
}
@@ -169,16 +170,47 @@ func (api *API) setupWSLiveCommunication(dnsServer *server.DNSServer) {
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(_ *http.Request) bool {
return true
},
}
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
return
}
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
conn.SetPongHandler(func(string) error {
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
api.WSCommunication = conn
if dnsServer != nil {
dnsServer.WSCommunication = conn
}
go func() {
defer func() {
_ = conn.Close()
if dnsServer != nil {
dnsServer.WSCommunicationLock.Lock()
dnsServer.WSCommunication = nil
dnsServer.WSCommunicationLock.Unlock()
}
}()
for {
_, _, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Warning("Websocket closed unexpectedly: %v", err)
}
break
}
}
}()
})
}

View File

@@ -36,11 +36,11 @@ func (api *API) updateSettings(c *gin.Context) {
return
}
api.Config.UpdateSettings(updatedSettings)
settingsJson, _ := json.MarshalIndent(updatedSettings, "", " ")
log.Debug("%s", string(settingsJson))
api.Config.Update(updatedSettings)
settingsJSON, _ := json.MarshalIndent(updatedSettings, "", " ")
log.Debug("%s", string(settingsJSON))
api.DNSServer.Audits.CreateAudit(&audit.Entry{
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
Topic: audit.TopicSettings,
Message: "Settings was updated",
})
@@ -65,7 +65,7 @@ func (api *API) exportDatabase(c *gin.Context) {
_ = os.Remove(tempExport)
// Create a new connection to a temp file and vacuum into it
if err := api.DBManager.Conn.Exec(fmt.Sprintf("VACUUM INTO '%s';", tempExport)).Error; err != nil {
if err := api.DBConn.Exec(fmt.Sprintf("VACUUM INTO '%s';", tempExport)).Error; err != nil {
log.Error("Failed to write WAL to temp export: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to prepare database for export"})
return
@@ -116,7 +116,7 @@ func (api *API) exportDatabase(c *gin.Context) {
return n > 0
})
api.DNSServer.Audits.CreateAudit(&audit.Entry{
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
Topic: audit.TopicDatabase,
Message: "Database was exported",
})
@@ -125,7 +125,7 @@ func (api *API) exportDatabase(c *gin.Context) {
func validateSQLiteFile(filePath string) error {
file, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("cannot open file: %v", err)
return fmt.Errorf("cannot open file: %w", err)
}
go func() {
_ = file.Close()
@@ -133,7 +133,7 @@ func validateSQLiteFile(filePath string) error {
stat, err := file.Stat()
if err != nil {
return fmt.Errorf("cannot stat file: %v", err)
return fmt.Errorf("cannot stat file: %w", err)
}
if stat.Size() < 50 {
@@ -143,7 +143,7 @@ func validateSQLiteFile(filePath string) error {
header := make([]byte, 16)
_, err = file.Read(header)
if err != nil {
return fmt.Errorf("cannot read file header: %v", err)
return fmt.Errorf("cannot read file header: %w", err)
}
expectedHeader := "SQLite format 3\x00"
@@ -184,7 +184,7 @@ func (api *API) importDatabase(c *gin.Context) {
}
_, err = io.Copy(tempFile, file)
defer func(tempfile *os.File) {
defer func(tempFile *os.File) {
_ = tempFile.Close()
}(tempFile)
@@ -229,7 +229,7 @@ func (api *API) importDatabase(c *gin.Context) {
return
}
sqlDB, err := api.DBManager.Conn.DB()
sqlDB, err := api.DBConn.DB()
if err != nil {
log.Error("Failed to get underlying sql.DB: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to close current database"})
@@ -260,15 +260,15 @@ func (api *API) importDatabase(c *gin.Context) {
if err != nil {
log.Error("Failed to open imported database with GORM: %v", err)
_ = copyFile(backupPath, currentDBPath)
api.DBManager.Conn, _ = gorm.Open(sqlite.Open(currentDBPath), &gorm.Config{})
api.DBConn, _ = gorm.Open(sqlite.Open(currentDBPath), &gorm.Config{})
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to open imported database, restored from backup"})
return
}
api.DBManager.Conn = newDB
*api.DBConn = *newDB
log.Info("Database imported successfully from %s", header.Filename)
api.DNSServer.Audits.CreateAudit(&audit.Entry{
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
Topic: audit.TopicDatabase,
Message: "Database was imported",
})

View File

@@ -18,14 +18,14 @@ import (
probing "github.com/prometheus-community/pro-bing"
)
type PingResult struct {
Duration time.Duration
type pingResult struct {
Error error
Method string
Duration time.Duration
Successful bool
}
func (pr PingResult) String() string {
func (pr pingResult) String() string {
if !pr.Successful {
return fmt.Sprintf("Failed (%s)", pr.Method)
}
@@ -68,16 +68,16 @@ func (api *API) createUpstream(c *gin.Context) {
upstream += ":53"
}
if slices.Contains(api.Config.DNS.UpstreamDNS, upstream) {
if slices.Contains(api.Config.DNS.Upstream.Fallback, upstream) {
c.JSON(http.StatusBadRequest, gin.H{"error": "Upstream already exists"})
return
}
api.Config.DNS.UpstreamDNS = append(api.Config.DNS.UpstreamDNS, upstream)
api.Config.DNS.Upstream.Fallback = append(api.Config.DNS.Upstream.Fallback, upstream)
api.Config.Save()
log.Info("Added %s as a new upstream", upstream)
api.DNSServer.Audits.CreateAudit(&audit.Entry{
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
Topic: audit.TopicUpstream,
Message: fmt.Sprintf("Added a new upstream '%s'", request.Upstream),
})
@@ -87,9 +87,9 @@ func (api *API) createUpstream(c *gin.Context) {
func (api *API) getUpstreams(c *gin.Context) {
var (
upstreams = api.Config.DNS.UpstreamDNS
upstreams = api.Config.DNS.Upstream.Fallback
results = make([]map[string]any, len(upstreams))
preferredUpstream = api.Config.DNS.PreferredUpstream
preferredUpstream = api.Config.DNS.Upstream.Preferred
wg sync.WaitGroup
)
wg.Add(len(upstreams))
@@ -158,7 +158,7 @@ func resolveHostname(host string) string {
return "No IP found"
}
func measureDNSPing(upstream string) PingResult {
func measureDNSPing(upstream string) pingResult {
var (
testDomains = []string{"google.com", "cloudflare.com", "quad9.net"}
client = &dns.Client{
@@ -193,7 +193,7 @@ func measureDNSPing(upstream string) PingResult {
}
if successCount == 0 {
return PingResult{
return pingResult{
Duration: 0,
Error: lastError,
Method: "dns",
@@ -202,7 +202,7 @@ func measureDNSPing(upstream string) PingResult {
}
avgDuration := totalDuration / time.Duration(successCount)
return PingResult{
return pingResult{
Duration: avgDuration,
Error: nil,
Method: "dns",
@@ -210,7 +210,7 @@ func measureDNSPing(upstream string) PingResult {
}
}
func measureICMPPing(host string) PingResult {
func measureICMPPing(host string) pingResult {
icmpResult := tryICMPPing(host)
if icmpResult.Successful {
return icmpResult
@@ -220,10 +220,10 @@ func measureICMPPing(host string) PingResult {
return tcpResult
}
func tryICMPPing(host string) PingResult {
func tryICMPPing(host string) pingResult {
pinger, err := probing.NewPinger(host)
if err != nil {
return PingResult{
return pingResult{
Duration: 0,
Error: err,
Method: "icmp",
@@ -237,7 +237,7 @@ func tryICMPPing(host string) PingResult {
err = pinger.Run()
if err != nil {
return PingResult{
return pingResult{
Duration: 0,
Error: err,
Method: "icmp",
@@ -247,7 +247,7 @@ func tryICMPPing(host string) PingResult {
stats := pinger.Statistics()
if stats.PacketsRecv == 0 {
return PingResult{
return pingResult{
Duration: 0,
Error: fmt.Errorf("no packets received"),
Method: "icmp",
@@ -255,7 +255,7 @@ func tryICMPPing(host string) PingResult {
}
}
return PingResult{
return pingResult{
Duration: stats.AvgRtt,
Error: nil,
Method: "icmp",
@@ -263,12 +263,12 @@ func tryICMPPing(host string) PingResult {
}
}
func tryTCPPing(host string) PingResult {
func tryTCPPing(host string) pingResult {
start := time.Now()
conn, err := net.DialTimeout("tcp", net.JoinHostPort(host, "53"), 2*time.Second)
if err != nil {
return PingResult{
return pingResult{
Duration: 0,
Error: err,
Method: "tcp",
@@ -281,7 +281,7 @@ func tryTCPPing(host string) PingResult {
_ = conn.Close()
}()
return PingResult{
return pingResult{
Duration: duration,
Error: nil,
Method: "tcp",
@@ -333,22 +333,22 @@ func (api *API) updatePreferredUpstream(c *gin.Context) {
return
}
if !slices.Contains(api.Config.DNS.UpstreamDNS, request.Upstream) {
if !slices.Contains(api.Config.DNS.Upstream.Fallback, request.Upstream) {
c.JSON(http.StatusNotFound, gin.H{"error": "Upstream not found"})
return
}
if api.Config.DNS.PreferredUpstream == request.Upstream {
if api.Config.DNS.Upstream.Preferred == request.Upstream {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Preferred upstream already set to %s", request.Upstream)})
return
}
api.Config.DNS.PreferredUpstream = request.Upstream
api.Config.DNS.Upstream.Preferred = request.Upstream
message := fmt.Sprintf("Preferred upstream set to %s", request.Upstream)
log.Info("%s", message)
api.Config.Save()
api.DNSServer.Audits.CreateAudit(&audit.Entry{
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
Topic: audit.TopicUpstream,
Message: fmt.Sprintf("New preferred upstream '%s'", request.Upstream),
})
@@ -364,23 +364,23 @@ func (api *API) deleteUpstream(c *gin.Context) {
return
}
if !slices.Contains(api.Config.DNS.UpstreamDNS, upstreamToDelete) {
if !slices.Contains(api.Config.DNS.Upstream.Fallback, upstreamToDelete) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("Upstream %s not found", upstreamToDelete)})
return
}
var updatedUpstreams []string
for _, upstream := range api.Config.DNS.UpstreamDNS {
for _, upstream := range api.Config.DNS.Upstream.Fallback {
if upstream != upstreamToDelete {
updatedUpstreams = append(updatedUpstreams, upstream)
}
}
api.Config.DNS.UpstreamDNS = updatedUpstreams
api.Config.DNS.Upstream.Fallback = updatedUpstreams
api.Config.Save()
log.Info("Removed upstream: %s", upstreamToDelete)
api.DNSServer.Audits.CreateAudit(&audit.Entry{
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
Topic: audit.TopicUpstream,
Message: fmt.Sprintf("Removed upstream '%s'", upstreamToDelete),
})

View File

@@ -1,6 +0,0 @@
package user
type User struct {
Username string
Password string
}

View File

@@ -1,119 +0,0 @@
package user
import (
"context"
"errors"
"goaway/backend/dns/database"
"goaway/backend/logging"
"strings"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
var log = logging.GetLogger()
type Credentials struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
}
func (user *User) Create(db *gorm.DB) error {
log.Info("Creating a new user with name '%s'", user.Username)
hashedPassword, err := hashPassword(user.Password)
if err != nil {
log.Error("Failed to hash password: %v", err)
return err
}
user.Password = hashedPassword
result := db.Create(user)
if result.Error != nil {
log.Error("Failed to create user: %v", result.Error)
return result.Error
}
if result.RowsAffected == 0 {
log.Error("User creation failed: no rows affected")
return errors.New("user creation failed: no rows affected")
}
log.Debug("User created successfully")
return nil
}
func hashPassword(password string) (string, error) {
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
return string(hashed), err
}
func (user *User) Exists(db *gorm.DB) bool {
query := database.User{}
db.Where("username = ?", user.Username).Find(&query)
return query.Username != ""
}
func (user *User) Authenticate(db *gorm.DB) bool {
var dbUser User
if err := db.Where("username = ?", user.Username).First(&dbUser).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Error("User not found: %s", user.Username)
return false
}
log.Error("Query error: %v", err)
return false
}
if err := bcrypt.CompareHashAndPassword([]byte(dbUser.Password), []byte(user.Password)); err != nil {
return false
}
return true
}
func (user *User) UpdatePassword(db *gorm.DB) error {
hashedPassword, err := hashPassword(user.Password)
if err != nil {
log.Error("Failed to hash new password: %v", err)
return err
}
affected, err := gorm.G[database.User](db).Where("username = ?", user.Username).Update(context.Background(), "password", hashedPassword)
if err != nil {
log.Error("Failed to update password: %v", err)
return err
}
if affected == 0 {
log.Error("Password update failed: no rows affected")
return errors.New("password update failed: no rows affected")
}
log.Debug("Password updated successfully")
return nil
}
func (c *Credentials) Validate() error {
c.Username = strings.TrimSpace(c.Username)
c.Password = strings.TrimSpace(c.Password)
if c.Username == "" || c.Password == "" {
return errors.New("username and password cannot be empty")
}
if len(c.Username) > 60 {
return errors.New("username too long")
}
if len(c.Password) > 120 {
return errors.New("password too long")
}
for _, r := range c.Username {
if r < 32 || r == 127 {
return errors.New("username contains invalid characters")
}
}
return nil
}

View File

@@ -2,7 +2,7 @@ package api
import (
"encoding/json"
"goaway/backend/dns/database"
"goaway/backend/database"
"io"
"net/http"
@@ -34,7 +34,7 @@ func (api *API) addWhitelisted(c *gin.Context) {
return
}
err = api.Whitelist.AddDomain(newDomain.Domain)
err = api.WhitelistService.AddDomain(newDomain.Domain)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
@@ -46,7 +46,7 @@ func (api *API) addWhitelisted(c *gin.Context) {
func (api *API) getWhitelistedDomains(c *gin.Context) {
var domains []string
err := api.DBManager.Conn.Model(&database.Whitelist{}).Pluck("domain", &domains).Error
err := api.DBConn.Model(&database.Whitelist{}).Pluck("domain", &domains).Error
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "Failed to retrieve whitelisted domains"})
return
@@ -58,7 +58,7 @@ func (api *API) getWhitelistedDomains(c *gin.Context) {
func (api *API) deleteWhitelistedDomain(c *gin.Context) {
newDomain := c.Query("domain")
err := api.Whitelist.RemoveDomain(newDomain)
err := api.WhitelistService.RemoveDomain(newDomain)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return

111
backend/app.go Normal file
View File

@@ -0,0 +1,111 @@
package app
import (
"context"
"embed"
"fmt"
"goaway/backend/alert"
"goaway/backend/api/key"
"goaway/backend/audit"
"goaway/backend/blacklist"
"goaway/backend/lifecycle"
"goaway/backend/logging"
"goaway/backend/mac"
"goaway/backend/notification"
"goaway/backend/prefetch"
"goaway/backend/request"
"goaway/backend/resolution"
"goaway/backend/services"
"goaway/backend/settings"
"goaway/backend/setup"
"goaway/backend/user"
"goaway/backend/whitelist"
)
var log = logging.GetLogger()
type Application struct {
config *settings.Config
context *services.AppContext
services *services.ServiceRegistry
lifecycle *lifecycle.Manager
content embed.FS
version string
commit string
date string
}
func New(setFlags *setup.SetFlags, version, commit, date string, content embed.FS) *Application {
config := setup.InitializeSettings(setFlags)
return &Application{
config: config,
version: version,
commit: commit,
date: date,
content: content,
}
}
func (a *Application) Start() error {
ctx, err := services.NewAppContext(a.config)
if err != nil {
return fmt.Errorf("failed to initialize application context: %w", err)
}
a.context = ctx
dbConn := a.context.DBConn
alertService := alert.NewService(alert.NewRepository(dbConn))
auditService := audit.NewService(audit.NewRepository(dbConn))
blacklistService := blacklist.NewService(blacklist.NewRepository(dbConn))
keyService := key.NewService(key.NewRepository(dbConn))
macService := mac.NewService(mac.NewRepository(dbConn))
notificationService := notification.NewService(notification.NewRepository(dbConn))
prefetchService := prefetch.NewService(prefetch.NewRepository(dbConn), a.context.DNSServer)
requestService := request.NewService(request.NewRepository(dbConn))
resolutionService := resolution.NewService(resolution.NewRepository(dbConn))
userService := user.NewService(user.NewRepository(dbConn))
whitelistService := whitelist.NewService(whitelist.NewRepository(dbConn))
a.context.DNSServer.AlertService = alertService
a.context.DNSServer.AuditService = auditService
a.context.DNSServer.BlacklistService = blacklistService
a.context.DNSServer.MACService = macService
a.context.DNSServer.NotificationService = notificationService
a.context.DNSServer.RequestService = requestService
a.context.DNSServer.UserService = userService
a.context.DNSServer.ResolutionService = resolutionService
a.context.DNSServer.WhitelistService = whitelistService
a.displayStartupInfo()
a.services = services.NewServiceRegistry(a.context, a.version, a.commit, a.date, a.content)
a.services.ResolutionService = resolutionService
a.services.BlacklistService = blacklistService
a.services.NotificationService = notificationService
a.services.PrefetchService = prefetchService
a.services.RequestService = requestService
a.services.UserService = userService
a.services.KeyService = keyService
a.services.WhitelistService = whitelistService
a.lifecycle = lifecycle.NewManager(a.services)
runServices := a.lifecycle.Run()
return runServices
}
func (a *Application) displayStartupInfo() {
domains, err := a.context.DNSServer.BlacklistService.CountDomains(context.Background())
if err != nil {
log.Warning("Failed to count blacklist domains: %v", err)
}
currentVersion := setup.GetVersionOrDefault(a.version)
ASCIIArt(
a.config,
domains,
currentVersion.Original(),
a.config.API.Authentication,
)
}

View File

@@ -1,7 +1,8 @@
package asciiart
package app
import (
"fmt"
"goaway/backend/logging"
"goaway/backend/settings"
)
@@ -15,9 +16,19 @@ const (
Magenta = "\033[35m"
)
func AsciiArt(config *settings.Config, blockedDomains int, version string, disableAuth bool, ansi bool) {
var art = `
__ _ ___ __ ___ ____ _ _ _ DNS port: %s
/ _' |/ _ \ / _' \ \ /\ / / _' | | | | Web port: %s
| (_| | (_) | (_| |\ V V / (_| | |_| | Upstream: %s
\__, |\___/ \__,_| \_/\_/ \__,_|\__, | Authentication: %s
__/ | __/ | Cache TTL: %s
|___/ %s |___/ Blocked Domains: %s
`
func ASCIIArt(config *settings.Config, blockedDomains int, version string, disableAuth bool) {
const versionSpace = 7
var ansi = logging.GetLogger().Ansi
colorize := func(color, text string) string {
if !ansi {
return text
@@ -28,20 +39,21 @@ func AsciiArt(config *settings.Config, blockedDomains int, version string, disab
versionFormatted := fmt.Sprintf("%-*s%s%-*s", (versionSpace-len(version))/2, "",
colorize(Cyan, version), (versionSpace-len(version)+1)/2, "")
portFormatted := colorize(Green, fmt.Sprintf("%d", config.DNS.Port))
portFormatted := colorize(Green, fmt.Sprintf("%d", config.DNS.Ports.TCPUDP))
adminPanelFormatted := colorize(Red, fmt.Sprintf("%d", config.API.Port))
upstreamFormatted := colorize(Cyan, config.DNS.PreferredUpstream)
upstreamFormatted := colorize(Cyan, config.DNS.Upstream.Preferred)
authFormatted := colorize(Yellow, fmt.Sprintf("%v", disableAuth))
cacheTTLFormatted := colorize(Blue, fmt.Sprintf("%d", config.DNS.CacheTTL))
blockedDomainsFormatted := colorize(Magenta, fmt.Sprintf("%d", blockedDomains))
fmt.Printf(`
__ _ ___ __ ___ ____ _ _ _ DNS port: %s
/ _' |/ _ \ / _' \ \ /\ / / _' | | | | Web port: %s
| (_| | (_) | (_| |\ V V / (_| | |_| | Upstream: %s
\__, |\___/ \__,_| \_/\_/ \__,_|\__, | Authentication: %s
__/ | __/ | Cache TTL: %s
|___/ %s |___/ Blocked Domains: %s
`, portFormatted, adminPanelFormatted, upstreamFormatted, authFormatted, cacheTTLFormatted, versionFormatted, blockedDomainsFormatted)
fmt.Printf(art,
portFormatted,
adminPanelFormatted,
upstreamFormatted,
authFormatted,
cacheTTLFormatted,
versionFormatted,
blockedDomainsFormatted,
)
fmt.Println()
}

View File

@@ -1,64 +1,40 @@
package audit
import (
"goaway/backend/dns/database"
"goaway/backend/logging"
"goaway/backend/database"
"time"
"gorm.io/gorm"
)
type Manager struct {
dbManager *database.DatabaseManager
dbConn *gorm.DB
}
type Topic string
const (
TopicServer Topic = "server"
TopicDNS Topic = "dns"
TopicAPI Topic = "api"
TopicResolution Topic = "resolution"
TopicPrefetch Topic = "prefetch"
TopicUpstream Topic = "upstream"
TopicUser Topic = "user"
TopicList Topic = "list"
TopicLogs Topic = "logs"
TopicSettings Topic = "settings"
TopicDatabase Topic = "database"
)
type Entry struct {
Id int `json:"id"`
Topic Topic `json:"topic"`
Message string `json:"message"`
CreatedAt time.Time `json:"createdAt"`
func NewAuditManager(dbconn *gorm.DB) *Manager {
return &Manager{dbConn: dbconn}
}
var logger = logging.GetLogger()
func NewAuditManager(dbManager *database.DatabaseManager) *Manager {
return &Manager{dbManager: dbManager}
}
func (nm *Manager) CreateAudit(newAudit *Entry) {
func (m *Manager) CreateAudit(newAudit *Entry) {
audit := database.Audit{
Topic: string(newAudit.Topic),
Message: newAudit.Message,
CreatedAt: time.Now(),
}
result := nm.dbManager.Conn.Create(&audit)
result := m.dbConn.Create(&audit)
if result.Error != nil {
logger.Warning("Unable to create new audit, error: %v", result.Error)
log.Warning("Unable to create new audit, error: %v", result.Error)
return
}
logger.Debug("Created new audit, %+v", newAudit)
log.Debug("Created new audit, %+v", newAudit)
}
func (nm *Manager) ReadAudits() ([]Entry, error) {
func (m *Manager) ReadAudits() ([]Entry, error) {
var audits []database.Audit
result := nm.dbManager.Conn.Order("created_at DESC").Find(&audits)
result := m.dbConn.Order("created_at DESC").Find(&audits)
if result.Error != nil {
return nil, result.Error
}
@@ -66,7 +42,7 @@ func (nm *Manager) ReadAudits() ([]Entry, error) {
entries := make([]Entry, len(audits))
for i, audit := range audits {
entries[i] = Entry{
Id: int(audit.ID),
ID: audit.ID,
Topic: Topic(audit.Topic),
Message: audit.Message,
CreatedAt: audit.CreatedAt,

View File

@@ -0,0 +1,51 @@
package audit
import (
"goaway/backend/database"
"gorm.io/gorm"
)
type Repository interface {
CreateAudit(audit *Entry) error
ReadAudits() ([]Entry, error)
}
type repository struct {
db *gorm.DB
}
func NewRepository(db *gorm.DB) Repository {
return &repository{db: db}
}
func (r *repository) CreateAudit(audit *Entry) error {
dbAudit := database.Audit{
Topic: string(audit.Topic),
Message: audit.Message,
CreatedAt: audit.CreatedAt,
}
result := r.db.Create(&dbAudit)
return result.Error
}
func (r *repository) ReadAudits() ([]Entry, error) {
var dbAudits []database.Audit
result := r.db.Order("created_at DESC").Find(&dbAudits)
if result.Error != nil {
return nil, result.Error
}
audits := make([]Entry, len(dbAudits))
for i, dbAudit := range dbAudits {
audits[i] = Entry{
ID: dbAudit.ID,
Topic: Topic(dbAudit.Topic),
Message: dbAudit.Message,
CreatedAt: dbAudit.CreatedAt,
}
}
return audits, nil
}

50
backend/audit/service.go Normal file
View File

@@ -0,0 +1,50 @@
package audit
import (
"goaway/backend/logging"
"time"
)
type Topic string
const (
TopicServer Topic = "server"
TopicDNS Topic = "dns"
TopicAPI Topic = "api"
TopicResolution Topic = "resolution"
TopicPrefetch Topic = "prefetch"
TopicUpstream Topic = "upstream"
TopicUser Topic = "user"
TopicList Topic = "list"
TopicLogs Topic = "logs"
TopicSettings Topic = "settings"
TopicDatabase Topic = "database"
)
type Entry struct {
CreatedAt time.Time `json:"createdAt"`
Topic Topic `json:"topic"`
Message string `json:"message"`
ID uint `json:"id"`
}
var log = logging.GetLogger()
type Service struct {
repository Repository
}
func NewService(repo Repository) *Service {
return &Service{repository: repo}
}
func (s *Service) CreateAudit(entry *Entry) {
err := s.repository.CreateAudit(entry)
if err != nil {
log.Warning("Could not create audit: %v", err)
}
}
func (s *Service) ReadAudits() ([]Entry, error) {
return s.repository.ReadAudits()
}

View File

@@ -0,0 +1,16 @@
package blacklist
type ListUpdateAvailable struct {
RemoteChecksum string `json:"remoteChecksum"`
DBChecksum string `json:"dbChecksum"`
RemoteDomains []string `json:"remoteDomains"`
DBDomains []string `json:"dbDomains"`
DiffAdded []string `json:"diffAdded"`
DiffRemoved []string `json:"diffRemoved"`
UpdateAvailable bool `json:"updateAvailable"`
}
type BlocklistSource struct {
Name string `json:"name"`
URL string `json:"url"`
}

View File

@@ -0,0 +1,383 @@
package blacklist
import (
"context"
"errors"
"fmt"
"goaway/backend/database"
"strings"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type SourceRepository interface {
GetSources(ctx context.Context, excludeCustom bool) ([]database.Source, error)
GetSourceByName(ctx context.Context, name string) (*database.Source, error)
GetSourceByNameAndURL(ctx context.Context, name, url string) (*database.Source, error)
CreateOrUpdateSource(ctx context.Context, source *database.Source) error
UpdateSourceName(ctx context.Context, oldName, newName, url string) error
UpdateSourceLastUpdated(ctx context.Context, url string, timestamp time.Time) error
ToggleSourceActive(ctx context.Context, name string) error
DeleteSource(ctx context.Context, name, url string) error
UpsertSource(ctx context.Context, source *database.Source) error
}
type DomainRepository interface {
GetAllDomains(ctx context.Context) ([]string, error)
GetDomainsForSource(ctx context.Context, sourceName string) ([]string, error)
GetPaginatedDomains(ctx context.Context, page, pageSize int, search string) ([]database.Blacklist, int64, error)
CountDomains(ctx context.Context) (int64, error)
CreateDomain(ctx context.Context, domain *database.Blacklist) error
CreateDomainsInBatches(ctx context.Context, domains []database.Blacklist, batchSize int) error
DeleteDomain(ctx context.Context, domain string) error
DeleteDomainsBySourceID(ctx context.Context, sourceID uint) error
DeleteCustomDomain(ctx context.Context, domain string, sourceID uint) error
}
type StatsRepository interface {
GetAllSourceStats(ctx context.Context) ([]SourceWithCount, error)
GetSourceStats(ctx context.Context, listname string) (*SourceWithCount, error)
GetRequestStats(ctx context.Context) ([]RequestStats, error)
}
type TransactionRepository interface {
WithTransaction(ctx context.Context, fn func(*gorm.DB) error) error
Vacuum(ctx context.Context) error
}
type Repository interface {
SourceRepository
DomainRepository
StatsRepository
TransactionRepository
}
type SourceWithCount struct {
Name string `json:"name"`
URL string `json:"url"`
ID uint `json:"id"`
LastUpdated time.Time `json:"lastUpdated"`
BlockedCount int `json:"blockedCount"`
Active bool `json:"active"`
}
type RequestStats struct {
Blocked bool `json:"blocked"`
Count int `json:"count"`
}
type repository struct {
db *gorm.DB
}
func NewRepository(db *gorm.DB) Repository {
return &repository{db: db}
}
func (r *repository) GetSources(ctx context.Context, excludeCustom bool) ([]database.Source, error) {
var sources []database.Source
query := r.db.WithContext(ctx)
if excludeCustom {
query = query.Where("name != ?", "Custom")
}
if err := query.Find(&sources).Error; err != nil {
return nil, fmt.Errorf("failed to query sources: %w", err)
}
return sources, nil
}
func (r *repository) GetSourceByName(ctx context.Context, name string) (*database.Source, error) {
var source database.Source
if err := r.db.WithContext(ctx).Where("name = ?", name).First(&source).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("source '%s' not found", name)
}
return nil, fmt.Errorf("failed to get source: %w", err)
}
return &source, nil
}
func (r *repository) GetSourceByNameAndURL(ctx context.Context, name, url string) (*database.Source, error) {
var source database.Source
if err := r.db.WithContext(ctx).Where("name = ? AND url = ?", name, url).First(&source).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("source '%s' with URL '%s' not found", name, url)
}
return nil, fmt.Errorf("failed to get source: %w", err)
}
return &source, nil
}
func (r *repository) CreateOrUpdateSource(ctx context.Context, source *database.Source) error {
result := r.db.WithContext(ctx).Where(database.Source{Name: source.Name, URL: source.URL}).FirstOrCreate(source)
if result.Error != nil {
return fmt.Errorf("failed to create or update source: %w", result.Error)
}
return nil
}
func (r *repository) UpdateSourceName(ctx context.Context, oldName, newName, url string) error {
if strings.TrimSpace(newName) == "" {
return fmt.Errorf("new name cannot be empty")
}
result := r.db.WithContext(ctx).Model(&database.Source{}).
Where("name = ? AND url = ?", oldName, url).
Update("name", newName)
if result.Error != nil {
return fmt.Errorf("failed to update source name: %w", result.Error)
}
if result.RowsAffected == 0 {
return fmt.Errorf("list with name '%s' not found", oldName)
}
return nil
}
func (r *repository) UpdateSourceLastUpdated(ctx context.Context, url string, timestamp time.Time) error {
result := r.db.WithContext(ctx).Model(&database.Source{}).
Where("url = ?", url).
Update("last_updated", timestamp)
if result.Error != nil {
return fmt.Errorf("failed to update source: %w", result.Error)
}
return nil
}
func (r *repository) ToggleSourceActive(ctx context.Context, name string) error {
var source database.Source
if err := r.db.WithContext(ctx).Where("name = ?", name).First(&source).Error; err != nil {
return fmt.Errorf("failed to find source %s: %w", name, err)
}
result := r.db.WithContext(ctx).Model(&source).Update("active", !source.Active)
if result.Error != nil {
return fmt.Errorf("failed to toggle status for %s: %w", name, result.Error)
}
return nil
}
func (r *repository) DeleteSource(ctx context.Context, name, url string) error {
var source database.Source
if err := r.db.WithContext(ctx).Where("name = ? AND url = ?", name, url).First(&source).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("source '%s' not found", name)
}
return fmt.Errorf("failed to get source: %w", err)
}
if err := r.db.WithContext(ctx).Delete(&source).Error; err != nil {
return fmt.Errorf("failed to remove source '%s': %w", name, err)
}
return nil
}
func (r *repository) GetAllDomains(ctx context.Context) ([]string, error) {
var domains []string
result := r.db.WithContext(ctx).Model(&database.Blacklist{}).
Distinct("domain").
Pluck("domain", &domains)
if result.Error != nil {
return nil, fmt.Errorf("failed to query blacklist: %w", result.Error)
}
return domains, nil
}
func (r *repository) GetDomainsForSource(ctx context.Context, sourceName string) ([]string, error) {
var blacklistEntries []database.Blacklist
result := r.db.WithContext(ctx).Select("blacklists.domain").
Joins("JOIN sources ON blacklists.source_id = sources.id").
Where("sources.name = ?", sourceName).
Find(&blacklistEntries)
if result.Error != nil {
return nil, fmt.Errorf("failed to query domains for list: %w", result.Error)
}
domains := make([]string, len(blacklistEntries))
for i, entry := range blacklistEntries {
domains[i] = entry.Domain
}
return domains, nil
}
func (r *repository) GetPaginatedDomains(ctx context.Context, page, pageSize int, search string) ([]database.Blacklist, int64, error) {
searchPattern := "%" + search + "%"
offset := (page - 1) * pageSize
var blacklistEntries []database.Blacklist
result := r.db.WithContext(ctx).Select("domain").
Where("domain LIKE ?", searchPattern).
Order("domain DESC").
Limit(pageSize).
Offset(offset).
Find(&blacklistEntries)
if result.Error != nil {
return nil, 0, fmt.Errorf("failed to query blacklist: %w", result.Error)
}
var total int64
countResult := r.db.WithContext(ctx).Model(&database.Blacklist{}).
Where("domain LIKE ?", searchPattern).
Count(&total)
if countResult.Error != nil {
return nil, 0, fmt.Errorf("failed to count domains: %w", countResult.Error)
}
return blacklistEntries, total, nil
}
func (r *repository) CountDomains(ctx context.Context) (int64, error) {
var count int64
result := r.db.WithContext(ctx).Model(&database.Blacklist{}).Count(&count)
if result.Error != nil {
return 0, fmt.Errorf("failed to count domains: %w", result.Error)
}
return count, nil
}
func (r *repository) CreateDomain(ctx context.Context, domain *database.Blacklist) error {
result := r.db.WithContext(ctx).Create(domain)
if result.Error != nil {
if strings.Contains(result.Error.Error(), "UNIQUE constraint failed") ||
strings.Contains(result.Error.Error(), "duplicate key") {
return fmt.Errorf("%s is already blacklisted", domain.Domain)
}
return fmt.Errorf("failed to add domain to blacklist: %w", result.Error)
}
return nil
}
func (r *repository) CreateDomainsInBatches(ctx context.Context, domains []database.Blacklist, batchSize int) error {
if len(domains) == 0 {
return nil
}
if err := r.db.WithContext(ctx).CreateInBatches(domains, batchSize).Error; err != nil {
if !strings.Contains(err.Error(), "UNIQUE constraint failed") &&
!strings.Contains(err.Error(), "duplicate key") {
return fmt.Errorf("failed to add domains: %w", err)
}
}
return nil
}
func (r *repository) DeleteDomain(ctx context.Context, domain string) error {
result := r.db.WithContext(ctx).Where("domain = ?", domain).Delete(&database.Blacklist{})
if result.Error != nil {
return fmt.Errorf("failed to remove domain from blacklist: %w", result.Error)
}
if result.RowsAffected == 0 {
return fmt.Errorf("domain not found: %s", domain)
}
return nil
}
func (r *repository) DeleteDomainsBySourceID(ctx context.Context, sourceID uint) error {
if err := r.db.WithContext(ctx).Where("source_id = ?", sourceID).Delete(&database.Blacklist{}).Error; err != nil {
return fmt.Errorf("failed to remove domains for source ID %d: %w", sourceID, err)
}
return nil
}
func (r *repository) DeleteCustomDomain(ctx context.Context, domain string, sourceID uint) error {
result := r.db.WithContext(ctx).
Where("domain = ? AND source_id = ?", domain, sourceID).
Delete(&database.Blacklist{})
if result.Error != nil {
return fmt.Errorf("failed to delete domain '%s': %w", domain, result.Error)
}
if result.RowsAffected == 0 {
return fmt.Errorf("domain '%s' not found in custom blacklist", domain)
}
return nil
}
func (r *repository) GetAllSourceStats(ctx context.Context) ([]SourceWithCount, error) {
var results []SourceWithCount
result := r.db.WithContext(ctx).Table("sources s").
Select("s.id, s.name, s.url, s.last_updated, s.active, COALESCE(bc.blocked_count, 0) as blocked_count").
Joins("LEFT JOIN (SELECT source_id, COUNT(*) as blocked_count FROM blacklists GROUP BY source_id) bc ON s.id = bc.source_id").
Order("s.name, s.id").
Scan(&results)
if result.Error != nil {
return nil, fmt.Errorf("failed to query source statistics: %w", result.Error)
}
return results, nil
}
func (r *repository) GetSourceStats(ctx context.Context, listname string) (*SourceWithCount, error) {
var result SourceWithCount
err := r.db.WithContext(ctx).Table("sources s").
Select("s.name, s.url, s.last_updated, s.active, COALESCE(bc.blocked_count, 0) as blocked_count").
Joins("LEFT JOIN (SELECT source_id, COUNT(*) as blocked_count FROM blacklists GROUP BY source_id) bc ON s.id = bc.source_id").
Where("s.name = ?", listname).
First(&result).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("list not found")
}
return nil, fmt.Errorf("failed to query list statistics: %w", err)
}
return &result, nil
}
func (r *repository) GetRequestStats(ctx context.Context) ([]RequestStats, error) {
var stats []RequestStats
result := r.db.WithContext(ctx).Model(&database.RequestLog{}).
Select("blocked, COUNT(*) as count").
Group("blocked").
Scan(&stats)
if result.Error != nil {
return nil, fmt.Errorf("failed to query request_logs: %w", result.Error)
}
return stats, nil
}
func (r *repository) Vacuum(ctx context.Context) error {
if err := r.db.WithContext(ctx).Exec("VACUUM").Error; err != nil {
return fmt.Errorf("error while vacuuming database: %w", err)
}
return nil
}
func (r *repository) WithTransaction(ctx context.Context, fn func(*gorm.DB) error) error {
return r.db.WithContext(ctx).Transaction(fn)
}
func (r *repository) UpsertSource(ctx context.Context, source *database.Source) error {
if err := r.db.WithContext(ctx).Clauses(
clause.OnConflict{
Columns: []clause.Column{{Name: "url"}},
DoUpdates: clause.AssignmentColumns([]string{"name", "last_updated", "active"}),
},
).Create(source).Error; err != nil {
return fmt.Errorf("failed to upsert source: %w", err)
}
return nil
}

View File

@@ -0,0 +1,693 @@
package blacklist
import (
"bufio"
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"goaway/backend/database"
"goaway/backend/logging"
"io"
"net/http"
"sort"
"strings"
"sync"
"time"
"gorm.io/gorm"
)
type HTTPClient interface {
Get(url string) (*http.Response, error)
}
type Config struct {
DefaultSources []BlocklistSource
CacheTTL time.Duration
BatchSize int
UpdateInterval time.Duration
}
type Service struct {
repository Repository
httpClient HTTPClient
cache map[string]bool
cacheMu sync.RWMutex
blocklistURL []BlocklistSource
config Config
}
var log = logging.GetLogger()
const (
blacklistedIP = "0.0.0.0"
IPv4Loopback = "127.0.0.1"
defaultBatchSize = 1000
)
var defaultConfig = Config{
DefaultSources: []BlocklistSource{
{
Name: "StevenBlack",
URL: "https://raw.githubusercontent.com/StevenBlack/hosts/refs/heads/master/hosts",
},
},
CacheTTL: 24 * time.Hour,
BatchSize: defaultBatchSize,
UpdateInterval: 24 * time.Hour,
}
func NewService(repo Repository) *Service {
config := defaultConfig
service := &Service{
repository: repo,
httpClient: http.DefaultClient,
cache: make(map[string]bool),
config: config,
}
if len(service.blocklistURL) == 0 {
service.blocklistURL = config.DefaultSources
}
if err := service.initialize(context.Background()); err != nil {
log.Error("Could not initialize blacklist: %v", err)
}
return service
}
func (s *Service) initialize(ctx context.Context) error {
count, err := s.repository.CountDomains(ctx)
if err != nil {
return fmt.Errorf("failed to count domains: %w", err)
}
if count == 0 {
log.Info("No domains in blacklist. Running initialization...")
if err := s.initializeBlockedDomains(ctx); err != nil {
return fmt.Errorf("failed to initialize blocked domains: %w", err)
}
}
if err := s.InitializeBlocklist(ctx, "Custom", ""); err != nil {
return fmt.Errorf("failed to initialize custom blocklist: %w", err)
}
if _, err := s.GetBlocklistUrls(ctx); err != nil {
log.Error("Failed to fetch blocklist URLs: %v", err)
return fmt.Errorf("failed to fetch blocklist URLs: %w", err)
}
if err := s.PopulateCache(ctx); err != nil {
log.Error("Failed to initialize blocklist cache: %v", err)
return fmt.Errorf("failed to initialize blocklist cache: %w", err)
}
return nil
}
func (s *Service) initializeBlockedDomains(ctx context.Context) error {
start := time.Now()
for _, source := range s.blocklistURL {
if source.Name == "Custom" {
continue
}
if err := s.FetchAndLoadHosts(ctx, source.URL, source.Name); err != nil {
return err
}
}
log.Info("Blocked domains initialized in %.2fs", time.Since(start).Seconds())
return nil
}
func (s *Service) PopulateCache(ctx context.Context) error {
domains, err := s.repository.GetAllDomains(ctx)
if err != nil {
return fmt.Errorf("failed to populate cache: %w", err)
}
s.cacheMu.Lock()
defer s.cacheMu.Unlock()
s.cache = make(map[string]bool, len(domains))
for _, domain := range domains {
s.cache[domain] = true
}
return nil
}
func (s *Service) IsBlacklisted(domain string) bool {
s.cacheMu.RLock()
defer s.cacheMu.RUnlock()
if exists, found := s.cache[domain]; found {
return exists
}
return false
}
func (s *Service) GetBlocklistUrls(ctx context.Context) ([]BlocklistSource, error) {
sources, err := s.repository.GetSources(ctx, true)
if err != nil {
return nil, fmt.Errorf("failed to get blocklist URLs: %w", err)
}
blocklistURL := make([]BlocklistSource, len(sources))
for i, source := range sources {
blocklistURL[i] = BlocklistSource{
Name: source.Name,
URL: source.URL,
}
}
s.blocklistURL = blocklistURL
return blocklistURL, nil
}
func (s *Service) NameExists(name, url string) bool {
for _, source := range s.blocklistURL {
if source.Name == name && source.URL == url {
return true
}
}
return false
}
func (s *Service) URLExists(url string) bool {
for _, source := range s.blocklistURL {
if source.URL == url {
return true
}
}
return false
}
func (s *Service) CheckIfUpdateAvailable(ctx context.Context, remoteListURL, listName string) (ListUpdateAvailable, error) {
listUpdateAvailable := ListUpdateAvailable{}
remoteDomains, remoteChecksum, err := s.FetchRemoteHostsList(ctx, remoteListURL)
if err != nil {
log.Warning("Failed to fetch remote hosts list: %v", err)
return listUpdateAvailable, fmt.Errorf("failed to fetch remote hosts list: %w", err)
}
dbDomains, dbChecksum, err := s.FetchDBHostsList(ctx, listName)
if err != nil {
log.Warning("Failed to fetch database hosts list: %v", err)
return listUpdateAvailable, fmt.Errorf("failed to fetch database hosts list: %w", err)
}
if remoteChecksum == dbChecksum {
log.Debug("No updates available for %s", listName)
return listUpdateAvailable, nil
}
diff := func(a, b []string) []string {
mb := make(map[string]struct{}, len(b))
for _, x := range b {
mb[x] = struct{}{}
}
diff := make([]string, 0)
for _, x := range a {
if _, found := mb[x]; !found {
diff = append(diff, x)
}
}
return diff
}
return ListUpdateAvailable{
RemoteDomains: remoteDomains,
DBDomains: dbDomains,
RemoteChecksum: remoteChecksum,
DBChecksum: dbChecksum,
UpdateAvailable: true,
DiffAdded: diff(remoteDomains, dbDomains),
DiffRemoved: diff(dbDomains, remoteDomains),
}, nil
}
func (s *Service) FetchRemoteHostsList(ctx context.Context, url string) ([]string, string, error) {
resp, err := s.httpClient.Get(url)
if err != nil {
return nil, "", fmt.Errorf("failed to fetch hosts file from %s: %w", url, err)
}
defer func(Body io.ReadCloser) {
_ = Body.Close()
}(resp.Body)
domains, err := s.ExtractDomains(resp.Body)
if err != nil {
return nil, "", fmt.Errorf("failed to extract domains from %s: %w", url, err)
}
return domains, calculateDomainsChecksum(domains), nil
}
func (s *Service) FetchDBHostsList(ctx context.Context, name string) ([]string, string, error) {
domains, err := s.repository.GetDomainsForSource(ctx, name)
if err != nil {
return nil, "", fmt.Errorf("could not fetch domains from database: %w", err)
}
return domains, calculateDomainsChecksum(domains), nil
}
func calculateDomainsChecksum(domains []string) string {
sort.Strings(domains)
data := strings.Join(domains, "\n")
hash := sha256.Sum256([]byte(data))
return hex.EncodeToString(hash[:])
}
func (s *Service) FetchAndLoadHosts(ctx context.Context, url, name string) error {
resp, err := s.httpClient.Get(url)
if err != nil {
return fmt.Errorf("failed to fetch hosts file from %s: %w", url, err)
}
defer func(Body io.ReadCloser) {
_ = Body.Close()
}(resp.Body)
domains, err := s.ExtractDomains(resp.Body)
if err != nil {
return fmt.Errorf("failed to extract domains from %s: %w", url, err)
}
if err := s.InitializeBlocklist(ctx, name, url); err != nil {
return fmt.Errorf("failed to initialize blocklist: %w", err)
}
if err := s.AddDomains(ctx, name, domains, url); err != nil {
return fmt.Errorf("failed to add domains to database: %w", err)
}
log.Info("Added %d domains from list '%s' with url '%s'", len(domains), name, url)
return nil
}
func (s *Service) isValidDomain(domain string) bool {
invalidDomains := map[string]bool{
"localhost": true,
"localhost.localdomain": true,
"broadcasthost": true,
"local": true,
blacklistedIP: true,
}
return !invalidDomains[domain]
}
func (s *Service) ExtractDomains(body io.Reader) ([]string, error) {
scanner := bufio.NewScanner(body)
domainSet := make(map[string]struct{})
var domains []string
for scanner.Scan() {
fields := strings.Fields(scanner.Text())
if len(fields) == 0 || strings.HasPrefix(fields[0], "#") {
continue
}
domain := fields[0]
if (domain == blacklistedIP || domain == IPv4Loopback) && len(fields) > 1 {
domain = fields[1]
if !s.isValidDomain(domain) {
continue
}
} else if domain == blacklistedIP || domain == IPv4Loopback {
continue
}
if _, exists := domainSet[domain]; !exists {
domainSet[domain] = struct{}{}
domains = append(domains, domain)
}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error reading hosts file: %w", err)
}
if len(domains) == 0 {
return nil, errors.New("zero results when parsing")
}
return domains, nil
}
func (s *Service) updateCache(domains []string, add bool) {
s.cacheMu.Lock()
defer s.cacheMu.Unlock()
for _, domain := range domains {
if add {
s.cache[domain] = true
} else {
delete(s.cache, domain)
}
}
}
func (s *Service) AddBlacklistedDomain(ctx context.Context, domain string) error {
blacklistEntry := &database.Blacklist{Domain: domain}
if err := s.repository.CreateDomain(ctx, blacklistEntry); err != nil {
return err
}
s.updateCache([]string{domain}, true)
return nil
}
func (s *Service) AddDomains(ctx context.Context, name string, domains []string, url string) error {
return s.repository.WithTransaction(ctx, func(tx *gorm.DB) error {
currentTime := time.Now()
if err := s.repository.UpdateSourceLastUpdated(ctx, url, currentTime); err != nil {
return err
}
source, err := s.repository.GetSourceByNameAndURL(ctx, name, url)
if err != nil {
return err
}
blacklistEntries := make([]database.Blacklist, 0, len(domains))
for _, domain := range domains {
blacklistEntries = append(blacklistEntries, database.Blacklist{
Domain: domain,
SourceID: source.ID,
})
}
if len(blacklistEntries) > 0 {
batchSize := s.config.BatchSize
if batchSize == 0 {
batchSize = defaultBatchSize
}
if err := s.repository.CreateDomainsInBatches(ctx, blacklistEntries, batchSize); err != nil {
return err
}
}
s.updateCache(domains, true)
return nil
})
}
func (s *Service) RemoveDomain(ctx context.Context, domain string) error {
if err := s.repository.DeleteDomain(ctx, domain); err != nil {
return err
}
s.updateCache([]string{domain}, false)
return nil
}
func (s *Service) AddCustomDomains(ctx context.Context, domains []string) error {
return s.repository.WithTransaction(ctx, func(tx *gorm.DB) error {
currentTime := time.Now()
source, err := s.repository.GetSourceByName(ctx, "Custom")
if err != nil {
if strings.Contains(err.Error(), "not found") {
newSource := &database.Source{
Name: "Custom",
LastUpdated: currentTime,
Active: true,
}
if err := s.repository.CreateOrUpdateSource(ctx, newSource); err != nil {
return fmt.Errorf("failed to create custom source: %w", err)
}
source = newSource
} else {
return fmt.Errorf("failed to get custom source: %w", err)
}
} else {
if err := s.repository.UpdateSourceLastUpdated(ctx, "", currentTime); err != nil {
return fmt.Errorf("failed to update custom source: %w", err)
}
}
for _, domain := range domains {
entry := &database.Blacklist{
Domain: domain,
SourceID: source.ID,
}
// Ignore duplicate errors
if err := s.repository.CreateDomain(ctx, entry); err != nil &&
!strings.Contains(err.Error(), "already blacklisted") {
log.Warning("Failed to add domain %s: %v", domain, err)
} else {
s.updateCache([]string{domain}, true)
}
}
return nil
})
}
func (s *Service) RemoveCustomDomain(ctx context.Context, domain string) error {
source, err := s.repository.GetSourceByName(ctx, "Custom")
if err != nil {
return fmt.Errorf("custom source not found: %w", err)
}
if err := s.repository.DeleteCustomDomain(ctx, domain, source.ID); err != nil {
return err
}
s.updateCache([]string{domain}, false)
currentTime := time.Now()
if err := s.repository.UpdateSourceLastUpdated(ctx, "", currentTime); err != nil {
log.Warning("Failed to update custom source timestamp: %v", err)
}
return nil
}
func (s *Service) InitializeBlocklist(ctx context.Context, name, url string) error {
source := &database.Source{
Name: name,
URL: url,
LastUpdated: time.Now(),
Active: true,
}
return s.repository.CreateOrUpdateSource(ctx, source)
}
func (s *Service) AddSource(ctx context.Context, name, url string) error {
if strings.TrimSpace(name) == "" || strings.TrimSpace(url) == "" {
return fmt.Errorf("name and url cannot be empty")
}
source := &database.Source{
Name: name,
URL: url,
LastUpdated: time.Now(),
Active: true,
}
if err := s.repository.UpsertSource(ctx, source); err != nil {
return err
}
// Update in-memory list
found := false
for _, existing := range s.blocklistURL {
if existing.Name == name && existing.URL == url {
found = true
break
}
}
if !found {
s.blocklistURL = append(s.blocklistURL, BlocklistSource{Name: name, URL: url})
}
return nil
}
func (s *Service) UpdateSourceName(ctx context.Context, oldName, newName, url string) error {
if oldName == newName {
return fmt.Errorf("new name is the same as the old name")
}
if err := s.repository.UpdateSourceName(ctx, oldName, newName, url); err != nil {
return err
}
// Update in-memory list
for i, source := range s.blocklistURL {
if source.Name == oldName {
s.blocklistURL[i].Name = newName
}
}
log.Info("Updated blocklist name from '%s' to '%s'", oldName, newName)
return nil
}
func (s *Service) ToggleBlocklistStatus(ctx context.Context, name string) error {
return s.repository.ToggleSourceActive(ctx, name)
}
func (s *Service) RemoveSourceAndDomains(ctx context.Context, name, url string) error {
return s.repository.WithTransaction(ctx, func(tx *gorm.DB) error {
source, err := s.repository.GetSourceByNameAndURL(ctx, name, url)
if err != nil {
return err
}
if err := s.repository.DeleteDomainsBySourceID(ctx, source.ID); err != nil {
return fmt.Errorf("failed to remove domains for source '%s': %w", name, err)
}
if err := s.repository.DeleteSource(ctx, name, url); err != nil {
return err
}
return nil
})
}
func (s *Service) RemoveSourceByNameAndURL(name, url string) bool {
for i := len(s.blocklistURL) - 1; i >= 0; i-- {
if s.blocklistURL[i].Name == name && s.blocklistURL[i].URL == url {
s.blocklistURL = append(s.blocklistURL[:i], s.blocklistURL[i+1:]...)
return true
}
}
return false
}
func (s *Service) CountDomains(ctx context.Context) (int, error) {
count, err := s.repository.CountDomains(ctx)
if err != nil {
return 0, err
}
return int(count), nil
}
func (s *Service) GetAllowedAndBlocked(ctx context.Context) (allowed, blocked int, err error) {
stats, err := s.repository.GetRequestStats(ctx)
if err != nil {
return 0, 0, err
}
for _, stat := range stats {
if stat.Blocked {
blocked = stat.Count
} else {
allowed = stat.Count
}
}
return allowed, blocked, nil
}
func (s *Service) GetAllListStatistics(ctx context.Context) ([]SourceWithCount, error) {
results, err := s.repository.GetAllSourceStats(ctx)
if err != nil {
return nil, err
}
stats := make([]SourceWithCount, len(results))
for i, r := range results {
stats[i] = SourceWithCount{
Name: r.Name,
URL: r.URL,
BlockedCount: r.BlockedCount,
LastUpdated: r.LastUpdated,
Active: r.Active,
}
}
return stats, nil
}
func (s *Service) GetListStatistics(ctx context.Context, listname string) (string, SourceWithCount, error) {
result, err := s.repository.GetSourceStats(ctx, listname)
if err != nil {
return "", SourceWithCount{}, err
}
stats := SourceWithCount{
URL: result.URL,
BlockedCount: result.BlockedCount,
LastUpdated: result.LastUpdated,
Active: result.Active,
}
return result.Name, stats, nil
}
func (s *Service) LoadPaginatedBlacklist(ctx context.Context, page, pageSize int, search string) ([]string, int, error) {
blacklistEntries, total, err := s.repository.GetPaginatedDomains(ctx, page, pageSize, search)
if err != nil {
return nil, 0, err
}
domains := make([]string, len(blacklistEntries))
for i, entry := range blacklistEntries {
domains[i] = entry.Domain
}
return domains, int(total), nil
}
func (s *Service) Vacuum(ctx context.Context) {
log.Debug("Vacuuming database...")
if err := s.repository.Vacuum(ctx); err != nil {
log.Warning("Error while vacuuming database: %v", err)
}
}
func (s *Service) ScheduleAutomaticListUpdates() {
ticker := time.NewTicker(s.config.UpdateInterval)
defer ticker.Stop()
for range ticker.C {
ctx := context.Background()
log.Info("Starting automatic list updates...")
for _, source := range s.blocklistURL {
if source.Name == "Custom" {
continue
}
log.Info("Checking for updates for blocklist %s from %s", source.Name, source.URL)
availableUpdate, err := s.CheckIfUpdateAvailable(ctx, source.URL, source.Name)
if err != nil {
log.Warning("Failed to check for updates for %s: %v", source.Name, err)
continue
}
if !availableUpdate.UpdateAvailable {
log.Info("No updates available for %s", source.Name)
continue
}
if err := s.RemoveSourceAndDomains(ctx, source.Name, source.URL); err != nil {
log.Warning("Failed to remove old domains for %s: %v", source.Name, err)
continue
}
if err := s.FetchAndLoadHosts(ctx, source.URL, source.Name); err != nil {
log.Warning("Failed to fetch and load hosts for %s: %v", source.Name, err)
continue
}
log.Info("Successfully updated %s with %d new domains", source.Name, len(availableUpdate.DiffAdded))
}
if err := s.PopulateCache(ctx); err != nil {
log.Warning("Failed to populate blocklist cache after auto-update: %v", err)
}
}
}

View File

@@ -0,0 +1,46 @@
package database
import (
"log"
"os"
"path/filepath"
"github.com/glebarez/sqlite"
"gorm.io/gorm"
)
func Initialize() *gorm.DB {
if err := os.MkdirAll("data", 0755); err != nil {
log.Fatal("failed to create data directory: %w", err)
}
databasePath := filepath.Join("data", "database.db")
db, err := gorm.Open(sqlite.Open(databasePath), &gorm.Config{})
if err != nil {
log.Fatal("failed while initializing database: %w", err)
}
if err := AutoMigrate(db); err != nil {
log.Fatal("auto migrate failed: %w", err)
}
return db
}
func AutoMigrate(db *gorm.DB) error {
return db.AutoMigrate(
&Source{},
&Blacklist{},
&Whitelist{},
&RequestLog{},
&RequestLogIP{},
&Resolution{},
&MacAddress{},
&User{},
&APIKey{},
&Notification{},
&Prefetch{},
&Audit{},
&Alert{},
)
}

118
backend/database/model.go Normal file
View File

@@ -0,0 +1,118 @@
package database
import (
"time"
)
type Source struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Name string `json:"name" validate:"required"`
URL string `gorm:"unique;not null" json:"url" validate:"required,url"`
Active bool `gorm:"default:true" json:"active"`
LastUpdated time.Time `gorm:"not null" json:"lastUpdated"`
CreatedAt time.Time `gorm:"not null" json:"createdAt"`
UpdatedAt time.Time `gorm:"not null" json:"updatedAt"`
}
type Blacklist struct {
Domain string `gorm:"primaryKey" json:"domain" validate:"required,fqdn"`
SourceID uint `gorm:"primaryKey;not null" json:"sourceID" validate:"required"`
Source Source `gorm:"foreignKey:SourceID;constraint:OnDelete:CASCADE" json:"source"`
CreatedAt time.Time `gorm:"not null" json:"createdAt"`
UpdatedAt time.Time `gorm:"not null" json:"updatedAt"`
}
type Whitelist struct {
Domain string `gorm:"primaryKey" json:"domain" validate:"required,fqdn"`
CreatedAt time.Time `gorm:"not null" json:"createdAt"`
UpdatedAt time.Time `gorm:"not null" json:"updatedAt"`
}
type RequestLog struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Timestamp time.Time `gorm:"not null;index:idx_timestamp_response_size,priority:1" json:"timestamp"`
Domain string `gorm:"type:varchar(255);not null;index:idx_domain_timestamp,priority:1;index:idx_client_ip_domain,priority:2" json:"domain" validate:"required"`
ClientIP string `gorm:"type:varchar(45);not null;index:idx_client_ip;index:idx_client_ip_domain,priority:1" json:"clientIP" validate:"required,ip"`
ClientName string `gorm:"type:varchar(255)" json:"clientName"`
QueryType string `gorm:"type:varchar(10)" json:"queryType"`
Status string `gorm:"type:varchar(50)" json:"status"`
Protocol string `gorm:"type:varchar(10)" json:"protocol"`
ResponseTimeNs int64 `gorm:"not null" json:"repsonseTimeNS"`
ResponseSizeBytes int `gorm:"index:idx_timestamp_response_size,priority:2" json:"responseSizeBytes"`
Blocked bool `gorm:"not null;index:idx_timestamp_covering,priority:2;default:false" json:"blocked"`
Cached bool `gorm:"not null;index:idx_timestamp_covering,priority:3;default:false" json:"cached"`
IPs []RequestLogIP `gorm:"foreignKey:RequestLogID;constraint:OnDelete:CASCADE" json:"ips"`
CreatedAt time.Time `gorm:"not null" json:"createdAt"`
}
type RequestLogIP struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
RequestLogID uint `gorm:"not null;index" json:"requestLogID" validate:"required"`
IP string `gorm:"type:varchar(45);not null" json:"ip" validate:"required,ip"`
RecordType string `gorm:"type:varchar(10);not null" json:"recordType" validate:"required"`
RequestLog RequestLog `gorm:"foreignKey:RequestLogID;constraint:OnDelete:CASCADE" json:"requestLog"`
CreatedAt time.Time `gorm:"not null" json:"createdAt"`
}
type Resolution struct {
Domain string `gorm:"primaryKey" json:"domain" validate:"required,fqdn"`
IP string `gorm:"index" json:"ip" validate:"required,ip"`
CreatedAt time.Time `gorm:"not null" json:"createdAt"`
UpdatedAt time.Time `gorm:"not null" json:"updatedAt"`
}
type MacAddress struct {
MAC string `gorm:"primaryKey" json:"mac" validate:"required,mac"`
IP string `gorm:"index" json:"ip" validate:"required,ip"`
Vendor string `json:"vendor"`
CreatedAt time.Time `gorm:"not null" json:"createdAt"`
UpdatedAt time.Time `gorm:"not null" json:"updatedAt"`
}
type User struct {
Username string `gorm:"primaryKey" json:"username" validate:"required,min=3,max=50"`
Password string `json:"password" validate:"required,min=8"`
CreatedAt time.Time `gorm:"not null" json:"createdAt"`
UpdatedAt time.Time `gorm:"not null" json:"updatedAt"`
}
type APIKey struct {
Name string `gorm:"primaryKey" json:"name" validate:"required"`
Key string `gorm:"unique;not null" json:"key" validate:"required"`
CreatedAt time.Time `gorm:"not null" json:"createdAt"`
UpdatedAt time.Time `gorm:"not null" json:"updatedAt"`
}
type Notification struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Severity string `gorm:"type:varchar(20);not null" json:"severity" validate:"required,oneof=info warning error critical"`
Category string `gorm:"type:varchar(50);not null" json:"category" validate:"required"`
Text string `gorm:"type:text;not null" json:"text" validate:"required"`
Read bool `gorm:"default:false;index" json:"read"`
CreatedAt time.Time `gorm:"not null;index" json:"createdAt"`
UpdatedAt time.Time `gorm:"not null" json:"updatedAt"`
}
type Prefetch struct {
Domain string `gorm:"primaryKey" json:"domain" validate:"required,fqdn"`
QueryType int `gorm:"not null" json:"queryType" validate:"required,min=1"`
Refresh int `gorm:"not null" json:"refresh" validate:"required,min=1"`
CreatedAt time.Time `gorm:"not null" json:"createdAt"`
UpdatedAt time.Time `gorm:"not null" json:"updatedAt"`
}
type Audit struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Topic string `gorm:"type:varchar(100);not null;index" json:"topic" validate:"required"`
Message string `gorm:"type:text;not null" json:"message" validate:"required"`
CreatedAt time.Time `gorm:"not null;index" json:"createdAt"`
}
type Alert struct {
Type string `gorm:"primaryKey" json:"type" validate:"required"`
Name string `gorm:"not null" json:"name" validate:"required"`
Webhook string `json:"webhook" validate:"omitempty,url"`
Enabled bool `gorm:"default:false" json:"enabled"`
CreatedAt time.Time `gorm:"not null" json:"createdAt"`
UpdatedAt time.Time `gorm:"not null" json:"updatedAt"`
}

View File

@@ -17,18 +17,34 @@ import (
var log = logging.GetLogger()
type vendorResponse struct {
Company string `json:"company"`
Success bool `json:"success"`
Found bool `json:"found"`
Company string `json:"company"`
}
type ARPCache struct {
mu sync.RWMutex
type Cache struct {
table map[string]string
mu sync.RWMutex
}
type vendorCacheEntry struct {
vendor string
err error
timestamp time.Time
}
type VendorCache struct {
entries map[string]*vendorCacheEntry
mu sync.RWMutex
ttl time.Duration
}
var (
cache = &ARPCache{table: make(map[string]string)}
cache = &Cache{table: make(map[string]string)}
vendorCache = &VendorCache{
entries: make(map[string]*vendorCacheEntry),
ttl: 60 * time.Second,
}
httpClient = &http.Client{Timeout: 5 * time.Second}
)
@@ -44,6 +60,14 @@ func ProcessARPTable() {
}
}
func CleanVendorResponseCache() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
vendorCache.cleanup()
}
}
func updateARPTable() {
ctx, cancel := context.WithTimeout(context.Background(), 40*time.Second)
defer cancel()
@@ -130,39 +154,56 @@ func GetMacVendor(mac string) (string, error) {
mac = strings.ReplaceAll(mac, "-", "")
mac = strings.ToLower(mac)
if vendor, err, found := vendorCache.get(mac); found {
return vendor, err
}
url := fmt.Sprintf("https://api.maclookup.app/v2/macs/%s", mac)
req, err := http.NewRequest(http.MethodGet, url, nil)
req, err := http.NewRequest(http.MethodGet, url, http.NoBody)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
reqErr := fmt.Errorf("failed to create request: %w", err)
vendorCache.set(mac, "", reqErr)
return "", reqErr
}
resp, err := httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("failed to fetch MAC vendor: %w", err)
apiErr := fmt.Errorf("failed to fetch MAC vendor: %w", err)
vendorCache.set(mac, "", apiErr)
return "", apiErr
}
defer func(Body io.ReadCloser) {
_ = Body.Close()
}(resp.Body)
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
statusErr := fmt.Errorf("unexpected status code: %d", resp.StatusCode)
vendorCache.set(mac, "", statusErr)
return "", statusErr
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response body: %w", err)
readErr := fmt.Errorf("failed to read response body: %w", err)
vendorCache.set(mac, "", readErr)
return "", readErr
}
var result vendorResponse
if err := json.Unmarshal(body, &result); err != nil {
return "", fmt.Errorf("failed to unmarshal response: %w", err)
unmarshalErr := fmt.Errorf("failed to unmarshal response: %w", err)
vendorCache.set(mac, "", unmarshalErr)
return "", unmarshalErr
}
if result.Found {
vendorCache.set(mac, result.Company, nil)
return result.Company, nil
}
return "", fmt.Errorf("vendor not found")
notFoundErr := fmt.Errorf("vendor not found for mac %s", mac)
vendorCache.set(mac, "", notFoundErr)
return "", notFoundErr
}
func isValidMAC(mac string) bool {
@@ -171,3 +212,42 @@ func isValidMAC(mac string) bool {
return len(cleanMAC) == 12 && cleanMAC != "000000000000"
}
func (vc *VendorCache) get(mac string) (string, error, bool) {
vc.mu.RLock()
defer vc.mu.RUnlock()
entry, exists := vc.entries[mac]
if !exists {
return "", nil, false
}
if time.Since(entry.timestamp) > vc.ttl {
return "", nil, false
}
return entry.vendor, entry.err, true
}
func (vc *VendorCache) set(mac, vendor string, err error) {
vc.mu.Lock()
defer vc.mu.Unlock()
vc.entries[mac] = &vendorCacheEntry{
vendor: vendor,
err: err,
timestamp: time.Now(),
}
}
func (vc *VendorCache) cleanup() {
vc.mu.Lock()
defer vc.mu.Unlock()
now := time.Now()
for mac, entry := range vc.entries {
if now.Sub(entry.timestamp) > vc.ttl {
delete(vc.entries, mac)
}
}
}

View File

@@ -1,148 +0,0 @@
package database
import (
"os"
"path/filepath"
"sync"
"time"
"github.com/glebarez/sqlite"
"gorm.io/gorm"
)
type DatabaseManager struct {
Conn *gorm.DB
Mutex *sync.RWMutex
}
type Source struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Name string `json:"name"`
URL string `gorm:"unique" json:"url"`
Active bool `json:"active"`
LastUpdated int64 `json:"lastUpdated"`
}
type Blacklist struct {
Domain string `gorm:"primaryKey" json:"domain"`
SourceID uint `gorm:"primaryKey" json:"source_id"`
Source Source `gorm:"foreignKey:SourceID;references:ID" json:"source"`
}
type Whitelist struct {
Domain string `gorm:"primaryKey" json:"domain"`
}
type RequestLog struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Timestamp time.Time `gorm:"not null;index:idx_request_log_timestamp_covering,priority:1;index:idx_request_log_timestamp_desc;index:idx_request_log_domain_timestamp,priority:2" json:"timestamp"`
Domain string `gorm:"type:varchar(255);not null;index:idx_request_log_domain_timestamp,priority:1;index:idx_client_ip_domain,priority:2" json:"domain"`
Blocked bool `gorm:"not null;index:idx_request_log_timestamp_covering,priority:2" json:"blocked"`
Cached bool `gorm:"not null;index:idx_request_log_timestamp_covering,priority:3" json:"cached"`
ResponseTimeNs int64 `gorm:"not null" json:"response_time_ns"`
ClientIP string `gorm:"type:varchar(45);index:idx_client_ip;index:idx_client_ip_domain,priority:1" json:"client_ip"`
ClientName string `gorm:"type:varchar(255)" json:"client_name"`
Status string `gorm:"type:varchar(50)" json:"status"`
QueryType string `gorm:"type:varchar(10)" json:"query_type"`
ResponseSizeBytes int `json:"response_size_bytes"`
Protocol string `gorm:"type:varchar(10)" json:"protocol"`
IPs []RequestLogIP `gorm:"foreignKey:RequestLogID;constraint:OnDelete:CASCADE" json:"ips"`
}
type RequestLogIP struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
RequestLogID uint `gorm:"not null;index" json:"request_log_id"`
IP string `gorm:"type:varchar(45);not null" json:"ip"`
RType string `gorm:"type:varchar(10);not null" json:"rtype"`
RequestLog RequestLog `gorm:"foreignKey:RequestLogID;references:ID" json:"request_log"`
}
type Resolution struct {
Domain string `gorm:"primaryKey" json:"domain"`
IP string `json:"ip"`
}
type MacAddress struct {
MAC string `gorm:"primaryKey" json:"mac"`
IP string `json:"ip"`
Vendor string `json:"vendor"`
}
type User struct {
Username string `gorm:"primaryKey" json:"username"`
Password string `json:"password"`
}
type APIKey struct {
Name string `gorm:"primaryKey" json:"name"`
Key string `json:"key"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
}
type Notification struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Severity string `json:"severity"`
Category string `json:"category"`
Text string `json:"text"`
Read bool `json:"read"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
}
type Prefetch struct {
Domain string `gorm:"primaryKey" json:"domain"`
Refresh int `json:"refresh"`
QType int `json:"qtype"`
}
type Audit struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Topic string `json:"topic"`
Message string `json:"message"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
}
type Alert struct {
Type string `gorm:"primaryKey" json:"type"`
Enabled bool `json:"enabled"`
Name string `json:"name"`
Webhook string `json:"webhook"`
}
func Initialize() *DatabaseManager {
if err := os.MkdirAll("data", 0755); err != nil {
log.Fatal("failed to create data directory: %v", err)
}
databasePath := filepath.Join("data", "database.db")
db, err := gorm.Open(sqlite.Open(databasePath), &gorm.Config{})
if err != nil {
log.Fatal("failed while initializing database: %v", err)
}
if err := AutoMigrate(db); err != nil {
log.Fatal("auto migrate failed: %v", err)
}
return &DatabaseManager{
Conn: db,
Mutex: &sync.RWMutex{},
}
}
func AutoMigrate(db *gorm.DB) error {
return db.AutoMigrate(
&Source{},
&Blacklist{},
&Whitelist{},
&RequestLog{},
&RequestLogIP{},
&Resolution{},
&MacAddress{},
&User{},
&APIKey{},
&Notification{},
&Prefetch{},
&Audit{},
&Alert{},
)
}

View File

@@ -1,34 +0,0 @@
package database
import (
"errors"
"gorm.io/gorm"
)
func FindVendor(db *gorm.DB, mac string) (string, error) {
var query MacAddress
tx := db.Find(&query, "mac = ?", mac)
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return "", nil
}
if tx.Error != nil {
return "", tx.Error
}
return query.Vendor, nil
}
func SaveMacEntry(db *gorm.DB, clientIP, mac, vendor string) {
entry := MacAddress{
MAC: mac,
IP: clientIP,
Vendor: vendor,
}
tx := db.Create(&entry)
if tx.Error != nil {
log.Warning("Unable to save new MAC entry %v", tx.Error)
}
}

View File

@@ -1,23 +0,0 @@
package models
import "time"
type Client struct {
Name, Mac, Vendor string
LastSeen time.Time
}
type ClientDetails struct {
IP, Name, MAC string
}
type ClientRequestDetails struct {
TotalRequests, UniqueDomains, BlockedRequests, CachedRequests int
AvgResponseTimeMs float64
LastSeen, MostQueriedDomain string
}
type Resolution struct {
IP string `json:"ip"`
Domain string `json:"domain"`
}

View File

@@ -1,67 +0,0 @@
package database
import (
"errors"
"fmt"
"strings"
"gorm.io/gorm"
)
func FetchResolutions(db *gorm.DB) ([]Resolution, error) {
var resolutions []Resolution
if err := db.Find(&resolutions).Error; err != nil {
return nil, fmt.Errorf("failed to fetch resolutions: %w", err)
}
return resolutions, nil
}
func FetchResolution(db *gorm.DB, domain string) (string, error) {
log.Debug("Finding resolution for domain: %s", domain)
var res Resolution
db.Where("domain = ?", domain).Find(&res)
if res.IP != "" {
return res.IP, nil
}
parts := strings.Split(domain, ".")
for i := 1; i < len(parts); i++ {
wildcardDomain := "*." + strings.Join(parts[i:], ".")
if err := db.Where("domain = ?", wildcardDomain).Find(&res).Error; err == nil {
return res.IP, nil
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
return "", err
}
}
return "", nil
}
func CreateNewResolution(db *gorm.DB, ip, domain string) error {
res := Resolution{
Domain: domain,
IP: ip,
}
if err := db.Create(&res).Error; err != nil {
if strings.Contains(err.Error(), "UNIQUE") {
return fmt.Errorf("domain already exists, must be unique")
}
return fmt.Errorf("could not create new resolution: %w", err)
}
return nil
}
func DeleteResolution(db *gorm.DB, ip, domain string) (int, error) {
result := db.Where("ip = ? AND domain = ?", ip, domain).Delete(&Resolution{})
if result.Error != nil {
return 0, fmt.Errorf("could not delete resolution: %w", result.Error)
}
if result.RowsAffected == 0 {
log.Warning("No resolution found with IP: %s and Domain: %s", ip, domain)
}
return int(result.RowsAffected), nil
}

View File

@@ -1,794 +0,0 @@
package lists
import (
"bufio"
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"goaway/backend/dns/database"
"goaway/backend/logging"
"io"
"net/http"
"sort"
"strings"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
var log = logging.GetLogger()
type BlocklistSource struct {
Name string
URL string
}
type Blacklist struct {
DBManager *database.DatabaseManager
BlocklistURL []BlocklistSource
BlacklistCache map[string]bool
}
type SourceStats struct {
Name string `json:"name"`
URL string `json:"url"`
BlockedCount int `json:"blockedCount"`
LastUpdated int64 `json:"lastUpdated"`
Active bool `json:"active"`
}
type ListUpdateAvailable struct {
RemoteDomains []string `json:"remoteDomains"`
DBDomains []string `json:"dbDomains"`
RemoteChecksum string `json:"remoteChecksum"`
DBChecksum string `json:"dbChecksum"`
UpdateAvailable bool `json:"updateAvailable"`
DiffAdded []string `json:"diffAdded"`
DiffRemoved []string `json:"diffRemoved"`
}
func InitializeBlacklist(dbManager *database.DatabaseManager) (*Blacklist, error) {
b := &Blacklist{
DBManager: dbManager,
BlocklistURL: []BlocklistSource{
{Name: "StevenBlack", URL: "https://raw.githubusercontent.com/StevenBlack/hosts/refs/heads/master/hosts"},
},
BlacklistCache: map[string]bool{},
}
if count, _ := b.CountDomains(); count == 0 {
log.Info("No domains in blacklist. Running initialization...")
if err := b.initializeBlockedDomains(); err != nil {
return nil, fmt.Errorf("failed to initialize blocked domains: %w", err)
}
}
if err := b.InitializeBlocklist("Custom", ""); err != nil {
return nil, fmt.Errorf("failed to initialize custom blocklist: %w", err)
}
_, err := b.GetBlocklistUrls()
if err != nil {
log.Error("Failed to fetch blocklist URLs: %v", err)
return nil, fmt.Errorf("failed to fetch blocklist URLs: %w", err)
}
_, err = b.PopulateBlocklistCache()
if err != nil {
log.Error("Failed to initialize blocklist cache")
return nil, fmt.Errorf("failed to initialize blocklist cache: %w", err)
}
return b, nil
}
func (b *Blacklist) initializeBlockedDomains() error {
for _, source := range b.BlocklistURL {
if source.Name == "Custom" {
continue
}
if err := b.FetchAndLoadHosts(source.URL, source.Name); err != nil {
return err
}
}
return nil
}
func (b *Blacklist) Vacuum() {
b.DBManager.Mutex.Lock()
tx := b.DBManager.Conn.Raw("VACUUM")
if err := tx.Error; err != nil {
log.Warning("Error while vacuuming database: %v", err)
}
err := tx.Commit().Error
b.DBManager.Mutex.Unlock()
if err != nil {
log.Warning("Error while vacuuming database: %v", err)
}
}
func (b *Blacklist) GetBlocklistUrls() ([]BlocklistSource, error) {
var sources []database.Source
result := b.DBManager.Conn.Where("name != ?", "Custom").Find(&sources)
if result.Error != nil {
return nil, fmt.Errorf("failed to query sources: %w", result.Error)
}
blocklistURL := make([]BlocklistSource, len(sources))
for i, source := range sources {
blocklistURL[i] = BlocklistSource{
Name: source.Name,
URL: source.URL,
}
}
b.BlocklistURL = blocklistURL
return blocklistURL, nil
}
func (b *Blacklist) CheckIfUpdateAvailable(remoteListURL, listName string) (ListUpdateAvailable, error) {
listUpdateAvailable := ListUpdateAvailable{}
remoteDomains, remoteChecksum, err := b.FetchRemoteHostsList(remoteListURL)
if err != nil {
log.Warning("Failed to fetch remote hosts list: %v", err)
return listUpdateAvailable, fmt.Errorf("failed to fetch remote hosts list: %w", err)
}
dbDomains, dbChecksum, err := b.FetchDBHostsList(listName)
if err != nil {
log.Warning("Failed to fetch database hosts list: %v", err)
return listUpdateAvailable, fmt.Errorf("failed to fetch database hosts list: %w", err)
}
if remoteChecksum == dbChecksum {
log.Debug("No updates available for %s", listName)
return listUpdateAvailable, nil
}
diff := func(a, b []string) []string {
mb := make(map[string]struct{}, len(b))
for _, x := range b {
mb[x] = struct{}{}
}
diff := make([]string, 0)
for _, x := range a {
if _, found := mb[x]; !found {
diff = append(diff, x)
}
}
return diff
}
return ListUpdateAvailable{
RemoteDomains: remoteDomains,
DBDomains: dbDomains,
RemoteChecksum: remoteChecksum,
DBChecksum: dbChecksum,
UpdateAvailable: true,
DiffAdded: diff(remoteDomains, dbDomains),
DiffRemoved: diff(dbDomains, remoteDomains),
}, nil
}
func (b *Blacklist) FetchRemoteHostsList(url string) ([]string, string, error) {
resp, err := http.Get(url)
if err != nil {
return nil, "", fmt.Errorf("failed to fetch hosts file from %s: %w", url, err)
}
defer func(Body io.ReadCloser) {
_ = Body.Close()
}(resp.Body)
domains, err := b.ExtractDomains(resp.Body)
if err != nil {
return nil, "", fmt.Errorf("failed to extract domains from %s: %w", url, err)
}
return domains, calculateDomainsChecksum(domains), nil
}
func (b *Blacklist) FetchDBHostsList(name string) ([]string, string, error) {
domains, err := b.GetDomainsForList(name)
if err != nil {
return nil, "", fmt.Errorf("could not fetch domains from database")
}
return domains, calculateDomainsChecksum(domains), nil
}
func calculateDomainsChecksum(domains []string) string {
sort.Strings(domains)
data := strings.Join(domains, "\n")
hash := sha256.Sum256([]byte(data))
return hex.EncodeToString(hash[:])
}
func (b *Blacklist) FetchAndLoadHosts(url, name string) error {
resp, err := http.Get(url)
if err != nil {
return fmt.Errorf("failed to fetch hosts file from %s: %w", url, err)
}
defer func(Body io.ReadCloser) {
_ = Body.Close()
}(resp.Body)
domains, err := b.ExtractDomains(resp.Body)
if err != nil {
return fmt.Errorf("failed to extract domains from %s: %w", url, err)
}
_ = b.InitializeBlocklist(name, url)
if err := b.AddDomains(domains, url); err != nil {
return fmt.Errorf("failed to add domains to database: %w", err)
}
log.Info("Added %d domains from list '%s' with url '%s'", len(domains), name, url)
return nil
}
func (b *Blacklist) ExtractDomains(body io.Reader) ([]string, error) {
scanner := bufio.NewScanner(body)
domainSet := make(map[string]struct{})
var domains []string
for scanner.Scan() {
fields := strings.Fields(scanner.Text())
if len(fields) == 0 || strings.HasPrefix(fields[0], "#") {
continue
}
domain := fields[0]
if (domain == "0.0.0.0" || domain == "127.0.0.1") && len(fields) > 1 {
domain = fields[1]
switch domain {
case "localhost", "localhost.localdomain", "broadcasthost", "local", "0.0.0.0":
continue
}
} else if domain == "0.0.0.0" || domain == "127.0.0.1" {
continue
}
if _, exists := domainSet[domain]; !exists {
domainSet[domain] = struct{}{}
domains = append(domains, domain)
}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error reading hosts file: %w", err)
}
if len(domains) == 0 {
return nil, errors.New("zero results when parsing")
}
return domains, nil
}
func (b *Blacklist) AddBlacklistedDomain(domain string) error {
blacklistEntry := database.Blacklist{Domain: domain}
result := b.DBManager.Conn.Create(&blacklistEntry)
if result.Error != nil {
if strings.Contains(result.Error.Error(), "UNIQUE constraint failed") ||
strings.Contains(result.Error.Error(), "duplicate key") {
return fmt.Errorf("%s is already blacklisted", domain)
}
return fmt.Errorf("failed to add domain to blacklist: %w", result.Error)
}
b.BlacklistCache[domain] = true
return nil
}
func (b *Blacklist) AddDomains(domains []string, url string) error {
return b.DBManager.Conn.Transaction(func(tx *gorm.DB) error {
var source database.Source
currentTime := time.Now().Unix()
result := tx.Model(&source).Where("url = ?", url).Update("last_updated", currentTime)
if result.Error != nil {
return fmt.Errorf("failed to update source: %w", result.Error)
}
if err := tx.Where("url = ?", url).First(&source).Error; err != nil {
return fmt.Errorf("failed to find source: %w", err)
}
blacklistEntries := make([]database.Blacklist, 0, len(domains))
for _, domain := range domains {
blacklistEntries = append(blacklistEntries, database.Blacklist{
Domain: domain,
SourceID: source.ID,
})
}
if len(blacklistEntries) > 0 {
if err := tx.CreateInBatches(blacklistEntries, 1000).Error; err != nil {
if !strings.Contains(err.Error(), "UNIQUE constraint failed") &&
!strings.Contains(err.Error(), "duplicate key") {
return fmt.Errorf("failed to add domains: %w", err)
}
}
}
return nil
})
}
func (b *Blacklist) PopulateBlocklistCache() (int, error) {
var databaseDomains []string
result := b.DBManager.Conn.Model(&database.Blacklist{}).
Distinct("domain").
Pluck("domain", &databaseDomains)
if result.Error != nil {
return 0, fmt.Errorf("failed to query blacklist: %w", result.Error)
}
b.BlacklistCache = make(map[string]bool, len(databaseDomains))
for _, domain := range databaseDomains {
b.BlacklistCache[domain] = true
}
return len(b.BlacklistCache), nil
}
func (b *Blacklist) CountDomains() (int, error) {
var count int64
result := b.DBManager.Conn.Model(&database.Blacklist{}).Count(&count)
if result.Error != nil {
return 0, fmt.Errorf("failed to count domains: %w", result.Error)
}
return int(count), nil
}
func (b *Blacklist) GetAllowedAndBlocked() (allowed, blocked int, err error) {
type RequestStats struct {
Blocked bool
Count int
}
var stats []RequestStats
result := b.DBManager.Conn.Model(&database.RequestLog{}).
Select("blocked, COUNT(*) as count").
Group("blocked").
Scan(&stats)
if result.Error != nil {
return 0, 0, fmt.Errorf("failed to query request_logs: %w", result.Error)
}
for _, stat := range stats {
if stat.Blocked {
blocked = stat.Count
} else {
allowed = stat.Count
}
}
return allowed, blocked, nil
}
func (b *Blacklist) RemoveDomain(domain string) error {
result := b.DBManager.Conn.Where("domain = ?", domain).Delete(&database.Blacklist{})
if result.Error != nil {
return fmt.Errorf("failed to remove domain from blacklist: %w", result.Error)
}
if result.RowsAffected == 0 {
return fmt.Errorf("%s is already whitelisted", domain)
}
delete(b.BlacklistCache, domain)
return nil
}
func (b *Blacklist) UpdateSourceName(oldName, newName, url string) error {
if strings.TrimSpace(newName) == "" {
return fmt.Errorf("new name cannot be empty")
}
if oldName == newName {
return fmt.Errorf("new name is the same as the old name")
}
result := b.DBManager.Conn.Model(&database.Source{}).
Where("name = ? AND url = ?", oldName, url).
Update("name", newName)
if result.Error != nil {
return fmt.Errorf("failed to update source name: %w", result.Error)
}
if result.RowsAffected == 0 {
return fmt.Errorf("list with name '%s' not found", oldName)
}
for i, source := range b.BlocklistURL {
if source.Name == oldName {
b.BlocklistURL[i].Name = newName
}
}
log.Info("Updated blocklist name from '%s' to '%s'", oldName, newName)
return nil
}
func (b *Blacklist) NameExists(name, url string) bool {
for _, source := range b.BlocklistURL {
if source.Name == name && source.URL == url {
return true
}
}
return false
}
func (b *Blacklist) URLExists(url string) bool {
for _, source := range b.BlocklistURL {
if source.URL == url {
return true
}
}
return false
}
func (b *Blacklist) IsBlacklisted(domain string) bool {
return b.BlacklistCache[domain]
}
func (b *Blacklist) LoadPaginatedBlacklist(page, pageSize int, search string) ([]string, int, error) {
searchPattern := "%" + search + "%"
offset := (page - 1) * pageSize
var blacklistEntries []database.Blacklist
result := b.DBManager.Conn.Select("domain").
Where("domain LIKE ?", searchPattern).
Order("domain DESC").
Limit(pageSize).
Offset(offset).
Find(&blacklistEntries)
if result.Error != nil {
return nil, 0, fmt.Errorf("failed to query blacklist: %w", result.Error)
}
domains := make([]string, len(blacklistEntries))
for i, entry := range blacklistEntries {
domains[i] = entry.Domain
}
var total int64
countResult := b.DBManager.Conn.Model(&database.Blacklist{}).
Where("domain LIKE ?", searchPattern).
Count(&total)
if countResult.Error != nil {
return nil, 0, fmt.Errorf("failed to count domains: %w", countResult.Error)
}
return domains, int(total), nil
}
func (b *Blacklist) InitializeBlocklist(name, url string) error {
return b.DBManager.Conn.Transaction(func(tx *gorm.DB) error {
source := database.Source{
Name: name,
URL: url,
LastUpdated: time.Now().Unix(),
Active: true,
}
result := tx.Where(database.Source{Name: name, URL: url}).FirstOrCreate(&source)
if result.Error != nil {
return fmt.Errorf("failed to initialize new blocklist: %w", result.Error)
}
return nil
})
}
func (b *Blacklist) AddCustomDomains(domains []string) error {
return b.DBManager.Conn.Transaction(func(tx *gorm.DB) error {
var source database.Source
currentTime := time.Now().Unix()
err := tx.Where("name = ?", "Custom").First(&source).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
source = database.Source{
Name: "Custom",
LastUpdated: currentTime,
}
if err := tx.Create(&source).Error; err != nil {
return fmt.Errorf("failed to insert custom source: %w", err)
}
} else {
return fmt.Errorf("failed to get custom source ID: %w", err)
}
} else {
if err := tx.Model(&source).Update("last_updated", currentTime).Error; err != nil {
return fmt.Errorf("failed to update lastUpdated for custom source: %w", err)
}
}
blacklistEntries := make([]database.Blacklist, 0, len(domains))
for _, domain := range domains {
blacklistEntries = append(blacklistEntries, database.Blacklist{
Domain: domain,
SourceID: source.ID,
})
}
for _, entry := range blacklistEntries {
if err := tx.Where(database.Blacklist{Domain: entry.Domain, SourceID: entry.SourceID}).FirstOrCreate(&entry).Error; err != nil {
return fmt.Errorf("failed to add custom domain '%s': %w", entry.Domain, err)
}
b.BlacklistCache[entry.Domain] = true
}
return nil
})
}
func (b *Blacklist) RemoveCustomDomain(domain string) error {
b.DBManager.Mutex.Lock()
defer b.DBManager.Mutex.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
return b.DBManager.Conn.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var source database.Source
err := tx.Where("name = ?", "Custom").First(&source).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("custom source not found")
}
return fmt.Errorf("failed to get custom source ID: %w", err)
}
result := tx.Where("domain = ? AND source_id = ?", domain, source.ID).Delete(&database.Blacklist{})
if result.Error != nil {
return fmt.Errorf("failed to delete domain '%s': %w", domain, result.Error)
}
if result.RowsAffected == 0 {
return fmt.Errorf("domain '%s' not found in custom blacklist", domain)
}
delete(b.BlacklistCache, domain)
currentTime := time.Now().Unix()
if err := tx.Model(&source).Update("last_updated", currentTime).Error; err != nil {
return fmt.Errorf("failed to update lastUpdated for custom source: %w", err)
}
return nil
})
}
func (b *Blacklist) GetAllListStatistics() ([]SourceStats, error) {
type SourceWithCount struct {
ID int `json:"id"`
Name string `json:"name"`
URL string `json:"url"`
LastUpdated int64 `json:"last_updated"`
Active bool `json:"active"`
BlockedCount int `json:"blocked_count"`
}
var results []SourceWithCount
result := b.DBManager.Conn.Table("sources s").
Select("s.id, s.name, s.url, s.last_updated, s.active, COALESCE(bc.blocked_count, 0) as blocked_count").
Joins("LEFT JOIN (SELECT source_id, COUNT(*) as blocked_count FROM blacklists GROUP BY source_id) bc ON s.id = bc.source_id").
Order("s.name, s.id").
Scan(&results)
if result.Error != nil {
return nil, fmt.Errorf("failed to query source statistics: %w", result.Error)
}
stats := make([]SourceStats, len(results))
for i, r := range results {
stats[i] = SourceStats{
Name: r.Name,
URL: r.URL,
BlockedCount: r.BlockedCount,
LastUpdated: r.LastUpdated,
Active: r.Active,
}
}
return stats, nil
}
func (b *Blacklist) GetListStatistics(listname string) (string, SourceStats, error) {
type SourceWithCount struct {
Name string `json:"name"`
URL string `json:"url"`
LastUpdated int64 `json:"last_updated"`
Active bool `json:"active"`
BlockedCount int `json:"blocked_count"`
}
var result SourceWithCount
err := b.DBManager.Conn.Table("sources s").
Select("s.name, s.url, s.last_updated, s.active, COALESCE(bc.blocked_count, 0) as blocked_count").
Joins("LEFT JOIN (SELECT source_id, COUNT(*) as blocked_count FROM blacklists GROUP BY source_id) bc ON s.id = bc.source_id").
Where("s.name = ?", listname).
First(&result).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", SourceStats{}, fmt.Errorf("list not found")
}
return "", SourceStats{}, fmt.Errorf("failed to query list statistics: %w", err)
}
stats := SourceStats{
URL: result.URL,
BlockedCount: result.BlockedCount,
LastUpdated: result.LastUpdated,
Active: result.Active,
}
return result.Name, stats, nil
}
func (b *Blacklist) GetDomainsForList(list string) ([]string, error) {
var blacklistEntries []database.Blacklist
result := b.DBManager.Conn.Select("blacklists.domain").
Joins("JOIN sources ON blacklists.source_id = sources.id").
Where("sources.name = ?", list).
Find(&blacklistEntries)
if result.Error != nil {
return nil, fmt.Errorf("failed to query domains for list: %w", result.Error)
}
domains := make([]string, len(blacklistEntries))
for i, entry := range blacklistEntries {
domains[i] = entry.Domain
}
return domains, nil
}
func (b *Blacklist) ToggleBlocklistStatus(name string) error {
var source database.Source
if err := b.DBManager.Conn.Where("name = ?", name).First(&source).Error; err != nil {
return fmt.Errorf("failed to find source %s: %w", name, err)
}
result := b.DBManager.Conn.Model(&source).Update("active", !source.Active)
if result.Error != nil {
return fmt.Errorf("failed to toggle status for %s: %w", name, result.Error)
}
return nil
}
func (b *Blacklist) RemoveSourceAndDomains(name, url string) error {
return b.DBManager.Conn.Transaction(func(tx *gorm.DB) error {
var source database.Source
err := tx.Where("name = ? AND url = ?", name, url).First(&source).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("source '%s' not found", name)
}
return fmt.Errorf("failed to get source ID: %w", err)
}
if err := tx.Where("source_id = ?", source.ID).Delete(&database.Blacklist{}).Error; err != nil {
return fmt.Errorf("failed to remove domains for source '%s': %w", name, err)
}
if err := tx.Delete(&source).Error; err != nil {
return fmt.Errorf("failed to remove source '%s': %w", name, err)
}
return nil
})
}
func (b *Blacklist) RemoveSourceAndDomainsWithCacheRefresh(name, url string) error {
if err := b.RemoveSourceAndDomains(name, url); err != nil {
return err
}
if _, err := b.PopulateBlocklistCache(); err != nil {
log.Warning("Failed to clear blocklist cache after removing source: %v", err)
}
log.Info("Removed all domains and source '%s'", name)
return nil
}
func (b *Blacklist) ScheduleAutomaticListUpdates() {
for {
next := time.Now().Add(24 * time.Hour).Truncate(24 * time.Hour)
log.Info("Next auto-update for lists scheduled for: %s", next.Format(time.DateTime))
time.Sleep(time.Until(next))
for _, source := range b.BlocklistURL {
if source.Name == "Custom" {
continue
}
log.Info("Checking for updates for blocklist %s from %s", source.Name, source.URL)
availableUpdate, err := b.CheckIfUpdateAvailable(source.URL, source.Name)
if err != nil {
log.Warning("Failed to check for updates for %s: %v", source.Name, err)
continue
}
if !availableUpdate.UpdateAvailable {
log.Info("No updates available for %s", source.Name)
continue
}
if err := b.RemoveSourceAndDomains(source.Name, source.URL); err != nil {
log.Warning("Failed to remove old domains for %s: %v", source.Name, err)
continue
}
if err := b.FetchAndLoadHosts(source.URL, source.Name); err != nil {
log.Warning("Failed to fetch and load hosts for %s: %v", source.Name, err)
continue
}
log.Info("Successfully updated %s with %d new domains", source.Name, len(availableUpdate.DiffAdded))
}
if _, err := b.PopulateBlocklistCache(); err != nil {
log.Warning("Failed to populate blocklist cache after auto-update: %v", err)
}
}
}
func (b *Blacklist) AddSource(name, url string) error {
if strings.TrimSpace(name) == "" || strings.TrimSpace(url) == "" {
return fmt.Errorf("name and url cannot be empty")
}
if err := b.DBManager.Conn.Clauses(
clause.OnConflict{
Columns: []clause.Column{{Name: "url"}},
DoUpdates: clause.AssignmentColumns([]string{"name", "last_updated", "active"}),
},
).Create(&database.Source{
Name: name,
URL: url,
LastUpdated: time.Now().Unix(),
Active: true,
}).Error; err != nil {
return fmt.Errorf("failed to insert source: %w", err)
}
found := false
for _, s := range b.BlocklistURL {
if s.Name == name && s.URL == url {
found = true
break
}
}
if !found {
b.BlocklistURL = append(b.BlocklistURL, BlocklistSource{Name: name, URL: url})
}
return nil
}
func (b *Blacklist) RemoveSourceByNameAndURL(name, url string) bool {
for i := len(b.BlocklistURL) - 1; i >= 0; i-- {
if b.BlocklistURL[i].Name == name && b.BlocklistURL[i].URL == url {
b.BlocklistURL = append(b.BlocklistURL[:i], b.BlocklistURL[i+1:]...)
return true
}
}
return false
}

View File

@@ -1,85 +0,0 @@
package lists
import (
"fmt"
"goaway/backend/dns/database"
"gorm.io/gorm/clause"
)
type Whitelist struct {
DBManager *database.DatabaseManager
Cache map[string]bool
}
func InitializeWhitelist(dbManager *database.DatabaseManager) (*Whitelist, error) {
w := &Whitelist{
DBManager: dbManager,
Cache: map[string]bool{},
}
return w, w.refreshCache()
}
func (w *Whitelist) AddDomain(domain string) error {
result := w.DBManager.Conn.Clauses(clause.OnConflict{DoNothing: true}).Create(&database.Whitelist{Domain: domain})
if result.Error != nil {
return fmt.Errorf("failed to add domain to whitelist: %w", result.Error)
}
if result.RowsAffected == 0 {
return fmt.Errorf("%s is already whitelisted", domain)
}
w.Cache[domain] = true
return nil
}
func (w *Whitelist) RemoveDomain(domain string) error {
result := w.DBManager.Conn.Delete(&database.Whitelist{}, "domain = ?", domain)
if result.Error != nil {
return fmt.Errorf("failed to remove domain from whitelist: %w", result.Error)
}
if result.RowsAffected == 0 {
return fmt.Errorf("%s does not exist", domain)
}
delete(w.Cache, domain)
return nil
}
func (w *Whitelist) refreshCache() error {
for k := range w.Cache {
delete(w.Cache, k)
}
domains, err := w.GetDomains()
if err != nil {
return fmt.Errorf("could not get whitelisted domains while refreshing cache, %v", err)
}
for domain := range domains {
w.Cache[domain] = true
}
return nil
}
func (w *Whitelist) GetDomains() (map[string]bool, error) {
var records []database.Whitelist
if err := w.DBManager.Conn.Find(&records).Error; err != nil {
return nil, fmt.Errorf("failed to query whitelist: %w", err)
}
domains := make(map[string]bool, len(records))
for _, rec := range records {
domains[rec.Domain] = true
}
return domains, nil
}
func (w *Whitelist) IsWhitelisted(domain string) bool {
return w.Cache[domain]
}

View File

@@ -32,7 +32,7 @@ func (s *DNSServer) getCachedRecord(cached interface{}) ([]dns.RR, bool) {
if cachedRecord.Key != "" {
log.Debug("Cached entry has expired, removing %s from cache", cachedRecord.Key)
s.Cache.Delete(cachedRecord.Key)
s.DomainCache.Delete(cachedRecord.Key)
}
return nil, false
@@ -43,14 +43,14 @@ func (s *DNSServer) RemoveCachedDomain(domain string) {
return
}
s.Cache.Range(func(key, value interface{}) bool {
s.DomainCache.Range(func(key, value interface{}) bool {
cachedRecord, ok := value.(CachedRecord)
if !ok || cachedRecord.Domain != domain+"." {
return true
}
log.Debug("Removing cached record for domain %s", domain)
s.Cache.Delete(key)
s.DomainCache.Delete(key)
return true
})
}
@@ -69,7 +69,7 @@ func (s *DNSServer) CacheRecord(cacheKey, domain string, ipAddresses []dns.RR, t
}
now := time.Now()
s.Cache.Store(cacheKey, CachedRecord{
s.DomainCache.Store(cacheKey, CachedRecord{
IPAddresses: ipAddresses,
ExpiresAt: now.Add(cacheTTL),
CachedAt: now,

View File

@@ -16,11 +16,11 @@ import (
)
const (
MaxDoHRequestSize = 4096
DoHTimeout = 20 * time.Second
DoHReadTimeout = 8 * time.Second
DoHWriteTimeout = 8 * time.Second
MB = 1 << 20
maxDoHRequestSize = 4096
doHTimeout = 20 * time.Second
doHReadTimeout = 8 * time.Second
doHWriteTimeout = 8 * time.Second
megabyte = 1 << 20
)
func (s *DNSServer) InitDoH(cert tls.Certificate) (*http.Server, error) {
@@ -29,7 +29,7 @@ func (s *DNSServer) InitDoH(cert tls.Certificate) (*http.Server, error) {
mux.HandleFunc("/health", s.handleHealthCheck)
server := &http.Server{
Addr: fmt.Sprintf("%s:%d", s.Config.DNS.Address, s.Config.DNS.DoHPort),
Addr: fmt.Sprintf("%s:%d", s.Config.DNS.Address, s.Config.DNS.Ports.DoH),
Handler: mux,
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
@@ -38,17 +38,17 @@ func (s *DNSServer) InitDoH(cert tls.Certificate) (*http.Server, error) {
PreferServerCipherSuites: true,
NextProtos: []string{"h2", "http/1.1"},
},
ReadTimeout: DoHReadTimeout,
WriteTimeout: DoHWriteTimeout,
ReadTimeout: doHReadTimeout,
WriteTimeout: doHWriteTimeout,
ReadHeaderTimeout: 5 * time.Second,
IdleTimeout: 60 * time.Second,
MaxHeaderBytes: 1 * MB,
MaxHeaderBytes: 1 * megabyte,
}
return server, nil
}
func (s *DNSServer) handleHealthCheck(w http.ResponseWriter, r *http.Request) {
func (s *DNSServer) handleHealthCheck(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte(`{"status":"healthy"}`))
@@ -60,14 +60,14 @@ func (s *DNSServer) handleHealthCheck(w http.ResponseWriter, r *http.Request) {
}
func (s *DNSServer) handleDoHRequest(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), DoHTimeout)
ctx, cancel := context.WithTimeout(r.Context(), doHTimeout)
defer cancel()
r = r.WithContext(ctx)
log.Debug("DoH request received: %s %s from %s", r.Method, r.URL.String(), r.RemoteAddr)
if r.ContentLength > MaxDoHRequestSize {
if r.ContentLength > maxDoHRequestSize {
log.Warning("DoH request too large: %d bytes from %s", r.ContentLength, r.RemoteAddr)
http.Error(w, "Request too large", http.StatusRequestEntityTooLarge)
return
@@ -79,9 +79,9 @@ func (s *DNSServer) handleDoHRequest(w http.ResponseWriter, r *http.Request) {
client model.Client
)
if xRealIP != "" {
go s.WSCom(communicationMessage{true, false, false, xRealIP})
go s.WSCom(communicationMessage{IP: xRealIP, Client: true, Upstream: false, DNS: false})
} else {
go s.WSCom(communicationMessage{true, false, false, clientIP})
go s.WSCom(communicationMessage{IP: clientIP, Client: true, Upstream: false, DNS: false})
}
var (
@@ -90,9 +90,9 @@ func (s *DNSServer) handleDoHRequest(w http.ResponseWriter, r *http.Request) {
)
switch r.Method {
case "GET":
case http.MethodGet:
dnsQuery, err = s.handleDoHGet(r)
case "POST":
case http.MethodPost:
dnsQuery, err = s.handleDoHPost(r)
default:
log.Warning("DoH request invalid method: %s from %s", r.Method, r.RemoteAddr)
@@ -124,7 +124,7 @@ func (s *DNSServer) handleDoHRequest(w http.ResponseWriter, r *http.Request) {
responseWriter := &DoHResponseWriter{
httpWriter: w,
remoteAddr: r.RemoteAddr,
DoHPort: s.Config.DNS.DoHPort,
DoHPort: s.Config.DNS.Ports.DoH,
}
if xRealIP != "" {
@@ -147,7 +147,7 @@ func (s *DNSServer) handleDoHRequest(w http.ResponseWriter, r *http.Request) {
logEntry := s.processQuery(req)
go s.WSCom(communicationMessage{false, false, true, clientIP})
go s.WSCom(communicationMessage{IP: clientIP, Client: false, Upstream: false, DNS: true})
select {
case s.logEntryChannel <- logEntry:
@@ -162,7 +162,7 @@ func (s *DNSServer) handleDoHGet(r *http.Request) ([]byte, error) {
return nil, fmt.Errorf("missing dns parameter")
}
if len(dnsParam) > MaxDoHRequestSize {
if len(dnsParam) > maxDoHRequestSize {
return nil, fmt.Errorf("dns parameter too long")
}
@@ -184,7 +184,7 @@ func (s *DNSServer) handleDoHPost(r *http.Request) ([]byte, error) {
return nil, fmt.Errorf("invalid content type: %s", contentType)
}
limitedReader := io.LimitReader(r.Body, MaxDoHRequestSize)
limitedReader := io.LimitReader(r.Body, maxDoHRequestSize)
dnsQuery, err := io.ReadAll(limitedReader)
if err != nil {
return nil, fmt.Errorf("failed to read request body: %w", err)
@@ -266,4 +266,4 @@ func (w *DoHResponseWriter) Header() http.Header { return nil }
func (w *DoHResponseWriter) Network() string { return "tcp" }
func (w *DoHResponseWriter) WriteHeader(statusCode int) {}
func (w *DoHResponseWriter) WriteHeader(_ int) {}

View File

@@ -10,7 +10,7 @@ import (
func (s *DNSServer) InitDoT(cert tls.Certificate) (*dns.Server, error) {
notifyReady := func() {
log.Info("Started DoT (dns-over-tls) server on port %d", s.Config.DNS.DoTPort)
log.Info("Started DoT (dns-over-tls) server on port %d", s.Config.DNS.Ports.DoT)
}
tlsConfig := &tls.Config{
@@ -18,7 +18,7 @@ func (s *DNSServer) InitDoT(cert tls.Certificate) (*dns.Server, error) {
MinVersion: tls.VersionTLS12,
}
server := &dns.Server{
Addr: fmt.Sprintf("%s:%d", s.Config.DNS.Address, s.Config.DNS.DoTPort),
Addr: fmt.Sprintf("%s:%d", s.Config.DNS.Address, s.Config.DNS.Ports.DoT),
Net: "tcp-tls",
Handler: s,
TLSConfig: tlsConfig,

View File

@@ -5,9 +5,8 @@ import (
"context"
"fmt"
arp "goaway/backend/dns"
"goaway/backend/dns/database"
model "goaway/backend/dns/server/models"
notification "goaway/backend/notifications"
"goaway/backend/notification"
"net"
"os"
"os/exec"
@@ -24,8 +23,13 @@ var (
blackholeIPv6 = net.ParseIP("::")
)
const (
IPv4Loopback = "127.0.0.1"
unknownHostname = "unknown"
)
func trimDomainDot(name string) string {
if len(name) > 0 && name[len(name)-1] == '.' {
if name != "" && name[len(name)-1] == '.' {
return name[:len(name)-1]
}
return name
@@ -37,15 +41,15 @@ func isPTRQuery(request *Request, domainName string) bool {
func (s *DNSServer) checkAndUpdatePauseStatus() {
if s.Config.DNS.Status.Paused &&
time.Since(s.Config.DNS.Status.PausedAt).Seconds() >= float64(s.Config.DNS.Status.PauseTime) {
s.Config.DNS.Status.PausedAt.After(s.Config.DNS.Status.PauseTime) {
s.Config.DNS.Status.Paused = false
}
}
func (s *DNSServer) shouldBlockQuery(domainName, fullName string) bool {
return !s.Config.DNS.Status.Paused &&
s.Blacklist.IsBlacklisted(domainName) &&
!s.Whitelist.IsWhitelisted(fullName)
s.BlacklistService.IsBlacklisted(domainName) &&
!s.WhitelistService.IsWhitelisted(fullName)
}
func (s *DNSServer) processQuery(request *Request) model.RequestLogEntry {
@@ -77,20 +81,6 @@ func (s *DNSServer) processQuery(request *Request) model.RequestLogEntry {
return s.handleStandardQuery(request)
}
func (s *DNSServer) GetVendor(mac string) (string, error) {
s.DBManager.Mutex.Lock()
defer s.DBManager.Mutex.Unlock()
return database.FindVendor(s.DBManager.Conn, mac)
}
func (s *DNSServer) SaveMacVendor(clientIP, mac, vendor string) {
s.DBManager.Mutex.Lock()
defer s.DBManager.Mutex.Unlock()
log.Debug("Saving new MAC address: %s %s", mac, vendor)
database.SaveMacEntry(s.DBManager.Conn, clientIP, mac, vendor)
}
func (s *DNSServer) reverseHostnameLookup(requestedHostname string) (string, bool) {
trimmed := strings.TrimSuffix(requestedHostname, ".")
@@ -114,24 +104,24 @@ func (s *DNSServer) getClientInfo(remoteAddr string) *model.Client {
hostname := s.resolveHostname(clientIP)
resultIP := clientIP
vendor, err := s.GetVendor(macAddress)
if macAddress != "unknown" {
vendor, err := s.MACService.FindVendor(macAddress)
if macAddress != unknownHostname {
if err != nil || vendor == "" {
log.Debug("Lookup vendor for mac %s", macAddress)
vendor, err = arp.GetMacVendor(macAddress)
if err == nil {
s.SaveMacVendor(clientIP, macAddress, vendor)
s.MACService.SaveMac(clientIP, macAddress, vendor)
} else {
log.Warning("Error while lookup mac address vendor: %v", err)
log.Warning("Was not able to find vendor for addr '%s'. %v", remoteAddr, err)
}
}
}
if clientIP == "127.0.0.1" || clientIP == "::1" || clientIP == "[" {
if clientIP == IPv4Loopback || clientIP == "::1" || clientIP == "[" {
localIP, err := getLocalIP()
if err != nil {
log.Warning("Failed to get local IP: %v", err)
localIP = "127.0.0.1"
localIP = IPv4Loopback
}
resultIP = localIP
@@ -145,7 +135,7 @@ func (s *DNSServer) getClientInfo(remoteAddr string) *model.Client {
client := model.Client{IP: resultIP, Name: hostname, MAC: macAddress}
s.clientCache.Store(clientIP, &client)
if client.Name != "unknown" {
if client.Name != unknownHostname {
s.hostnameCache.Store(client.Name, client.IP)
}
@@ -153,19 +143,27 @@ func (s *DNSServer) getClientInfo(remoteAddr string) *model.Client {
}
func (s *DNSServer) resolveHostname(clientIP string) string {
if hostname := s.reverseDNSLookup(clientIP); hostname != "unknown" {
ip := net.ParseIP(clientIP)
if ip.IsLoopback() {
hostname, err := os.Hostname()
if err == nil {
return hostname
}
}
if hostname := s.reverseDNSLookup(clientIP); hostname != unknownHostname {
return hostname
}
if hostname := s.avahiLookup(clientIP); hostname != "unknown" {
if hostname := s.avahiLookup(clientIP); hostname != unknownHostname {
return hostname
}
if hostname := s.sshBannerLookup(clientIP); hostname != "unknown" {
if hostname := s.sshBannerLookup(clientIP); hostname != unknownHostname {
return hostname
}
return "unknown"
return unknownHostname
}
func (s *DNSServer) avahiLookup(clientIP string) string {
@@ -175,8 +173,8 @@ func (s *DNSServer) avahiLookup(clientIP string) string {
cmd := exec.CommandContext(ctx, "avahi-resolve-address", clientIP)
output, err := cmd.Output()
if err == nil {
lines := strings.Split(string(output), "\n")
for _, line := range lines {
lines := strings.SplitSeq(string(output), "\n")
for line := range lines {
if strings.Contains(line, clientIP) {
parts := strings.Fields(line)
if len(parts) >= 2 {
@@ -190,7 +188,7 @@ func (s *DNSServer) avahiLookup(clientIP string) string {
}
}
return "unknown"
return unknownHostname
}
func (s *DNSServer) reverseDNSLookup(clientIP string) string {
@@ -206,13 +204,13 @@ func (s *DNSServer) reverseDNSLookup(clientIP string) string {
return hostname
}
}
return "unknown"
return unknownHostname
}
func (s *DNSServer) sshBannerLookup(clientIP string) string {
conn, err := net.DialTimeout("tcp", clientIP+":22", 1*time.Second)
if err != nil {
return "unknown"
return unknownHostname
}
defer func() {
_ = conn.Close()
@@ -222,13 +220,13 @@ func (s *DNSServer) sshBannerLookup(clientIP string) string {
if err != nil {
log.Warning("Failed to set deadline for SSH banner lookup: %v", err)
_ = conn.Close()
return "unknown"
return unknownHostname
}
reader := bufio.NewReader(conn)
banner, err := reader.ReadString('\n')
if err != nil {
return "unknown"
return unknownHostname
}
patterns := []*regexp.Regexp{
@@ -248,7 +246,7 @@ func (s *DNSServer) sshBannerLookup(clientIP string) string {
}
}
return "unknown"
return unknownHostname
}
func getLocalIP() (string, error) {
@@ -265,7 +263,7 @@ func getLocalIP() (string, error) {
}
}
return "127.0.0.1", fmt.Errorf("no non-loopback IPv4 address found")
return IPv4Loopback, fmt.Errorf("no non-loopback IPv4 address found")
}
func (s *DNSServer) handlePTRQuery(request *Request) model.RequestLogEntry {
@@ -277,7 +275,7 @@ func (s *DNSServer) handlePTRQuery(request *Request) model.RequestLogEntry {
}
ipStr := strings.Join(parts, ".")
if ipStr == "127.0.0.1" {
if ipStr == IPv4Loopback {
return s.respondWithLocalhost(request)
}
@@ -285,12 +283,12 @@ func (s *DNSServer) handlePTRQuery(request *Request) model.RequestLogEntry {
return s.forwardPTRQueryUpstream(request)
}
hostname := database.GetClientNameFromRequestLog(s.DBManager.Conn, ipStr)
if hostname == "unknown" {
hostname := s.RequestService.GetClientNameFromIP(ipStr)
if hostname == unknownHostname {
hostname = s.resolveHostname(ipStr)
}
if hostname != "unknown" {
if hostname != unknownHostname {
return s.respondWithHostnamePTR(request, hostname)
}
@@ -540,11 +538,11 @@ func (s *DNSServer) handleStandardQuery(request *Request) model.RequestLogEntry
err := request.ResponseWriter.WriteMsg(request.Msg)
if err != nil {
log.Warning("Could not write query response. client: [%s] with query [%v], err: %v", request.Client.IP, request.Msg.Answer, err.Error())
s.Notifications.CreateNotification(&notification.Notification{
Severity: notification.SeverityWarning,
Category: notification.CategoryDNS,
Text: fmt.Sprintf("Could not write query response. Client: %s, err: %v", request.Client.IP, err.Error()),
})
s.NotificationService.SendNotification(
notification.SeverityWarning,
notification.CategoryDNS,
fmt.Sprintf("Could not write query response. Client: %s, err: %v", request.Client.IP, err.Error()),
)
}
return model.RequestLogEntry{
@@ -563,7 +561,7 @@ func (s *DNSServer) handleStandardQuery(request *Request) model.RequestLogEntry
func (s *DNSServer) Resolve(req *Request) ([]dns.RR, bool, string) {
cacheKey := req.Question.Name + ":" + strconv.Itoa(int(req.Question.Qtype))
if cached, found := s.Cache.Load(cacheKey); found {
if cached, found := s.DomainCache.Load(cacheKey); found {
if ipAddresses, valid := s.getCachedRecord(cached); valid {
return ipAddresses, true, dns.RcodeToString[dns.RcodeSuccess]
}
@@ -593,7 +591,7 @@ func (s *DNSServer) resolveResolution(domain string) ([]dns.RR, uint32, string)
status = dns.RcodeToString[dns.RcodeSuccess]
)
ipFound, err := database.FetchResolution(s.DBManager.Conn, domain)
ipFound, err := s.ResolutionService.GetResolution(domain)
if err != nil {
log.Error("Database lookup error for domain (%s): %v", domain, err)
return nil, 0, dns.RcodeToString[dns.RcodeServerFailure]
@@ -648,14 +646,14 @@ func (s *DNSServer) QueryUpstream(req *Request) ([]dns.RR, uint32, string) {
errCh := make(chan error, 1)
go func() {
go s.WSCom(communicationMessage{false, true, false, ""})
go s.WSCom(communicationMessage{IP: "", Client: false, Upstream: true, DNS: false})
upstreamMsg := &dns.Msg{}
upstreamMsg.SetQuestion(req.Question.Name, req.Question.Qtype)
upstreamMsg.RecursionDesired = true
upstreamMsg.Id = dns.Id()
upstream := s.Config.DNS.PreferredUpstream
upstream := s.Config.DNS.Upstream.Preferred
if s.dnsClient.Net == "tcp-tls" {
host, port, err := net.SplitHostPort(upstream)
if err != nil {
@@ -681,7 +679,7 @@ func (s *DNSServer) QueryUpstream(req *Request) ([]dns.RR, uint32, string) {
select {
case in := <-resultCh:
go s.WSCom(communicationMessage{false, false, true, ""})
go s.WSCom(communicationMessage{IP: "", Client: false, Upstream: false, DNS: true})
status := dns.RcodeToString[dns.RcodeServerFailure]
if statusStr, ok := dns.RcodeToString[in.Rcode]; ok {
@@ -713,11 +711,11 @@ func (s *DNSServer) QueryUpstream(req *Request) ([]dns.RR, uint32, string) {
case err := <-errCh:
log.Warning("Resolution error for domain (%s): %v", req.Question.Name, err)
s.Notifications.CreateNotification(&notification.Notification{
Severity: notification.SeverityWarning,
Category: notification.CategoryDNS,
Text: fmt.Sprintf("Resolution error for domain (%s)", req.Question.Name),
})
s.NotificationService.SendNotification(
notification.SeverityWarning,
notification.CategoryDNS,
fmt.Sprintf("Resolution error for domain (%s)", req.Question.Name),
)
return nil, 0, dns.RcodeToString[dns.RcodeServerFailure]
case <-time.After(5 * time.Second):
@@ -733,8 +731,13 @@ func (s *DNSServer) LocalForwardLookup(req *Request) (model.RequestLogEntry, err
hostname += "."
}
queryType := req.Question.Qtype
if queryType == 0 {
queryType = dns.TypeA
}
dnsMsg := new(dns.Msg)
dnsMsg.SetQuestion(hostname, dns.TypeA)
dnsMsg.SetQuestion(hostname, queryType)
client := &dns.Client{Net: "udp"}
start := time.Now()
@@ -759,7 +762,7 @@ func (s *DNSServer) LocalForwardLookup(req *Request) (model.RequestLogEntry, err
}
}
if len(ips) == 0 {
if len(ips) == 0 && queryType == dns.TypeA {
return model.RequestLogEntry{}, fmt.Errorf("no A records found for hostname: %s", hostname)
}
@@ -772,7 +775,7 @@ func (s *DNSServer) LocalForwardLookup(req *Request) (model.RequestLogEntry, err
entry := model.RequestLogEntry{
Domain: req.Question.Name,
Status: dns.RcodeToString[in.Rcode],
QueryType: dns.TypeToString[dns.TypeA],
QueryType: dns.TypeToString[queryType],
IP: ips,
ResponseSizeBytes: in.Len(),
Timestamp: start,

View File

@@ -2,14 +2,13 @@ package server
import (
"encoding/json"
"goaway/backend/dns/database"
model "goaway/backend/dns/server/models"
"time"
"github.com/gorilla/websocket"
)
const BatchSize = 1000
const batchSize = 1000
func (s *DNSServer) ProcessLogEntries() {
var batch []model.RequestLogEntry
@@ -26,7 +25,7 @@ func (s *DNSServer) ProcessLogEntries() {
}
batch = append(batch, entry)
if len(batch) >= BatchSize {
if len(batch) >= batchSize {
s.saveBatch(batch)
batch = nil
}
@@ -40,14 +39,13 @@ func (s *DNSServer) ProcessLogEntries() {
}
func (s *DNSServer) saveBatch(entries []model.RequestLogEntry) {
s.DBManager.Mutex.Lock()
err := database.SaveRequestLog(s.DBManager.Conn, entries)
s.DBManager.Mutex.Unlock()
err := s.RequestService.SaveRequestLog(entries)
if err != nil {
log.Warning("Error while saving logs, reason: %v", err)
}
}
// Removes old log entries based on the configured retention period.
func (s *DNSServer) ClearOldEntries() {
const (
maxRetries = 10
@@ -56,10 +54,10 @@ func (s *DNSServer) ClearOldEntries() {
)
for {
requestThreshold := ((60 * 60) * 24) * s.Config.StatisticsRetention
requestThreshold := ((60 * 60) * 24) * s.Config.Misc.StatisticsRetention
log.Debug("Next cleanup running at %s", time.Now().Add(cleanupInterval).Format(time.DateTime))
time.Sleep(cleanupInterval)
database.DeleteRequestLogsTimebased(s.Blacklist.Vacuum, s.DBManager.Conn, requestThreshold, maxRetries, retryDelay)
s.RequestService.DeleteRequestLogsTimebased(s.BlacklistService.Vacuum, requestThreshold, maxRetries, retryDelay)
}
}

View File

@@ -3,18 +3,18 @@ package model
import "time"
type RequestLogEntry struct {
ID int64 `json:"id"`
Timestamp time.Time `json:"timestamp"`
ClientInfo *Client `json:"client"`
Domain string `json:"domain"`
Status string `json:"status"`
QueryType string `json:"queryType"`
Protocol Protocol `json:"protocol"`
IP []ResolvedIP `json:"ip"`
ID uint `json:"id"`
ResponseSizeBytes int `json:"responseSizeBytes"`
Timestamp time.Time `json:"timestamp"`
ResponseTime time.Duration `json:"responseTimeNS"`
Blocked bool `json:"blocked"`
Cached bool `json:"cached"`
ClientInfo *Client `json:"client"`
Protocol Protocol `json:"protocol"`
}
type Protocol string
@@ -39,8 +39,8 @@ type RequestLogIntervalSummary struct {
}
type ResponseSizeSummary struct {
StartUnix int64 `json:"-"`
Start time.Time `json:"start"`
StartUnix int64 `json:"-"`
TotalSizeBytes int `json:"total_size_bytes"`
AvgResponseSizeBytes int `json:"avg_response_size_bytes"`
MinResponseSizeBytes int `json:"min_response_size_bytes"`

View File

@@ -1,204 +0,0 @@
package prefetch
import (
"fmt"
"goaway/backend/dns/database"
"goaway/backend/dns/server"
"goaway/backend/logging"
"strconv"
"time"
"github.com/miekg/dns"
)
var log = logging.GetLogger()
type Manager struct {
dbManager *database.DatabaseManager
DNS *server.DNSServer
Domains map[string]PrefetchedDomain
}
type PrefetchedDomain struct {
Domain string `json:"domain"`
Refresh int `json:"refresh"`
Qtype int `json:"qtype"`
}
func (manager *Manager) LoadPrefetchedDomains() {
var prefetched []database.Prefetch
if err := manager.dbManager.Conn.Find(&prefetched).Error; err != nil {
log.Warning("Failed to query prefetch table: %v", err)
return
}
for _, p := range prefetched {
manager.Domains[p.Domain] = PrefetchedDomain{
Domain: p.Domain,
Refresh: p.Refresh,
Qtype: p.QType,
}
}
if len(manager.Domains) > 0 {
log.Info("Loaded %d prefetched domain(s)", len(manager.Domains))
}
}
func (manager *Manager) AddPrefetchedDomain(domain string, refresh, qtype int) error {
prefetch := database.Prefetch{
Domain: domain,
Refresh: refresh,
QType: qtype,
}
result := manager.dbManager.Conn.FirstOrCreate(&prefetch, database.Prefetch{Domain: domain})
if result.Error != nil {
return fmt.Errorf("failed to add new domain to prefetch table: %w", result.Error)
}
if result.RowsAffected == 0 {
return fmt.Errorf("%s already exists", domain)
}
manager.Domains[domain] = PrefetchedDomain{
Domain: domain,
Refresh: refresh,
Qtype: qtype,
}
log.Info("%s was added as a prefetched domain", domain)
return nil
}
func (manager *Manager) RemovePrefetchedDomain(domain string) error {
result := manager.dbManager.Conn.Delete(&database.Prefetch{}, "domain = ?", domain)
if result.Error != nil {
return fmt.Errorf("failed to remove %s from prefetch table: %w", domain, result.Error)
}
if result.RowsAffected == 0 {
return fmt.Errorf("%s does not exist in the database", domain)
}
delete(manager.Domains, domain)
log.Info("%s was removed as a prefetched domain", domain)
return nil
}
func New(dnsServer *server.DNSServer) Manager {
manager := Manager{
dbManager: dnsServer.DBManager,
DNS: dnsServer,
Domains: make(map[string]PrefetchedDomain),
}
manager.LoadPrefetchedDomains()
return manager
}
func (manager *Manager) Run() {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for range ticker.C {
manager.checkNewDomains()
manager.processExpiredEntries()
}
}
func (manager *Manager) checkNewDomains() {
for domain, prefetchDomain := range manager.Domains {
cacheKey := manager.buildCacheKey(domain, dns.Type(prefetchDomain.Qtype))
if _, exists := manager.DNS.Cache.Load(cacheKey); !exists {
log.Debug("Prefetching new/missing domain: %s", domain)
manager.prefetchDomain(prefetchDomain)
}
}
}
func (manager *Manager) processExpiredEntries() {
now := time.Now()
var expiredKeys []interface{}
var removeFromDomains []string
manager.DNS.Cache.Range(func(key, value interface{}) bool {
cachedDomain, ok := value.(server.CachedRecord)
if !ok {
log.Debug("Cache entry type assertion failed for key: %v", key)
return true
}
if manager.isExpired(cachedDomain, now) {
expiredKeys = append(expiredKeys, key)
if _, isPrefetched := manager.Domains[cachedDomain.Domain]; !isPrefetched {
removeFromDomains = append(removeFromDomains, cachedDomain.Domain)
log.Debug("Non-prefetch entry '%v' expired and will be removed", key)
} else {
log.Debug("Prefetch entry '%v' expired and will be refreshed", key)
}
}
return true
})
manager.handleExpiredKeys(expiredKeys)
manager.removeNonPrefetchDomains(removeFromDomains)
}
func (manager *Manager) isExpired(record server.CachedRecord, now time.Time) bool {
return now.After(record.ExpiresAt) || now.Equal(record.ExpiresAt)
}
func (manager *Manager) handleExpiredKeys(expiredKeys []interface{}) {
for _, key := range expiredKeys {
if value, exists := manager.DNS.Cache.Load(key); exists {
if cachedDomain, ok := value.(server.CachedRecord); ok {
manager.DNS.Cache.Delete(key)
manager.handleExpiredEntry(cachedDomain)
}
}
}
}
func (manager *Manager) removeNonPrefetchDomains(domains []string) {
for _, domain := range domains {
delete(manager.Domains, domain)
}
}
func (manager *Manager) prefetchDomain(prefetchDomain PrefetchedDomain) {
question := dns.Question{
Name: prefetchDomain.Domain,
Qtype: uint16(prefetchDomain.Qtype),
Qclass: 1,
}
request := &server.Request{
Msg: &dns.Msg{Question: []dns.Question{question}},
Question: question,
Sent: time.Now(),
Prefetch: true,
}
answers, ttl, _ := manager.DNS.QueryUpstream(request)
cacheKey := manager.buildCacheKey(question.Name, dns.Type(question.Qtype))
manager.DNS.CacheRecord(cacheKey, prefetchDomain.Domain, answers, ttl)
}
func (manager *Manager) handleExpiredEntry(record server.CachedRecord) {
domain := record.IPAddresses[0].Header().Name
prefetchDomain, exists := manager.Domains[domain]
if !exists {
log.Debug("%s not set to be prefetched", domain)
return
}
log.Debug("Prefetching expired domain: %s", domain)
manager.prefetchDomain(prefetchDomain)
}
func (manager *Manager) buildCacheKey(domain string, qtype dns.Type) string {
return domain + ":" + strconv.Itoa(int(qtype))
}

View File

@@ -3,22 +3,25 @@ package server
import (
"crypto/tls"
"encoding/json"
"fmt"
"goaway/backend/alert"
"goaway/backend/audit"
"goaway/backend/dns/database"
"goaway/backend/dns/lists"
"goaway/backend/blacklist"
model "goaway/backend/dns/server/models"
"goaway/backend/logging"
notification "goaway/backend/notifications"
"goaway/backend/mac"
"goaway/backend/notification"
"goaway/backend/request"
"goaway/backend/resolution"
"goaway/backend/settings"
"goaway/backend/user"
"goaway/backend/whitelist"
"net"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/miekg/dns"
"gorm.io/gorm"
)
var (
@@ -26,74 +29,66 @@ var (
)
type DNSServer struct {
DBConn *gorm.DB
dnsClient *dns.Client
Config *settings.Config
Blacklist *lists.Blacklist
Whitelist *lists.Whitelist
DBManager *database.DatabaseManager
logIntervalSeconds int
lastLogTime time.Time
Cache sync.Map
clientCache sync.Map
hostnameCache sync.Map
WebServer *gin.Engine
logEntryChannel chan model.RequestLogEntry
WSQueries *websocket.Conn
WSCommunication *websocket.Conn
hostnameCache sync.Map
clientCache sync.Map
DomainCache sync.Map
WSCommunicationLock sync.Mutex
dnsClient *dns.Client
Notifications *notification.Manager
Alerts *alert.Manager
Audits *audit.Manager
RequestService *request.Service
AuditService *audit.Service
UserService *user.Service
AlertService *alert.Service
MACService *mac.Service
ResolutionService *resolution.Service
NotificationService *notification.Service
BlacklistService *blacklist.Service
WhitelistService *whitelist.Service
}
type CachedRecord struct {
IPAddresses []dns.RR
ExpiresAt time.Time
CachedAt time.Time
OriginalTTL uint32
Key string
Domain string
IPAddresses []dns.RR
OriginalTTL uint32
}
type Request struct {
Sent time.Time
ResponseWriter dns.ResponseWriter
Msg *dns.Msg
Question dns.Question
Sent time.Time
Client *model.Client
Prefetch bool
Protocol model.Protocol
Question dns.Question
Prefetch bool
}
type communicationMessage struct {
IP string `json:"ip"`
Client bool `json:"client"`
Upstream bool `json:"upstream"`
DNS bool `json:"dns"`
Ip string `json:"ip"`
}
func NewDNSServer(config *settings.Config, dbManager *database.DatabaseManager, notificationsManager *notification.Manager, alertManager *alert.Manager, auditManager *audit.Manager, cert tls.Certificate) (*DNSServer, error) {
whitelistEntry, err := lists.InitializeWhitelist(dbManager)
if err != nil {
return nil, fmt.Errorf("failed to initialize whitelist: %w", err)
}
func NewDNSServer(config *settings.Config, dbconn *gorm.DB, cert tls.Certificate) (*DNSServer, error) {
var client dns.Client
if cert.Certificate != nil {
client = dns.Client{Net: "tcp-tls"}
}
server := &DNSServer{
Config: config,
Whitelist: whitelistEntry,
DBManager: dbManager,
logIntervalSeconds: 1,
lastLogTime: time.Now(),
logEntryChannel: make(chan model.RequestLogEntry, 1000),
dnsClient: &client,
Notifications: notificationsManager,
Alerts: alertManager,
Audits: auditManager,
Config: config,
DBConn: dbconn,
logEntryChannel: make(chan model.RequestLogEntry, 1000),
dnsClient: &client,
DomainCache: sync.Map{},
}
return server, nil
@@ -101,7 +96,7 @@ func NewDNSServer(config *settings.Config, dbManager *database.DatabaseManager,
func (s *DNSServer) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) != 1 {
log.Warning("Query container more than one question, ignoring!")
log.Warning("Query contains more than one question, ignoring!")
r.SetRcode(r, dns.RcodeFormatError)
_ = w.WriteMsg(r)
return
@@ -114,7 +109,7 @@ func (s *DNSServer) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
Client: true,
Upstream: false,
DNS: false,
Ip: client.IP,
IP: client.IP,
})
entry := s.processQuery(&Request{
@@ -131,7 +126,7 @@ func (s *DNSServer) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
Client: false,
Upstream: false,
DNS: true,
Ip: client.IP,
IP: client.IP,
})
s.logEntryChannel <- entry
}
@@ -154,27 +149,12 @@ func (s *DNSServer) detectProtocol(w dns.ResponseWriter) model.Protocol {
}
func (s *DNSServer) PopulateHostnameCache() error {
type Result struct {
ClientIP string
ClientName string
}
var results []Result
if err := s.DBManager.Conn.
Model(&database.RequestLog{}).
Select("DISTINCT client_ip, client_name").
Where("client_name IS NOT NULL AND client_name != ?", "unknown").
Find(&results).Error; err != nil {
return fmt.Errorf("failed to fetch hostnames: %w", err)
}
for _, r := range results {
if _, exists := s.hostnameCache.Load(r.ClientName); !exists {
s.hostnameCache.Store(r.ClientName, r.ClientIP)
}
uniqueClients := s.RequestService.GetUniqueClientNameAndIP()
for _, client := range uniqueClients {
_, _ = s.hostnameCache.LoadOrStore(client.IP, client.Name)
}
log.Debug("Populated hostname cache with %d client(s)", len(uniqueClients))
return nil
}
@@ -186,12 +166,7 @@ func (s *DNSServer) WSCom(message communicationMessage) {
s.WSCommunicationLock.Lock()
defer s.WSCommunicationLock.Unlock()
if err := s.WSCommunication.WriteControl(
websocket.PingMessage,
nil,
time.Now().Add(2*time.Second),
); err != nil {
log.Debug("Websocket connection not alive, skipping message: %v", err)
if s.WSCommunication == nil {
return
}
@@ -201,11 +176,13 @@ func (s *DNSServer) WSCom(message communicationMessage) {
return
}
if err := s.WSCommunication.SetWriteDeadline(time.Now().Add(2 * time.Second)); err != nil {
if err := s.WSCommunication.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil {
log.Warning("Failed to set websocket write deadline: %v", err)
return
}
if err := s.WSCommunication.WriteMessage(websocket.TextMessage, entryWSJson); err != nil {
log.Debug("Failed to write websocket message: %v", err)
s.WSCommunication = nil
}
}

View File

@@ -0,0 +1,76 @@
package jobs
import (
arp "goaway/backend/dns"
"goaway/backend/logging"
"goaway/backend/services"
)
var log = logging.GetLogger()
type BackgroundJobs struct {
registry *services.ServiceRegistry
}
func NewBackgroundJobs(registry *services.ServiceRegistry) *BackgroundJobs {
return &BackgroundJobs{
registry: registry,
}
}
func (b *BackgroundJobs) Start(readyChan <-chan struct{}) {
b.startHostnameCachePopulation()
b.cleanVendorResponseCache(readyChan)
b.startARPProcessing(readyChan)
b.startScheduledUpdates(readyChan)
b.startCacheCleanup(readyChan)
b.startPrefetcher(readyChan)
}
func (b *BackgroundJobs) startHostnameCachePopulation() {
if err := b.registry.Context.DNSServer.PopulateHostnameCache(); err != nil {
log.Warning("Unable to populate hostname cache: %s", err)
}
}
func (b *BackgroundJobs) startARPProcessing(readyChan <-chan struct{}) {
go func() {
<-readyChan
log.Debug("Starting ARP table processing...")
arp.ProcessARPTable()
}()
}
func (b *BackgroundJobs) cleanVendorResponseCache(readyChan <-chan struct{}) {
go func() {
<-readyChan
log.Debug("Starting vendor response table processing...")
arp.CleanVendorResponseCache()
}()
}
func (b *BackgroundJobs) startScheduledUpdates(readyChan <-chan struct{}) {
go func() {
<-readyChan
if b.registry.Context.Config.Misc.ScheduledBlacklistUpdates {
log.Debug("Starting scheduler for automatic list updates...")
b.registry.BlacklistService.ScheduleAutomaticListUpdates()
}
}()
}
func (b *BackgroundJobs) startCacheCleanup(readyChan <-chan struct{}) {
go func() {
<-readyChan
log.Debug("Starting cache cleanup routine...")
b.registry.Context.DNSServer.ClearOldEntries()
}()
}
func (b *BackgroundJobs) startPrefetcher(readyChan <-chan struct{}) {
go func() {
<-readyChan
log.Debug("Starting prefetcher...")
b.registry.PrefetchService.Run()
}()
}

View File

@@ -0,0 +1,61 @@
package lifecycle
import (
"goaway/backend/jobs"
"goaway/backend/logging"
"goaway/backend/services"
"os"
"os/signal"
"syscall"
)
var log = logging.GetLogger()
// Coordinates startup, shutdown, and signal handling
type Manager struct {
services *services.ServiceRegistry
backgroundJobs *jobs.BackgroundJobs
signalChan chan os.Signal
}
func NewManager(registry *services.ServiceRegistry) *Manager {
return &Manager{
services: registry,
signalChan: make(chan os.Signal, 1),
}
}
func (m *Manager) Run() error {
if err := m.services.Initialize(); err != nil {
return err
}
m.backgroundJobs = jobs.NewBackgroundJobs(m.services)
signal.Notify(m.signalChan, syscall.SIGINT, syscall.SIGTERM)
m.services.StartAll()
m.backgroundJobs.Start(m.services.ReadyChannel())
go m.services.WaitGroup().Wait()
return m.waitForTermination()
}
func (m *Manager) waitForTermination() error {
select {
case err := <-m.services.ErrorChannel():
log.Error("%s server failed: %s", err.Service, err.Err)
log.Fatal("Server failure detected. Exiting.")
return err.Err
case <-m.signalChan:
log.Info("Received interrupt. Shutting down.")
m.shutdown()
return nil
}
}
func (m *Manager) shutdown() {
// TODO: Add graceful shutdown logic
os.Exit(0)
}

View File

@@ -9,11 +9,11 @@ import (
)
const (
ColorReset = "\033[0m"
ColorGray = "\033[90m"
ColorWhite = "\033[97m"
ColorYellow = "\033[33m"
ColorRed = "\033[31m"
colorReset = "\033[0m"
colorGray = "\033[90m"
colorWhite = "\033[97m"
colorYellow = "\033[33m"
colorRed = "\033[31m"
)
type LogLevel int
@@ -27,12 +27,12 @@ const (
)
type Logger struct {
logger *log.Logger
JSONLoggerInstance *slog.Logger
logLevel LogLevel
LoggingEnabled bool
Ansi bool
logger *log.Logger
JSON bool
JSONLoggerInstance *slog.Logger
}
var (
@@ -66,7 +66,7 @@ func (l *Logger) SetLevel(level LogLevel) {
}
}
func (l *Logger) SetJson(json bool) {
func (l *Logger) SetJSON(json bool) {
l.JSON = json
}
@@ -74,10 +74,10 @@ func (l *Logger) SetAnsi(ansi bool) {
l.Ansi = ansi
}
func (l *Logger) log(level string, color string, message string, msgLevel LogLevel) {
func (l *Logger) log(level, color, message string, msgLevel LogLevel) {
if !l.JSON {
if l.Ansi {
l.logger.Printf("%s%s%s%s", color, level, message, ColorReset)
l.logger.Printf("%s%s%s%s", color, level, message, colorReset)
} else {
l.logger.Printf("%s%s", level, message)
}
@@ -91,6 +91,8 @@ func (l *Logger) log(level string, color string, message string, msgLevel LogLev
l.JSONLoggerInstance.Warn(message)
case ERROR:
l.JSONLoggerInstance.Error(message)
default:
l.JSONLoggerInstance.Info(message)
}
}
}
@@ -105,9 +107,9 @@ func (l *Logger) Debug(format string, args ...interface{}) {
}
if len(args) > 0 {
message := fmt.Sprintf(format, args...)
l.log("[DEBUG] ", ColorGray, message, DEBUG)
l.log("[DEBUG] ", colorGray, message, DEBUG)
} else {
l.log("[DEBUG] ", ColorGray, format, DEBUG)
l.log("[DEBUG] ", colorGray, format, DEBUG)
}
}
@@ -117,9 +119,9 @@ func (l *Logger) Info(format string, args ...interface{}) {
}
if len(args) > 0 {
message := fmt.Sprintf(format, args...)
l.log("[INFO] ", ColorWhite, message, INFO)
l.log("[INFO] ", colorWhite, message, INFO)
} else {
l.log("[INFO] ", ColorWhite, format, INFO)
l.log("[INFO] ", colorWhite, format, INFO)
}
}
@@ -129,9 +131,9 @@ func (l *Logger) Warning(format string, args ...interface{}) {
}
if len(args) > 0 {
message := fmt.Sprintf(format, args...)
l.log("[WARN] ", ColorYellow, message, WARNING)
l.log("[WARN] ", colorYellow, message, WARNING)
} else {
l.log("[WARN] ", ColorYellow, format, WARNING)
l.log("[WARN] ", colorYellow, format, WARNING)
}
}
@@ -141,9 +143,9 @@ func (l *Logger) Error(format string, args ...interface{}) {
}
if len(args) > 0 {
message := fmt.Sprintf(format, args...)
l.log("[ERROR] ", ColorRed, message, ERROR)
l.log("[ERROR] ", colorRed, message, ERROR)
} else {
l.log("[ERROR] ", ColorRed, format, ERROR)
l.log("[ERROR] ", colorRed, format, ERROR)
}
}
@@ -153,39 +155,9 @@ func (l *Logger) Fatal(format string, args ...interface{}) {
}
if len(args) > 0 {
message := fmt.Sprintf(format, args...)
l.log("[FATAL] ", ColorRed, message, FATAL)
l.log("[FATAL] ", colorRed, message, FATAL)
} else {
l.log("[FATAL] ", ColorRed, format, FATAL)
l.log("[FATAL] ", colorRed, format, FATAL)
}
os.Exit(1)
}
func FromString(logLevel string) LogLevel {
switch logLevel {
case "DEBUG":
return 0
case "INFO":
return 1
case "WARNING":
return 2
case "ERROR":
return 3
default:
return 1
}
}
func (l LogLevel) String() string {
switch l {
case DEBUG:
return "DEBUG"
case INFO:
return "INFO"
case WARNING:
return "WARNING"
case ERROR:
return "ERROR"
default:
return "UNKNOWN"
}
}

51
backend/mac/repository.go Normal file
View File

@@ -0,0 +1,51 @@
package mac
import (
"errors"
"fmt"
"goaway/backend/database"
"gorm.io/gorm"
)
type Repository interface {
FindVendor(mac string) (string, error)
SaveMac(clientIP, mac, vendor string) error
}
type repository struct {
db *gorm.DB
}
func NewRepository(db *gorm.DB) Repository {
return &repository{db: db}
}
func (r *repository) FindVendor(mac string) (string, error) {
var query database.MacAddress
tx := r.db.Find(&query, "mac = ?", mac)
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return "", nil
}
if tx.Error != nil {
return "", tx.Error
}
return query.Vendor, nil
}
func (r *repository) SaveMac(clientIP, mac, vendor string) error {
entry := database.MacAddress{
MAC: mac,
IP: clientIP,
Vendor: vendor,
}
tx := r.db.Save(&entry)
if tx.Error != nil {
return fmt.Errorf("unable to save new MAC entry %v", tx.Error)
}
return nil
}

24
backend/mac/service.go Normal file
View File

@@ -0,0 +1,24 @@
package mac
import "goaway/backend/logging"
type Service struct {
repository Repository
}
var log = logging.GetLogger()
func NewService(repo Repository) *Service {
return &Service{repository: repo}
}
func (s *Service) FindVendor(mac string) (string, error) {
return s.repository.FindVendor(mac)
}
func (s *Service) SaveMac(clientIP, mac, vendor string) {
err := s.repository.SaveMac(clientIP, mac, vendor)
if err != nil {
log.Warning("Could not save MAC address, %v", err)
}
}

View File

@@ -0,0 +1,53 @@
package notification
import (
"goaway/backend/database"
"gorm.io/gorm"
)
type Repository interface {
CreateNotification(newNotification *database.Notification) error
GetNotifications() ([]database.Notification, error)
MarkNotificationsAsRead(notificationIDs []int) error
}
type repository struct {
db *gorm.DB
}
func NewRepository(db *gorm.DB) Repository {
return &repository{db: db}
}
func (r *repository) CreateNotification(newNotification *database.Notification) error {
tx := r.db.Create(&newNotification)
return tx.Error
}
func (r *repository) GetNotifications() ([]database.Notification, error) {
var notifications []database.Notification
result := r.db.Where("read = ?", false).Find(&notifications)
if result.Error != nil {
return nil, result.Error
}
return notifications, nil
}
func (r *repository) MarkNotificationsAsRead(notificationIDs []int) error {
if len(notificationIDs) == 0 {
return nil
}
result := r.db.Model(&database.Notification{}).
Where("id IN ?", notificationIDs).
Update("read", true)
if result.Error != nil {
return result.Error
}
return nil
}

View File

@@ -0,0 +1,62 @@
package notification
import (
"goaway/backend/database"
"goaway/backend/logging"
)
type Service struct {
repository Repository
}
type Severity string
type Category string
// Severity level of notification
// SeverityInfo: Server was upgraded, password changed...
// SeverityWarning: An error occurred on startup, database lock...
// SeverityError: Server cant start, requests cant be handled...
const (
SeverityInfo Severity = "info"
SeverityWarning Severity = "warning"
SeverityError Severity = "error"
)
// Categories to describe what area the notification covers
const (
CategoryServer Category = "server"
CategoryDNS Category = "dns"
CategoryAPI Category = "api"
)
var log = logging.GetLogger()
func NewService(repo Repository) *Service {
return &Service{repository: repo}
}
func (s *Service) SendNotification(severity Severity, category Category, text string) {
notification := &database.Notification{
Severity: string(severity),
Category: string(category),
Text: text,
Read: false,
}
err := s.repository.CreateNotification(notification)
if err != nil {
log.Warning("Could not send notification, %v", err)
return
}
log.Info("New notification created, severity: %s", severity)
}
func (s *Service) GetNotifications() ([]database.Notification, error) {
return s.repository.GetNotifications()
}
func (s *Service) MarkNotificationsAsRead(notificationIDs []int) error {
log.Info("Notifications have been marked as read")
return s.repository.MarkNotificationsAsRead(notificationIDs)
}

View File

@@ -1,89 +0,0 @@
package notification
import (
"goaway/backend/dns/database"
"goaway/backend/logging"
"time"
)
type Manager struct {
dbManager *database.DatabaseManager
}
type Severity string
type Category string
// Severity level of notification
// SeverityInfo: Server was upgraded, password changed...
// SeverityWarning: An error occurred on startup, database lock...
// SeverityError: Server cant start, requests cant be handled...
const (
SeverityInfo Severity = "info"
SeverityWarning Severity = "warning"
SeverityError Severity = "error"
)
// Categories to describe what area the notification covers
const (
CategoryServer Category = "server"
CategoryDNS Category = "dns"
CategoryAPI Category = "api"
)
type Notification struct {
Id int `json:"id"`
Severity Severity `json:"severity"`
Category Category `json:"category"`
Text string `json:"text"`
Read bool `json:"read"`
CreatedAt time.Time `json:"createdAt"`
}
var logger = logging.GetLogger()
func NewNotificationManager(dbManager *database.DatabaseManager) *Manager {
return &Manager{dbManager: dbManager}
}
func (nm *Manager) CreateNotification(newNotification *Notification) {
tx := nm.dbManager.Conn.Create(&database.Notification{
Severity: string(newNotification.Severity),
Category: string(newNotification.Category),
Text: newNotification.Text,
Read: false,
CreatedAt: time.Now(),
})
if tx.Error != nil {
logger.Warning("Unable to create new notification, error: %v", tx.Error)
return
}
logger.Debug("Created new notification, %+v", newNotification)
}
func (nm *Manager) ReadNotifications() ([]database.Notification, error) {
var notifications []database.Notification
result := nm.dbManager.Conn.Where("read = ?", true).Find(&notifications)
if result.Error != nil {
return nil, result.Error
}
return notifications, nil
}
func (nm *Manager) MarkNotificationsAsRead(notificationIDs []int) error {
if len(notificationIDs) == 0 {
return nil
}
result := nm.dbManager.Conn.Model(&database.Notification{}).
Where("id IN ?", notificationIDs).
Update("read", true)
if result.Error != nil {
return result.Error
}
return nil
}

View File

@@ -0,0 +1,53 @@
package prefetch
import (
"fmt"
"goaway/backend/database"
"gorm.io/gorm"
)
type Repository interface {
GetAll() ([]database.Prefetch, error)
Create(prefetch *database.Prefetch) error
Delete(domain string) error
}
type repository struct {
db *gorm.DB
}
func NewRepository(db *gorm.DB) Repository {
return &repository{db: db}
}
func (r *repository) Create(prefetch *database.Prefetch) error {
result := r.db.Create(prefetch)
if result.Error != nil {
return result.Error
}
return nil
}
func (r *repository) GetAll() ([]database.Prefetch, error) {
var prefetched []database.Prefetch
result := r.db.Model(&database.Prefetch{}).Find(&prefetched)
if result.Error != nil {
return nil, result.Error
}
return prefetched, nil
}
func (r *repository) Delete(domain string) error {
result := r.db.Delete(&database.Prefetch{}, "domain = ?", domain)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("%s does not exist in the database", domain)
}
return nil
}

182
backend/prefetch/service.go Normal file
View File

@@ -0,0 +1,182 @@
package prefetch
import (
"fmt"
"goaway/backend/database"
"goaway/backend/dns/server"
"goaway/backend/logging"
"strconv"
"time"
"github.com/miekg/dns"
)
type Service struct {
repository Repository
DNS *server.DNSServer
Domains map[string]database.Prefetch
}
var log = logging.GetLogger()
func NewService(repo Repository, dnsServer *server.DNSServer) *Service {
service := &Service{
repository: repo,
DNS: dnsServer,
Domains: make(map[string]database.Prefetch),
}
service.LoadPrefetchedDomains()
return service
}
func (s *Service) Run() {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for range ticker.C {
s.checkNewDomains()
s.processExpiredEntries()
}
}
func (s *Service) checkNewDomains() {
for domain, prefetchDomain := range s.Domains {
cacheKey := s.buildCacheKey(domain, dns.Type(prefetchDomain.QueryType))
if _, exists := s.DNS.DomainCache.Load(cacheKey); !exists {
log.Debug("Prefetching new/missing domain: %s", domain)
s.prefetchDomain(prefetchDomain)
}
}
}
func (s *Service) processExpiredEntries() {
now := time.Now()
var expiredKeys []interface{}
var removeFromDomains []string
s.DNS.DomainCache.Range(func(key, value interface{}) bool {
cachedDomain, ok := value.(server.CachedRecord)
if !ok {
log.Debug("Cache entry type assertion failed for key: %v", key)
return true
}
if s.isExpired(cachedDomain, now) {
expiredKeys = append(expiredKeys, key)
if _, isPrefetched := s.Domains[cachedDomain.Domain]; !isPrefetched {
removeFromDomains = append(removeFromDomains, cachedDomain.Domain)
log.Debug("Non-prefetch entry '%v' expired and will be removed", key)
} else {
log.Debug("Prefetch entry '%v' expired and will be refreshed", key)
}
}
return true
})
s.handleExpiredKeys(expiredKeys)
s.removeNonPrefetchDomains(removeFromDomains)
}
func (s *Service) isExpired(record server.CachedRecord, now time.Time) bool {
return now.After(record.ExpiresAt) || now.Equal(record.ExpiresAt)
}
func (s *Service) handleExpiredKeys(expiredKeys []interface{}) {
for _, key := range expiredKeys {
if value, exists := s.DNS.DomainCache.Load(key); exists {
if cachedDomain, ok := value.(server.CachedRecord); ok {
s.DNS.DomainCache.Delete(key)
s.handleExpiredEntry(cachedDomain)
}
}
}
}
func (s *Service) removeNonPrefetchDomains(domains []string) {
for _, domain := range domains {
delete(s.Domains, domain)
}
}
func (s *Service) prefetchDomain(prefetchDomain database.Prefetch) {
question := dns.Question{
Name: prefetchDomain.Domain,
Qtype: uint16(prefetchDomain.QueryType),
Qclass: 1,
}
request := &server.Request{
Msg: &dns.Msg{Question: []dns.Question{question}},
Question: question,
Sent: time.Now(),
Prefetch: true,
}
answers, ttl, _ := s.DNS.QueryUpstream(request)
cacheKey := s.buildCacheKey(question.Name, dns.Type(question.Qtype))
s.DNS.CacheRecord(cacheKey, prefetchDomain.Domain, answers, ttl)
}
func (s *Service) buildCacheKey(domain string, qtype dns.Type) string {
return domain + ":" + strconv.Itoa(int(qtype))
}
func (s *Service) handleExpiredEntry(record server.CachedRecord) {
domain := record.IPAddresses[0].Header().Name
prefetchDomain, exists := s.Domains[domain]
if !exists {
log.Debug("%s not set to be prefetched", domain)
return
}
log.Debug("Prefetching expired domain: %s", domain)
s.prefetchDomain(prefetchDomain)
}
func (s *Service) LoadPrefetchedDomains() {
prefetched, err := s.repository.GetAll()
if err != nil {
log.Error("failed to load prefetched domains: %v", err)
return
}
for _, p := range prefetched {
s.Domains[p.Domain] = p
}
if len(s.Domains) > 0 {
log.Info("Loaded %d prefetched domain(s)", len(s.Domains))
}
}
func (s *Service) AddPrefetchedDomain(domain string, refresh, qtype int) error {
prefetch := database.Prefetch{
Domain: domain,
Refresh: refresh,
QueryType: qtype,
}
err := s.repository.Create(&prefetch)
if err != nil {
return fmt.Errorf("failed to add new domain to prefetch table: %w", err)
}
s.Domains[domain] = prefetch
log.Info("%s was added as a prefetched domain", domain)
return nil
}
func (s *Service) RemovePrefetchedDomain(domain string) error {
err := s.repository.Delete(domain)
if err != nil {
return fmt.Errorf("failed to remove %s from prefetch table: %w", domain, err)
}
delete(s.Domains, domain)
log.Info("%s was removed as a prefetched domain", domain)
return nil
}

15
backend/request/model.go Normal file
View File

@@ -0,0 +1,15 @@
package request
import "time"
type Client struct {
LastSeen time.Time
Name string
Mac string
Vendor string
}
type ClientNameAndIP struct {
Name string
IP string
}

View File

@@ -1,24 +1,58 @@
package database
package request
import (
"context"
"database/sql"
"fmt"
"goaway/backend/api/models"
dbModel "goaway/backend/dns/database/models"
"goaway/backend/database"
model "goaway/backend/dns/server/models"
"goaway/backend/logging"
"strings"
"time"
"gorm.io/gorm"
)
var log = logging.GetLogger()
type Repository interface {
GetClientName(ip string) string
GetDistinctRequestIP() int
GetRequestSummaryByInterval(interval int) ([]model.RequestLogIntervalSummary, error)
GetResponseSizeSummaryByInterval(intervalMinutes int) ([]model.ResponseSizeSummary, error)
GetUniqueQueryTypes() ([]interface{}, error)
FetchQueries(q models.QueryParams) ([]model.RequestLogEntry, error)
GetUniqueClientNameAndIP() []database.RequestLog
FetchAllClients() (map[string]Client, error)
GetClientDetailsWithDomains(clientIP string) (ClientRequestDetails, string, map[string]int, error)
GetTopBlockedDomains(blockedRequests int) ([]map[string]interface{}, error)
GetTopClients() ([]map[string]interface{}, error)
CountQueries(search string) (int, error)
func GetClientNameFromRequestLog(db *gorm.DB, ip string) string {
SaveRequestLog(entries []model.RequestLogEntry) error
DeleteRequestLogsTimebased(vacuum vacuumFunc, requestThreshold, maxRetries int, retryDelay time.Duration) error
}
type ClientRequestDetails struct {
LastSeen string
MostQueriedDomain string
TotalRequests int
UniqueDomains int
BlockedRequests int
CachedRequests int
AvgResponseTimeMs float64
}
type repository struct {
db *gorm.DB
}
func NewRepository(db *gorm.DB) *repository {
return &repository{db: db}
}
func (r *repository) GetClientName(ip string) string {
var hostname string
err := db.Model(&RequestLog{}).
err := r.db.Model(&database.RequestLog{}).
Select("client_name").
Where("client_ip = ? AND client_name != ?", ip, "unknown").
Limit(1).
@@ -31,10 +65,10 @@ func GetClientNameFromRequestLog(db *gorm.DB, ip string) string {
return strings.TrimSuffix(hostname, ".")
}
func GetDistinctRequestIP(db *gorm.DB) int {
func (r *repository) GetDistinctRequestIP() int {
var count int64
err := db.Model(&RequestLog{}).
err := r.db.Model(&database.RequestLog{}).
Select("COUNT(DISTINCT client_ip)").
Scan(&count).Error
if err != nil {
@@ -44,33 +78,39 @@ func GetDistinctRequestIP(db *gorm.DB) int {
return int(count)
}
func GetRequestSummaryByInterval(interval int, db *gorm.DB) ([]model.RequestLogIntervalSummary, error) {
func (r *repository) GetRequestSummaryByInterval(interval int) ([]model.RequestLogIntervalSummary, error) {
minutes := interval * 60
var rawSummaries []model.RequestLogIntervalSummary
err := db.Table("request_logs").
Select(`
DATETIME((STRFTIME('%s', timestamp) / ?) * ?, 'unixepoch') AS interval_start,
type tempSummary struct {
IntervalStartUnix int64 `gorm:"column:interval_start_unix"`
BlockedCount int `gorm:"column:blocked_count"`
CachedCount int `gorm:"column:cached_count"`
AllowedCount int `gorm:"column:allowed_count"`
}
var rawSummaries []tempSummary
query := fmt.Sprintf(`
SELECT
((CAST(STRFTIME('%%s', timestamp) AS INTEGER) / %d) * %d) AS interval_start_unix,
SUM(blocked) AS blocked_count,
SUM(cached) AS cached_count,
SUM(NOT blocked AND NOT cached) AS allowed_count
`, minutes, minutes, minutes).
Where("timestamp >= DATETIME('now', '-1 day')").
Group("(STRFTIME('%s', timestamp) / ?)").
Order("interval_start").
Scan(&rawSummaries).Error
FROM request_logs
WHERE timestamp >= DATETIME('now', '-1 day')
GROUP BY (CAST(STRFTIME('%%s', timestamp) AS INTEGER) / %d)
ORDER BY interval_start_unix ASC
`, minutes, minutes, minutes)
err := r.db.Raw(query).Scan(&rawSummaries).Error
if err != nil {
return nil, err
}
summaries := make([]model.RequestLogIntervalSummary, len(rawSummaries))
for i := range rawSummaries {
t, err := time.Parse("2006-01-02 15:04:05", rawSummaries[i].IntervalStart)
if err != nil {
return nil, err
}
summaries[i] = model.RequestLogIntervalSummary{
IntervalStart: t.String(),
IntervalStart: time.Unix(rawSummaries[i].IntervalStartUnix, 0).Format("2006-01-02 15:04:05"),
BlockedCount: rawSummaries[i].BlockedCount,
CachedCount: rawSummaries[i].CachedCount,
AllowedCount: rawSummaries[i].AllowedCount,
@@ -80,27 +120,38 @@ func GetRequestSummaryByInterval(interval int, db *gorm.DB) ([]model.RequestLogI
return summaries, nil
}
func GetResponseSizeSummaryByInterval(intervalMinutes int, db *gorm.DB) ([]model.ResponseSizeSummary, error) {
func (r *repository) GetResponseSizeSummaryByInterval(intervalMinutes int) ([]model.ResponseSizeSummary, error) {
intervalSeconds := int64(intervalMinutes * 60)
twentyFourHoursAgo := time.Now().Add(-24 * time.Hour).Unix()
twentyFourHoursAgo := time.Now().Add(-24 * time.Hour)
var summaries []model.ResponseSizeSummary
query := `
WITH logs_unix AS (
SELECT
CAST(strftime('%s', timestamp) AS INTEGER) AS ts_unix,
response_size_bytes
FROM request_logs
WHERE timestamp >= ? AND response_size_bytes IS NOT NULL
)
SELECT
((strftime('%s', timestamp) / ?) * ?) AS start_unix,
(ts_unix / ?) * ? AS start_unix,
SUM(response_size_bytes) AS total_size_bytes,
ROUND(AVG(response_size_bytes)) AS avg_response_size_bytes,
MIN(response_size_bytes) AS min_response_size_bytes,
MAX(response_size_bytes) AS max_response_size_bytes
FROM request_logs
WHERE strftime('%s', timestamp) >= ? AND response_size_bytes IS NOT NULL
GROUP BY (strftime('%s', timestamp) / ?)
FROM logs_unix
GROUP BY ts_unix / ?
ORDER BY start_unix ASC
`
err := db.Raw(query, intervalSeconds, intervalSeconds, twentyFourHoursAgo, intervalSeconds).
Scan(&summaries).Error
err := r.db.Raw(query,
twentyFourHoursAgo,
intervalSeconds,
intervalSeconds,
intervalSeconds,
).Scan(&summaries).Error
if err != nil {
return nil, fmt.Errorf("failed to query response size summary: %w", err)
}
@@ -112,7 +163,7 @@ func GetResponseSizeSummaryByInterval(intervalMinutes int, db *gorm.DB) ([]model
return summaries, nil
}
func GetUniqueQueryTypes(db *gorm.DB) ([]interface{}, error) {
func (r *repository) GetUniqueQueryTypes() ([]interface{}, error) {
query := `
SELECT COUNT(*) AS count, query_type
FROM request_logs
@@ -120,7 +171,7 @@ func GetUniqueQueryTypes(db *gorm.DB) ([]interface{}, error) {
GROUP BY query_type
ORDER BY count DESC`
rows, err := db.Raw(query).Rows()
rows, err := r.db.Raw(query).Rows()
if err != nil {
return nil, err
}
@@ -134,8 +185,8 @@ func GetUniqueQueryTypes(db *gorm.DB) ([]interface{}, error) {
var queries []any
for rows.Next() {
query := struct {
Count int `json:"count"`
QueryType string `json:"queryType"`
Count int `json:"count"`
}{}
if err := rows.Scan(&query.Count, &query.QueryType); err != nil {
@@ -152,9 +203,9 @@ func GetUniqueQueryTypes(db *gorm.DB) ([]interface{}, error) {
return queries, nil
}
func FetchQueries(db *gorm.DB, q models.QueryParams) ([]model.RequestLogEntry, error) {
var logs []RequestLog
query := db.Model(&RequestLog{})
func (r *repository) FetchQueries(q models.QueryParams) ([]model.RequestLogEntry, error) {
var logs []database.RequestLog
query := r.db.Model(&database.RequestLog{})
if q.Column == "ip" {
query = query.Joins("LEFT JOIN request_log_ips ri ON request_logs.id = ri.request_log_id")
@@ -191,7 +242,7 @@ func FetchQueries(db *gorm.DB, q models.QueryParams) ([]model.RequestLogEntry, e
results := make([]model.RequestLogEntry, len(logs))
for i, log := range logs {
results[i] = model.RequestLogEntry{
ID: int64(log.ID),
ID: log.ID,
Timestamp: log.Timestamp,
Domain: log.Domain,
Blocked: log.Blocked,
@@ -206,14 +257,14 @@ func FetchQueries(db *gorm.DB, q models.QueryParams) ([]model.RequestLogEntry, e
}
for j, ip := range log.IPs {
results[i].IP[j] = model.ResolvedIP{IP: ip.IP, RType: ip.RType}
results[i].IP[j] = model.ResolvedIP{IP: ip.IP, RType: ip.RecordType}
}
}
return results, nil
}
func FetchAllClients(db *gorm.DB) (map[string]dbModel.Client, error) {
func (r *repository) FetchAllClients() (map[string]Client, error) {
var rows []struct {
ClientIP string `gorm:"column:client_ip"`
ClientName string `gorm:"column:client_name"`
@@ -222,11 +273,11 @@ func FetchAllClients(db *gorm.DB) (map[string]dbModel.Client, error) {
Vendor sql.NullString `gorm:"column:vendor"`
}
subquery := db.Table("request_logs").
subquery := r.db.Table("request_logs").
Select("client_ip, MAX(timestamp) as max_timestamp").
Group("client_ip")
if err := db.Table("request_logs r").
if err := r.db.Table("request_logs r").
Select("r.client_ip, r.client_name, r.timestamp, m.mac, m.vendor").
Joins("INNER JOIN (?) latest ON r.client_ip = latest.client_ip AND r.timestamp = latest.max_timestamp", subquery).
Joins("LEFT JOIN mac_addresses m ON r.client_ip = m.ip").
@@ -234,7 +285,7 @@ func FetchAllClients(db *gorm.DB) (map[string]dbModel.Client, error) {
return nil, err
}
uniqueClients := make(map[string]dbModel.Client, len(rows))
uniqueClients := make(map[string]Client, len(rows))
for _, row := range rows {
macStr := ""
vendorStr := ""
@@ -245,7 +296,7 @@ func FetchAllClients(db *gorm.DB) (map[string]dbModel.Client, error) {
vendorStr = row.Vendor.String
}
uniqueClients[row.ClientIP] = dbModel.Client{
uniqueClients[row.ClientIP] = Client{
Name: row.ClientName,
LastSeen: row.Timestamp,
Mac: macStr,
@@ -256,9 +307,20 @@ func FetchAllClients(db *gorm.DB) (map[string]dbModel.Client, error) {
return uniqueClients, nil
}
func GetClientDetailsWithDomains(db *gorm.DB, clientIP string) (dbModel.ClientRequestDetails, string, map[string]int, error) {
var crd dbModel.ClientRequestDetails
err := db.Table("request_logs").
func (r *repository) GetUniqueClientNameAndIP() []database.RequestLog {
var results []database.RequestLog
r.db.Model(&database.RequestLog{}).
Select("DISTINCT client_ip, client_name").
Where("client_name IS NOT NULL AND client_name != ?", "unknown").
Find(&results)
return results
}
func (r *repository) GetClientDetailsWithDomains(clientIP string) (ClientRequestDetails, string, map[string]int, error) {
var crd ClientRequestDetails
err := r.db.Table("request_logs").
Select(`
COUNT(*) as total_requests,
COUNT(DISTINCT domain) as unique_domains,
@@ -278,7 +340,7 @@ func GetClientDetailsWithDomains(db *gorm.DB, clientIP string) (dbModel.ClientRe
Count int `gorm:"column:query_count"`
}
err = db.Table("request_logs").
err = r.db.Table("request_logs").
Select("domain, COUNT(*) as query_count").
Where("client_ip = ?", clientIP).
Group("domain").
@@ -306,13 +368,13 @@ func GetClientDetailsWithDomains(db *gorm.DB, clientIP string) (dbModel.ClientRe
return crd, mostQueriedDomain, domainQueryCounts, nil
}
func GetTopBlockedDomains(db *gorm.DB, blockedRequests int) ([]map[string]interface{}, error) {
func (r *repository) GetTopBlockedDomains(blockedRequests int) ([]map[string]interface{}, error) {
var rows []struct {
Domain string `gorm:"column:domain"`
Hits int `gorm:"column:hits"`
}
if err := db.Table("request_logs").
if err := r.db.Table("request_logs").
Select("domain, COUNT(*) as hits").
Where("blocked = ?", true).
Group("domain").
@@ -337,9 +399,9 @@ func GetTopBlockedDomains(db *gorm.DB, blockedRequests int) ([]map[string]interf
return topBlockedDomains, nil
}
func GetTopClients(db *gorm.DB) ([]map[string]interface{}, error) {
func (r *repository) GetTopClients() ([]map[string]interface{}, error) {
var total int64
if err := db.Table("request_logs").Count(&total).Error; err != nil {
if err := r.db.Table("request_logs").Count(&total).Error; err != nil {
return nil, err
}
@@ -349,7 +411,7 @@ func GetTopClients(db *gorm.DB) ([]map[string]interface{}, error) {
Frequency float32 `gorm:"column:frequency"`
}
if err := db.Table("request_logs").
if err := r.db.Table("request_logs").
Select("? as frequency, client_ip, COUNT(*) as request_count", 0).
Group("client_ip").
Order("request_count DESC").
@@ -370,18 +432,18 @@ func GetTopClients(db *gorm.DB) ([]map[string]interface{}, error) {
return clients, nil
}
func CountQueries(db *gorm.DB, search string) (int, error) {
func (r *repository) CountQueries(search string) (int, error) {
var total int64
err := db.Table("request_logs").
err := r.db.Table("request_logs").
Where("domain LIKE ?", "%"+search+"%").
Count(&total).Error
return int(total), err
}
func SaveRequestLog(db *gorm.DB, entries []model.RequestLogEntry) error {
return db.Transaction(func(tx *gorm.DB) error {
func (r *repository) SaveRequestLog(entries []model.RequestLogEntry) error {
return r.db.Transaction(func(tx *gorm.DB) error {
for _, entry := range entries {
rl := RequestLog{
rl := database.RequestLog{
Timestamp: entry.Timestamp,
Domain: entry.Domain,
Blocked: entry.Blocked,
@@ -396,41 +458,40 @@ func SaveRequestLog(db *gorm.DB, entries []model.RequestLogEntry) error {
}
for _, resolvedIP := range entry.IP {
rl.IPs = append(rl.IPs, RequestLogIP{
IP: resolvedIP.IP,
RType: resolvedIP.RType,
rl.IPs = append(rl.IPs, database.RequestLogIP{
IP: resolvedIP.IP,
RecordType: resolvedIP.RType,
})
}
if err := tx.Create(&rl).Error; err != nil {
return fmt.Errorf("could not save request log: %v", err)
return fmt.Errorf("could not save request log: %w", err)
}
}
return nil
})
}
type vacuumFunc func()
func DeleteRequestLogsTimebased(vacuum vacuumFunc, db *gorm.DB, requestThreshold, maxRetries int, retryDelay time.Duration) {
func (r *repository) DeleteRequestLogsTimebased(vacuum vacuumFunc, requestThreshold, maxRetries int, retryDelay time.Duration) error {
cutoffTime := time.Now().Add(-time.Duration(requestThreshold) * time.Second)
for retryCount := range maxRetries {
result := db.Where("timestamp < ?", cutoffTime).Delete(&RequestLog{})
result := r.db.Where("timestamp < ?", cutoffTime).Delete(&database.RequestLog{})
if result.Error != nil {
if result.Error.Error() == "database is locked" {
log.Warning("Database is locked; retrying (%d/%d)", retryCount+1, maxRetries)
time.Sleep(retryDelay)
continue
}
log.Error("Failed to clear old entries: %s", result.Error)
break
return fmt.Errorf("failed to clear old entries: %w", result.Error)
}
if affected := result.RowsAffected; affected > 0 {
vacuum()
vacuum(context.Background())
log.Debug("Cleared %d old entries", affected)
}
break
return nil // Success
}
return fmt.Errorf("failed to delete after %d retries", maxRetries)
}

View File

@@ -0,0 +1,89 @@
package request
import (
"context"
"goaway/backend/api/models"
model "goaway/backend/dns/server/models"
"goaway/backend/logging"
"time"
)
type Service struct {
repository Repository
}
var log = logging.GetLogger()
func NewService(repo Repository) *Service {
return &Service{repository: repo}
}
func (s *Service) GetClientNameFromIP(ip string) string {
return s.repository.GetClientName(ip)
}
func (s *Service) GetDistinctRequestIP() int {
return s.repository.GetDistinctRequestIP()
}
func (s *Service) GetRequestSummaryByInterval(interval int) ([]model.RequestLogIntervalSummary, error) {
return s.repository.GetRequestSummaryByInterval(interval)
}
func (s *Service) GetResponseSizeSummaryByInterval(intervalMinutes int) ([]model.ResponseSizeSummary, error) {
return s.repository.GetResponseSizeSummaryByInterval(intervalMinutes)
}
func (s *Service) GetUniqueQueryTypes() ([]interface{}, error) {
return s.repository.GetUniqueQueryTypes()
}
func (s *Service) FetchQueries(q models.QueryParams) ([]model.RequestLogEntry, error) {
return s.repository.FetchQueries(q)
}
func (s *Service) FetchAllClients() (map[string]Client, error) {
return s.repository.FetchAllClients()
}
func (s *Service) GetUniqueClientNameAndIP() []ClientNameAndIP {
queryResult := s.repository.GetUniqueClientNameAndIP()
var uniqueClients []ClientNameAndIP
for _, client := range queryResult {
uniqueClients = append(uniqueClients, ClientNameAndIP{
Name: client.ClientName,
IP: client.ClientIP,
})
}
return uniqueClients
}
func (s *Service) GetClientDetailsWithDomains(clientIP string) (ClientRequestDetails, string, map[string]int, error) {
return s.repository.GetClientDetailsWithDomains(clientIP)
}
func (s *Service) GetTopBlockedDomains(blockedRequests int) ([]map[string]interface{}, error) {
return s.repository.GetTopBlockedDomains(blockedRequests)
}
func (s *Service) GetTopClients() ([]map[string]interface{}, error) {
return s.repository.GetTopClients()
}
func (s *Service) CountQueries(search string) (int, error) {
return s.repository.CountQueries(search)
}
func (s *Service) SaveRequestLog(entries []model.RequestLogEntry) error {
return s.repository.SaveRequestLog(entries)
}
type vacuumFunc func(ctx context.Context)
func (s *Service) DeleteRequestLogsTimebased(vacuum vacuumFunc, requestThreshold, maxRetries int, retryDelay time.Duration) {
if err := s.repository.DeleteRequestLogsTimebased(vacuum, requestThreshold, maxRetries, retryDelay); err != nil {
log.Warning("Error while deleting old request logs: %v", err)
}
}

View File

@@ -0,0 +1,77 @@
package resolution
import (
"errors"
"fmt"
"goaway/backend/database"
"strings"
"gorm.io/gorm"
)
type Repository interface {
CreateResolution(ip, domain string) error
FindResolution(domain string) (string, error)
FindResolutions() ([]database.Resolution, error)
DeleteResolution(ip, domain string) (int, error)
}
type repository struct {
db *gorm.DB
}
func NewRepository(db *gorm.DB) Repository {
return &repository{db: db}
}
func (r *repository) CreateResolution(ip, domain string) error {
res := database.Resolution{
Domain: domain,
IP: ip,
}
if err := r.db.Create(&res).Error; err != nil {
if strings.Contains(err.Error(), "UNIQUE") {
return errors.New("domain already exists, must be unique")
}
return fmt.Errorf("could not create new resolution: %w", err)
}
return nil
}
func (r *repository) FindResolution(domain string) (string, error) {
var res database.Resolution
r.db.Where("domain = ?", domain).Find(&res)
if res.IP != "" {
return res.IP, nil
}
parts := strings.Split(domain, ".")
for i := 1; i < len(parts); i++ {
wildcardDomain := "*." + strings.Join(parts[i:], ".")
if err := r.db.Where("domain = ?", wildcardDomain).Find(&res).Error; err == nil {
return res.IP, nil
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
return "", err
}
}
return "", nil
}
func (r *repository) FindResolutions() ([]database.Resolution, error) {
var resolutions []database.Resolution
if err := r.db.Find(&resolutions).Error; err != nil {
return nil, err
}
return resolutions, nil
}
func (r *repository) DeleteResolution(ip, domain string) (int, error) {
result := r.db.Where("domain = ? AND ip = ?", domain, ip).Delete(&database.Resolution{})
if result.Error != nil {
return 0, result.Error
}
return int(result.RowsAffected), nil
}

View File

@@ -0,0 +1,34 @@
package resolution
import (
"goaway/backend/database"
"goaway/backend/logging"
)
type Service struct {
repository Repository
}
var log = logging.GetLogger()
func NewService(repo Repository) *Service {
return &Service{repository: repo}
}
func (s *Service) CreateResolution(ip, domain string) error {
log.Debug("Creating new resolution '%s' -> '%s'", domain, ip)
return s.repository.CreateResolution(ip, domain)
}
func (s *Service) GetResolution(domain string) (string, error) {
log.Debug("Finding resolution for domain: %s", domain)
return s.repository.FindResolution(domain)
}
func (s *Service) GetResolutions() ([]database.Resolution, error) {
return s.repository.FindResolutions()
}
func (s *Service) DeleteResolution(ip, domain string) (int, error) {
return s.repository.DeleteResolution(ip, domain)
}

View File

@@ -0,0 +1,51 @@
package services
import (
"crypto/tls"
"fmt"
"goaway/backend/database"
"goaway/backend/dns/server"
"goaway/backend/settings"
"gorm.io/gorm"
)
type AppContext struct {
Config *settings.Config
DBConn *gorm.DB
Certificate tls.Certificate
DNSServer *server.DNSServer
}
func NewAppContext(config *settings.Config) (*AppContext, error) {
ctx := &AppContext{Config: config}
if err := ctx.initialize(); err != nil {
return nil, err
}
return ctx, nil
}
func (ctx *AppContext) initialize() error {
ctx.DBConn = database.Initialize()
cert, err := ctx.Config.GetCertificate()
if err != nil {
return fmt.Errorf("failed to get certificate: %w", err)
}
ctx.Certificate = cert
dnsServer, err := server.NewDNSServer(
ctx.Config,
ctx.DBConn,
cert,
)
if err != nil {
return fmt.Errorf("failed to initialize DNS server: %w", err)
}
ctx.DNSServer = dnsServer
go dnsServer.ProcessLogEntries()
return nil
}

View File

@@ -0,0 +1,235 @@
package services
import (
"embed"
"fmt"
"goaway/backend/api"
"goaway/backend/api/key"
"goaway/backend/blacklist"
"goaway/backend/logging"
"goaway/backend/notification"
"goaway/backend/prefetch"
"goaway/backend/request"
"goaway/backend/resolution"
"goaway/backend/user"
"goaway/backend/whitelist"
"net/http"
"sync"
"github.com/miekg/dns"
)
var log = logging.GetLogger()
// Manages all servers and services
type ServiceRegistry struct {
APIServer *api.API
errorChan chan ServiceError
dotServer *dns.Server
readyChan chan struct{}
content embed.FS
udpServer *dns.Server
dohServer *http.Server
tcpServer *dns.Server
Context *AppContext
version string
date string
commit string
wg sync.WaitGroup
ResolutionService *resolution.Service
RequestService *request.Service
PrefetchService *prefetch.Service
UserService *user.Service
KeyService *key.Service
NotificationService *notification.Service
BlacklistService *blacklist.Service
WhitelistService *whitelist.Service
}
type ServiceError struct {
Err error
Service string
}
func NewServiceRegistry(ctx *AppContext, version, commit, date string, content embed.FS) *ServiceRegistry {
return &ServiceRegistry{
Context: ctx,
version: version,
commit: commit,
date: date,
content: content,
readyChan: make(chan struct{}),
errorChan: make(chan ServiceError, 10),
}
}
func (r *ServiceRegistry) Initialize() error {
r.setupDNSServers()
if r.Context.Certificate.Certificate != nil {
if err := r.setupSecureServers(); err != nil {
return err
}
}
r.setupAPIServer()
return nil
}
func (r *ServiceRegistry) setupDNSServers() {
config := r.Context.Config
notifyReady := func() {
log.Info("Started DNS server on: %s:%d", config.DNS.Address, config.DNS.Ports.TCPUDP)
close(r.readyChan)
}
r.udpServer = &dns.Server{
Addr: fmt.Sprintf("%s:%d", config.DNS.Address, config.DNS.Ports.TCPUDP),
Net: "udp",
Handler: r.Context.DNSServer,
ReusePort: true,
UDPSize: config.DNS.UDPSize,
}
r.tcpServer = &dns.Server{
Addr: fmt.Sprintf("%s:%d", config.DNS.Address, config.DNS.Ports.TCPUDP),
Net: "tcp",
Handler: r.Context.DNSServer,
ReusePort: true,
UDPSize: config.DNS.UDPSize,
NotifyStartedFunc: notifyReady,
}
}
func (r *ServiceRegistry) setupSecureServers() error {
dotServer, err := r.Context.DNSServer.InitDoT(r.Context.Certificate)
if err != nil {
return fmt.Errorf("failed to initialize DoT server: %w", err)
}
r.dotServer = dotServer
dohServer, err := r.Context.DNSServer.InitDoH(r.Context.Certificate)
if err != nil {
return fmt.Errorf("failed to initialize DoH server: %w", err)
}
r.dohServer = dohServer
return nil
}
func (r *ServiceRegistry) setupAPIServer() {
r.APIServer = &api.API{
DNS: r.Context.DNSServer,
Authentication: r.Context.Config.API.Authentication,
Config: r.Context.Config,
DNSPort: r.Context.Config.DNS.Ports.TCPUDP,
Version: r.version,
Commit: r.commit,
Date: r.date,
DNSServer: r.Context.DNSServer,
DBConn: r.Context.DBConn,
WSQueries: r.Context.DNSServer.WSQueries,
WSCommunication: r.Context.DNSServer.WSCommunication,
ResolutionService: r.ResolutionService,
RequestService: r.RequestService,
PrefetchService: r.PrefetchService,
NotificationService: r.NotificationService,
UserService: r.UserService,
KeyService: r.KeyService,
BlacklistService: r.BlacklistService,
WhitelistService: r.WhitelistService,
}
}
func (r *ServiceRegistry) StartAll() {
r.startDNSServers()
if r.Context.Certificate.Certificate != nil {
r.startSecureServers()
}
r.startAPIServer()
}
func (r *ServiceRegistry) startDNSServers() {
r.wg.Add(1)
go func() {
defer r.wg.Done()
if err := r.udpServer.ListenAndServe(); err != nil {
r.errorChan <- ServiceError{Service: "UDP", Err: err}
}
}()
r.wg.Add(1)
go func() {
defer r.wg.Done()
if err := r.tcpServer.ListenAndServe(); err != nil {
r.errorChan <- ServiceError{Service: "TCP", Err: err}
}
}()
}
func (r *ServiceRegistry) startSecureServers() {
r.wg.Add(1)
go func() {
defer r.wg.Done()
if err := r.dotServer.ListenAndServe(); err != nil {
r.errorChan <- ServiceError{Service: "DoT", Err: err}
}
}()
r.wg.Add(1)
go func() {
defer r.wg.Done()
if serverIP, err := api.GetServerIP(); err == nil {
log.Info("DoH (dns-over-https) server running at https://%s:%d/dns-query",
serverIP, r.Context.Config.DNS.Ports.DoH)
} else {
log.Info("DoH (dns-over-https) server running on port :%d", r.Context.Config.DNS.Ports.DoH)
}
if err := r.dohServer.ListenAndServeTLS(
r.Context.Config.DNS.TLS.Cert,
r.Context.Config.DNS.TLS.Key,
); err != nil {
r.errorChan <- ServiceError{Service: "DoH", Err: err}
}
}()
}
func (r *ServiceRegistry) startAPIServer() {
r.wg.Add(1)
go func() {
defer r.wg.Done()
<-r.readyChan
errorChan := make(chan struct{}, 1)
go func() {
<-errorChan
r.errorChan <- ServiceError{Service: "API", Err: fmt.Errorf("API server stopped")}
}()
r.APIServer.Start(r.content, errorChan)
}()
}
func (r *ServiceRegistry) WaitGroup() *sync.WaitGroup {
return &r.wg
}
func (r *ServiceRegistry) ReadyChannel() <-chan struct{} {
return r.readyChan
}
func (r *ServiceRegistry) ErrorChannel() <-chan ServiceError {
return r.errorChan
}
func (r *ServiceRegistry) GetPrefetcher() *prefetch.Service {
return r.APIServer.PrefetchService
}

69
backend/settings/model.go Normal file
View File

@@ -0,0 +1,69 @@
package settings
import "time"
type Status struct {
PausedAt time.Time `json:"pausedAt"`
PauseTime time.Time `json:"pauseTime"`
Paused bool `json:"paused"`
}
type TLSConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"`
Cert string `yaml:"cert" json:"cert"`
Key string `yaml:"key" json:"key"`
}
type UpstreamConfig struct {
Preferred string `yaml:"preferred" json:"preferred"`
Fallback []string `yaml:"fallback" json:"fallback"`
}
type PortsConfig struct {
TCPUDP int `yaml:"udptcp" json:"udptcp"`
DoT int `yaml:"dot" json:"dot"`
DoH int `yaml:"doh" json:"doh"`
}
type DNSConfig struct {
Status Status `yaml:"-" json:"status"`
Address string `yaml:"address" json:"address"`
Gateway string `yaml:"gateway" json:"gateway"`
CacheTTL int `yaml:"cacheTTL" json:"cacheTTL"`
UDPSize int `yaml:"udpSize" json:"udpSize"`
TLS TLSConfig `yaml:"tls" json:"tls"`
Upstream UpstreamConfig `yaml:"upstream" json:"upstream"`
Ports PortsConfig `yaml:"ports" json:"ports"`
}
type RateLimitConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"`
MaxTries int `yaml:"maxTries" json:"maxTries"`
Window int `yaml:"window" json:"window"`
}
type APIConfig struct {
Port int `yaml:"port" json:"port"`
Authentication bool `yaml:"authentication" json:"authentication"`
RateLimit RateLimitConfig `yaml:"rateLimit" json:"rateLimit"`
}
type LoggingConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"`
Level int `yaml:"level" json:"level"`
}
type MiscConfig struct {
InAppUpdate bool `yaml:"inAppUpdate" json:"inAppUpdate"`
StatisticsRetention int `yaml:"statisticsRetention" json:"statisticsRetention"`
Dashboard bool `yaml:"dashboard" json:"dashboard"`
ScheduledBlacklistUpdates bool `yaml:"scheduledBlacklistUpdates" json:"scheduledBlacklistUpdates"`
}
type Config struct {
BinaryPath string `yaml:"-" json:"-"`
DNS DNSConfig `yaml:"dns" json:"dns"`
API APIConfig `yaml:"api" json:"api"`
Logging LoggingConfig `yaml:"logging" json:"logging"`
Misc MiscConfig `yaml:"misc" json:"misc"`
}

View File

@@ -3,62 +3,17 @@ package settings
import (
"crypto/tls"
"fmt"
"goaway/backend/api/ratelimit"
"goaway/backend/logging"
"net"
"os"
"path/filepath"
"strconv"
"time"
"gopkg.in/yaml.v3"
)
var log = logging.GetLogger()
type Status struct {
Paused bool `json:"paused"`
PausedAt time.Time `json:"pausedAt"`
PauseTime int `json:"pauseTime"`
}
type DNSConfig struct {
Address string `yaml:"address" json:"address"`
Port int `yaml:"port" json:"port"`
DoTPort int `yaml:"dotPort" json:"dotPort"`
DoHPort int `yaml:"dohPort" json:"dohPort"`
CacheTTL int `yaml:"cacheTTL" json:"cacheTTL"`
PreferredUpstream string `yaml:"preferredUpstream" json:"preferredUpstream"`
Gateway string `yaml:"gateway" json:"gateway"`
UpstreamDNS []string `yaml:"upstreamDNS" json:"upstreamDNS"`
UDPSize int `yaml:"udpSize" json:"udpSize"`
Status Status `yaml:"-" json:"status"`
TLSCertFile string `yaml:"tlsCertFile" json:"tlsCertFile"`
TLSKeyFile string `yaml:"tlsKeyFile" json:"tlsKeyFile"`
}
type APIConfig struct {
Port int `yaml:"port" json:"port"`
Authentication bool `yaml:"authentication" json:"authentication"`
RateLimiterConfig ratelimit.RateLimiterConfig `yaml:"rateLimit" json:"-"`
}
type Config struct {
DNS DNSConfig `yaml:"dns" json:"dns"`
API APIConfig `yaml:"api" json:"api"`
Dashboard bool `yaml:"dashboard" json:"-"`
ScheduledBlacklistUpdates bool `yaml:"scheduledBlacklistUpdates" json:"scheduledBlacklistUpdates"`
StatisticsRetention int `yaml:"statisticsRetention" json:"statisticsRetention"`
LoggingEnabled bool `yaml:"loggingEnabled" json:"loggingEnabled"`
LogLevel logging.LogLevel `yaml:"logLevel" json:"logLevel"`
InAppUpdate bool `yaml:"inAppUpdate" json:"inAppUpdate"`
// settings not visible in config file
BinaryPath string `yaml:"-" json:"-"`
}
func LoadSettings() (Config, error) {
var config Config
@@ -74,6 +29,7 @@ func LoadSettings() (Config, error) {
if err != nil {
return Config{}, err
}
return config, nil
}
data, err := os.ReadFile(path)
@@ -81,43 +37,10 @@ func LoadSettings() (Config, error) {
return Config{}, fmt.Errorf("could not read settings file: %w", err)
}
type configWithPtr struct {
DNS DNSConfig `yaml:"dns" json:"dns"`
API APIConfig `yaml:"api" json:"api"`
Dashboard *bool `yaml:"dashboard" json:"-"`
ScheduledBlacklistUpdates *bool `yaml:"scheduledBlacklistUpdates" json:"scheduledBlacklistUpdates"`
StatisticsRetention int `yaml:"statisticsRetention" json:"statisticsRetention"`
LoggingEnabled bool `yaml:"loggingEnabled" json:"loggingEnabled"`
LogLevel logging.LogLevel `yaml:"logLevel" json:"logLevel"`
InAppUpdate bool `yaml:"inAppUpdate" json:"inAppUpdate"`
}
var temp configWithPtr
if err := yaml.Unmarshal(data, &temp); err != nil {
if err := yaml.Unmarshal(data, &config); err != nil {
return Config{}, fmt.Errorf("invalid settings format: %w", err)
}
config.DNS = temp.DNS
config.API = temp.API
config.StatisticsRetention = temp.StatisticsRetention
config.LoggingEnabled = temp.LoggingEnabled
config.LogLevel = temp.LogLevel
config.InAppUpdate = temp.InAppUpdate
if temp.Dashboard == nil {
// true by default if the Dashboard field was not found in settings.yaml
config.Dashboard = true
} else {
config.Dashboard = *temp.Dashboard
}
if temp.ScheduledBlacklistUpdates == nil {
// false by default if the ScheduledBlacklistUpdates field was not found in settings.yaml
config.ScheduledBlacklistUpdates = false
} else {
config.ScheduledBlacklistUpdates = *temp.ScheduledBlacklistUpdates
}
binaryPath, err := os.Executable()
if err != nil {
log.Warning("Unable to find installed binary path, err: %v", err)
@@ -127,35 +50,84 @@ func LoadSettings() (Config, error) {
return config, nil
}
func (config *Config) Save() {
data, err := yaml.Marshal(config)
if err != nil {
log.Error("Could not parse settings %v", err)
return
}
if err := os.WriteFile("./config/settings.yaml", data, 0644); err != nil {
log.Error("Could not save settings %v", err)
}
}
func (config *Config) Update(updatedSettings Config) {
config.API.Port = updatedSettings.API.Port
config.API.Authentication = updatedSettings.API.Authentication
config.API.RateLimit = updatedSettings.API.RateLimit
config.DNS.Address = updatedSettings.DNS.Address
config.DNS.Ports = updatedSettings.DNS.Ports
config.DNS.UDPSize = updatedSettings.DNS.UDPSize
config.DNS.CacheTTL = updatedSettings.DNS.CacheTTL
config.DNS.TLS = updatedSettings.DNS.TLS
config.DNS.Upstream = updatedSettings.DNS.Upstream
config.Logging = updatedSettings.Logging
config.Misc = updatedSettings.Misc
log.ToggleLogging(config.Logging.Enabled)
log.SetLevel(logging.LogLevel(config.Logging.Level))
config.Save()
}
func createDefaultSettings(filePath string) (Config, error) {
defaultConfig := Config{
StatisticsRetention: 7,
LoggingEnabled: true,
LogLevel: logging.INFO,
InAppUpdate: false,
DNS: DNSConfig{
Address: "0.0.0.0",
Gateway: getDefaultGateway(),
CacheTTL: 3600,
UDPSize: 512,
TLS: TLSConfig{
Enabled: false,
Cert: "",
Key: "",
},
Upstream: UpstreamConfig{
Preferred: "8.8.8.8:53",
Fallback: []string{
"1.1.1.1:53",
},
},
Ports: PortsConfig{
TCPUDP: getEnvAsIntWithDefault("DNS_PORT", 53),
DoT: getEnvAsIntWithDefault("DOT_PORT", 853),
DoH: getEnvAsIntWithDefault("DOH_PORT", 443),
},
},
API: APIConfig{
Port: getEnvAsIntWithDefault("WEBSITE_PORT", 8080),
Authentication: true,
RateLimit: RateLimitConfig{
Enabled: true,
MaxTries: 5,
Window: 5,
},
},
Logging: LoggingConfig{
Enabled: true,
Level: int(logging.INFO),
},
Misc: MiscConfig{
InAppUpdate: false,
StatisticsRetention: 7,
Dashboard: true,
ScheduledBlacklistUpdates: true,
},
}
defaultConfig.DNS.Address = "0.0.0.0"
defaultConfig.DNS.Port = GetEnvAsIntWithDefault("DNS_PORT", 53)
defaultConfig.DNS.DoTPort = GetEnvAsIntWithDefault("DOT_PORT", 853)
defaultConfig.DNS.DoHPort = GetEnvAsIntWithDefault("DOH_PORT", 443)
defaultConfig.DNS.CacheTTL = 3600
defaultConfig.DNS.PreferredUpstream = "8.8.8.8:53"
defaultConfig.DNS.Gateway = GetDefaultGateway()
defaultConfig.DNS.UpstreamDNS = []string{
"1.1.1.1:53",
"8.8.8.8:53",
}
defaultConfig.DNS.UDPSize = 512
defaultConfig.DNS.TLSCertFile = ""
defaultConfig.DNS.TLSKeyFile = ""
defaultConfig.Dashboard = true
defaultConfig.ScheduledBlacklistUpdates = true
defaultConfig.API.Port = GetEnvAsIntWithDefault("WEBSITE_PORT", 8080)
defaultConfig.API.Authentication = true
defaultConfig.API.RateLimiterConfig = ratelimit.RateLimiterConfig{Enabled: true, MaxTries: 5, Window: 5}
data, err := yaml.Marshal(&defaultConfig)
if err != nil {
return Config{}, fmt.Errorf("failed to marshal default config: %w", err)
@@ -174,45 +146,7 @@ func createDefaultSettings(filePath string) (Config, error) {
return defaultConfig, nil
}
func (config *Config) Save() {
data, err := yaml.Marshal(config)
if err != nil {
log.Error("Could not parse settings %v", err)
return
}
if err := os.WriteFile("./config/settings.yaml", data, 0644); err != nil {
log.Error("Could not save settings %v", err)
}
}
func (config *Config) UpdateSettings(updatedSettings Config) {
config.API.Port = updatedSettings.API.Port
config.API.Authentication = updatedSettings.API.Authentication
config.DNS.Address = updatedSettings.DNS.Address
config.DNS.Port = updatedSettings.DNS.Port
config.DNS.DoTPort = updatedSettings.DNS.DoTPort
config.DNS.DoHPort = updatedSettings.DNS.DoHPort
config.DNS.UDPSize = updatedSettings.DNS.UDPSize
config.DNS.CacheTTL = updatedSettings.DNS.CacheTTL
config.DNS.TLSCertFile = updatedSettings.DNS.TLSCertFile
config.DNS.TLSKeyFile = updatedSettings.DNS.TLSKeyFile
config.LogLevel = updatedSettings.LogLevel
config.StatisticsRetention = updatedSettings.StatisticsRetention
config.LoggingEnabled = updatedSettings.LoggingEnabled
config.ScheduledBlacklistUpdates = updatedSettings.ScheduledBlacklistUpdates
config.InAppUpdate = updatedSettings.InAppUpdate
log.ToggleLogging(config.LoggingEnabled)
log.SetLevel(config.LogLevel)
config.Save()
}
func GetEnvAsIntWithDefault(envVariable string, defaultValue int) int {
func getEnvAsIntWithDefault(envVariable string, defaultValue int) int {
val, found := os.LookupEnv(envVariable)
if !found {
return defaultValue
@@ -227,10 +161,10 @@ func GetEnvAsIntWithDefault(envVariable string, defaultValue int) int {
}
func (config *Config) GetCertificate() (tls.Certificate, error) {
if config.DNS.TLSCertFile != "" && config.DNS.TLSKeyFile != "" {
cert, err := tls.LoadX509KeyPair(config.DNS.TLSCertFile, config.DNS.TLSKeyFile)
if config.DNS.TLS.Enabled && config.DNS.TLS.Cert != "" && config.DNS.TLS.Key != "" {
cert, err := tls.LoadX509KeyPair(config.DNS.TLS.Cert, config.DNS.TLS.Key)
if err != nil {
return tls.Certificate{}, fmt.Errorf("failed to load TLS certificate: %s", err)
return tls.Certificate{}, fmt.Errorf("failed to load TLS certificate: %w", err)
}
return cert, nil
@@ -239,7 +173,7 @@ func (config *Config) GetCertificate() (tls.Certificate, error) {
return tls.Certificate{}, nil
}
func GetDefaultGateway() string {
func getDefaultGateway() string {
conn, err := net.Dial("udp", "8.8.8.8:80")
if err != nil {
return "192.168.0.1:53"

View File

@@ -2,18 +2,19 @@ package setup
import (
"fmt"
"goaway/backend/logging"
"goaway/backend/settings"
"os"
"strconv"
"goaway/backend/logging"
"goaway/backend/settings"
"github.com/Masterminds/semver"
)
var log = logging.GetLogger()
type SetFlags struct {
DnsPort *int
DNSPort *int
DoTPort *int
DoHPort *int
WebserverPort *int
@@ -37,15 +38,15 @@ func UpdateConfig(config *settings.Config, flags *SetFlags) {
fmt.Println("Flag --log-level can't be greater than 3 or below 0.")
os.Exit(1)
}
if flags.DnsPort != nil || os.Getenv("DNS_PORT") != "" {
if flags.DNSPort != nil || os.Getenv("DNS_PORT") != "" {
if port, found := os.LookupEnv("DNS_PORT"); found {
dnsPort, err := strconv.Atoi(port)
if err != nil {
log.Fatal("Could not parse DNS_PORT environment variable")
}
config.DNS.Port = dnsPort
config.DNS.Ports.TCPUDP = dnsPort
} else {
config.DNS.Port = *flags.DnsPort
config.DNS.Ports.TCPUDP = *flags.DNSPort
}
}
if flags.DoTPort != nil || os.Getenv("DOT_PORT") != "" {
@@ -54,9 +55,9 @@ func UpdateConfig(config *settings.Config, flags *SetFlags) {
if err != nil {
log.Fatal("Could not parse DOT_PORT environment variable")
}
config.DNS.DoTPort = dotPort
config.DNS.Ports.DoT = dotPort
} else {
config.DNS.DoTPort = *flags.DoTPort
config.DNS.Ports.DoT = *flags.DoTPort
}
}
if flags.DoHPort != nil || os.Getenv("DOH_PORT") != "" {
@@ -65,9 +66,9 @@ func UpdateConfig(config *settings.Config, flags *SetFlags) {
if err != nil {
log.Fatal("Could not parse DOH_PORT environment variable")
}
config.DNS.DoHPort = dohPort
config.DNS.Ports.DoH = dohPort
} else {
config.DNS.DoHPort = *flags.DoHPort
config.DNS.Ports.DoH = *flags.DoHPort
}
}
if flags.WebserverPort != nil || os.Getenv("WEBSITE_PORT") != "" {
@@ -82,27 +83,27 @@ func UpdateConfig(config *settings.Config, flags *SetFlags) {
}
}
if flags.StatisticsRetention != nil {
config.StatisticsRetention = *flags.StatisticsRetention
config.Misc.StatisticsRetention = *flags.StatisticsRetention
}
if flags.Authentication != nil {
config.API.Authentication = *flags.Authentication
}
if flags.Dashboard != nil {
config.Dashboard = *flags.Dashboard
config.Misc.Dashboard = *flags.Dashboard
}
if flags.LoggingEnabled != nil {
config.LoggingEnabled = *flags.LoggingEnabled
config.Logging.Enabled = *flags.LoggingEnabled
}
if flags.LogLevel != nil {
config.LogLevel = logging.LogLevel(*flags.LogLevel)
config.Logging.Level = *flags.LogLevel
}
if flags.InAppUpdate != nil {
config.InAppUpdate = *flags.InAppUpdate
config.Misc.InAppUpdate = *flags.InAppUpdate
}
if flags.JSON != nil {
log.JSON = *flags.JSON
log.SetJson(log.JSON)
log.SetJSON(log.JSON)
} else {
log.Ansi = flags.Ansi == nil || *flags.Ansi
log.SetAnsi(log.Ansi)

View File

@@ -18,15 +18,15 @@ func SelfUpdate(sse sendSSE, binaryPath string) error {
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("failed to create stdout pipe: %v", err)
return fmt.Errorf("failed to create stdout pipe: %w", err)
}
stderrPipe, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("failed to create stderr pipe: %v", err)
return fmt.Errorf("failed to create stderr pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to start command: %v", err)
return fmt.Errorf("failed to start command: %w", err)
}
done := make(chan struct{})
@@ -55,7 +55,7 @@ func SelfUpdate(sse sendSSE, binaryPath string) error {
return nil
case err := <-waitCmd(cmd):
if err != nil {
return fmt.Errorf("update failed: %v", err)
return fmt.Errorf("update failed: %w", err)
}
}

6
backend/user/models.go Normal file
View File

@@ -0,0 +1,6 @@
package user
type User struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
}

View File

@@ -0,0 +1,66 @@
package user
import (
"context"
"errors"
"goaway/backend/database"
"gorm.io/gorm"
)
type Repository interface {
Create(user *database.User) error
FindByUsername(username string) (*User, error)
UpdatePassword(username string, hashedPassword string) error
}
type repository struct {
db *gorm.DB
}
func NewRepository(db *gorm.DB) Repository {
return &repository{db: db}
}
func (r *repository) Create(user *database.User) error {
result := r.db.Create(user)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return errors.New("user creation failed: no rows affected")
}
return nil
}
func (r *repository) FindByUsername(username string) (*User, error) {
var user database.User
err := r.db.Where("username = ?", username).Find(&user).Error
if err != nil {
return nil, err
}
if user.Username == "" {
return nil, errors.New("user not found")
}
return &User{
Username: user.Username,
Password: user.Password,
}, nil
}
func (r *repository) UpdatePassword(username, hashedPassword string) error {
affected, err := gorm.G[database.User](r.db).Where("username = ?", username).Update(context.Background(), "password", hashedPassword)
if err != nil {
return err
}
if affected == 0 {
return errors.New("password update failed: no rows affected")
}
return nil
}

114
backend/user/service.go Normal file
View File

@@ -0,0 +1,114 @@
package user
import (
"errors"
"goaway/backend/database"
"goaway/backend/logging"
"strings"
"golang.org/x/crypto/bcrypt"
)
type Service struct {
repository Repository
}
var log = logging.GetLogger()
func NewService(repo Repository) *Service {
return &Service{repository: repo}
}
func (s *Service) CreateUser(username, password string) error {
log.Info("Creating a new user with name '%s'", username)
hashedPassword, err := hashPassword(password)
if err != nil {
log.Error("Failed to hash password: %v", err)
return err
}
newUser := &database.User{Username: username, Password: hashedPassword}
if err := s.repository.Create(newUser); err != nil {
log.Error("Failed to create user: %v", err)
return err
}
log.Debug("User created successfully")
return nil
}
func (s *Service) Exists(username string) bool {
user, err := s.repository.FindByUsername(username)
if err != nil {
return false
}
if user != nil {
return true
}
return false
}
func (s *Service) Authenticate(username, password string) bool {
user, err := s.repository.FindByUsername(username)
if err != nil {
log.Error("Authentication failed for user '%s': %v", username, err)
return false
}
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
log.Debug("Invalid password for user '%s'", username)
return false
}
log.Debug("User '%s' authenticated successfully", username)
return true
}
func (s *Service) UpdatePassword(username, newPassword string) error {
hashedPassword, err := hashPassword(newPassword)
if err != nil {
log.Error("Failed to hash new password: %v", err)
return err
}
if err := s.repository.UpdatePassword(username, hashedPassword); err != nil {
log.Error("Failed to update password: %v", err)
return err
}
log.Debug("Password updated successfully for user '%s'", username)
return nil
}
func (s *Service) ValidateCredentials(user User) error {
user.Username = strings.TrimSpace(user.Username)
user.Password = strings.TrimSpace(user.Password)
if user.Username == "" || user.Password == "" {
return errors.New("username and password cannot be empty")
}
if len(user.Username) > 60 {
return errors.New("username too long")
}
if len(user.Password) > 120 {
return errors.New("password too long")
}
for _, r := range user.Username {
if r < 32 || r == 127 {
return errors.New("username contains invalid characters")
}
}
return nil
}
func hashPassword(password string) (string, error) {
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
return string(hashed), err
}

View File

@@ -0,0 +1,61 @@
package whitelist
import (
"fmt"
"goaway/backend/database"
"gorm.io/gorm"
)
type Repository interface {
AddDomain(domain string) error
GetDomains() (map[string]bool, error)
RemoveDomain(domain string) error
}
type repository struct {
db *gorm.DB
}
func NewRepository(db *gorm.DB) Repository {
return &repository{db: db}
}
func (r *repository) AddDomain(domain string) error {
result := r.db.Create(&database.Whitelist{Domain: domain})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("%s is already whitelisted", domain)
}
return nil
}
func (r *repository) GetDomains() (map[string]bool, error) {
var records []database.Whitelist
result := r.db.Find(&records)
if result.Error != nil {
return nil, result.Error
}
domainMap := make(map[string]bool)
for _, record := range records {
domainMap[record.Domain] = true
}
return domainMap, nil
}
func (r *repository) RemoveDomain(domain string) error {
result := r.db.Delete(&database.Whitelist{}, "domain = ?", domain)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("%s does not exist", domain)
}
return nil
}

View File

@@ -0,0 +1,59 @@
package whitelist
import "goaway/backend/logging"
type Service struct {
repository Repository
Cache map[string]bool
}
var log = logging.GetLogger()
func NewService(repo Repository) *Service {
service := &Service{
repository: repo,
Cache: map[string]bool{},
}
_, err := service.GetDomains() // Preload cache
if err != nil {
log.Warning("Could not preload domains cache, %v", err)
}
return service
}
func (s *Service) AddDomain(domain string) error {
err := s.repository.AddDomain(domain)
if err != nil {
return err
}
s.Cache[domain] = true
return nil
}
func (s *Service) GetDomains() (map[string]bool, error) {
domains, err := s.repository.GetDomains()
if err != nil {
return nil, err
}
s.Cache = domains
return domains, nil
}
func (s *Service) RemoveDomain(domain string) error {
err := s.repository.RemoveDomain(domain)
if err != nil {
return err
}
delete(s.Cache, domain)
return nil
}
func (s *Service) IsWhitelisted(domain string) bool {
_, exists := s.Cache[domain]
return exists
}

View File

@@ -17,5 +17,8 @@
"lib": "@/lib",
"hooks": "@/hooks"
},
"iconLibrary": "phosphor-icons/react"
"iconLibrary": "phosphor-icons/react",
"registries": {
"@magicui": "https://magicui.design/r/{name}.json"
}
}

View File

@@ -12,22 +12,8 @@
},
"dependencies": {
"@phosphor-icons/react": "^2.1.10",
"@radix-ui/react-checkbox": "^1.3.3",
"@radix-ui/react-collapsible": "^1.1.12",
"@radix-ui/react-dialog": "^1.1.15",
"@radix-ui/react-dropdown-menu": "^2.1.16",
"@radix-ui/react-label": "^2.1.7",
"@radix-ui/react-popover": "^1.1.15",
"@radix-ui/react-scroll-area": "^1.2.10",
"@radix-ui/react-select": "^2.2.6",
"@radix-ui/react-separator": "^1.1.7",
"@radix-ui/react-slider": "^1.3.6",
"@radix-ui/react-slot": "^1.2.3",
"@radix-ui/react-switch": "^1.2.6",
"@radix-ui/react-tabs": "^1.1.13",
"@radix-ui/react-toggle": "^1.1.10",
"@radix-ui/react-toggle-group": "^1.1.11",
"@radix-ui/react-tooltip": "^1.2.8",
"@tanstack/react-table": "^8.21.3",
"@tsparticles/engine": "^3.9.1",
"@tsparticles/preset-stars": "^3.2.0",
@@ -37,35 +23,41 @@
"clsx": "^2.1.1",
"cmdk": "1.1.1",
"compare-versions": "^6.1.1",
"motion": "^12.23.12",
"lucide-react": "^0.546.0",
"motion": "^12.23.24",
"next-themes": "^0.4.6",
"react": "^19.1.1",
"react-dom": "^19.1.1",
"react-force-graph-2d": "^1.28.0",
"react-router-dom": "^7.8.2",
"radix-ui": "^1.4.3",
"react": "^19.2.0",
"react-dom": "^19.2.0",
"react-force-graph-2d": "^1.29.0",
"react-router-dom": "^7.9.4",
"react-timeago": "^8.3.0",
"recharts": "^3.1.2",
"recharts": "^3.3.0",
"sonner": "^2.0.7",
"tailwind-merge": "^3.3.1",
"tw-animate-css": "^1.3.8"
"tw-animate-css": "^1.4.0"
},
"devDependencies": {
"@commitlint/cli": "^19.8.1",
"@commitlint/config-conventional": "^19.8.1",
"babel-plugin-react-compiler": "^1.0.0",
"@commitlint/cli": "^20.1.0",
"@commitlint/config-conventional": "^20.0.0",
"@commitlint/types": "^20.0.0",
"@eslint/js": "^9.34.0",
"@tailwindcss/vite": "^4.1.12",
"@types/react": "^19.1.12",
"@types/react-dom": "^19.1.9",
"@vitejs/plugin-react": "^5.0.2",
"eslint": "^9.34.0",
"eslint-plugin-react-hooks": "^5.2.0",
"eslint-plugin-react-refresh": "^0.4.20",
"globals": "^16.3.0",
"tailwindcss": "^4.1.12",
"typescript": "~5.9.2",
"typescript-eslint": "^8.42.0",
"vite": "^7.1.4"
"@eslint/js": "^9.38.0",
"@radix-ui/react-dialog": "^1.1.15",
"@radix-ui/react-select": "^2.2.6",
"@tailwindcss/vite": "^4.1.15",
"@types/canvas-confetti": "^1.9.0",
"@types/react": "^19.2.2",
"@types/react-dom": "^19.2.2",
"@vitejs/plugin-react": "^5.0.4",
"eslint": "^9.38.0",
"eslint-plugin-react-hooks": "^7.0.0",
"eslint-plugin-react-refresh": "^0.4.24",
"globals": "^16.4.0",
"tailwindcss": "^4.1.15",
"typescript": "~5.9.3",
"typescript-eslint": "^8.46.2",
"vite": "^7.1.11"
},
"pnpm": {
"onlyBuiltDependencies": [

2975
client/pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 7.5 KiB

BIN
client/public/gray-icon.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.3 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 27 KiB

After

Width:  |  Height:  |  Size: 3.5 MiB

View File

@@ -1,18 +1,18 @@
import { Button } from "@/components/ui/button";
import { Dialog, DialogContent, DialogTitle } from "@/components/ui/dialog";
import { Tabs, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { ClientEntry } from "@/pages/clients";
import { GetRequest } from "@/util";
import {
BirdIcon,
CaretDownIcon,
CaretRightIcon,
EyeIcon,
ClockCounterClockwiseIcon,
EyeglassesIcon,
IdentificationBadgeIcon,
LightningIcon,
LineSegmentsIcon,
PlusMinusIcon,
RowsIcon,
ShieldIcon,
SparkleIcon
SparkleIcon,
TargetIcon
} from "@phosphor-icons/react";
import { useEffect, useState } from "react";
import { toast } from "sonner";
@@ -79,75 +79,76 @@ export function CardDetails({
return (
<Dialog open onOpenChange={onClose}>
<DialogContent className="border-none bg-accent rounded-lg w-full max-w-5xl mx-auto p-0 overflow-hidden max-h-[90vh] flex flex-col">
<div className="p-4 sm:p-6 border-b">
<DialogTitle>
<div className="flex items-center justify-between">
<div>
<h2 className="text-xl sm:text-2xl font-bold mb-1">
{clientEntry.name || "Unnamed Device"}
</h2>
<div className="flex items-center text-sm">
<span className="bg-blue-400 px-2 py-0.5 rounded-full font-medium">
{clientEntry.ip}
<DialogContent className="border-none bg-accent rounded-lg w-full max-w-6xl overflow-hidden">
<DialogTitle>
<div className="flex items-center justify-between">
<div>
<h2 className="text-xl sm:text-2xl font-bold mb-1">
{clientEntry.name || "unknown"}
</h2>
<div className="flex items-center text-sm gap-2">
<span className="bg-muted-foreground/20 px-2 py-0.5 rounded-md font-mono text-xs">
ip: {clientEntry.ip}
</span>
{clientEntry.mac && (
<span className="bg-muted-foreground/20 px-2 py-0.5 rounded-md font-mono text-xs">
mac: {clientEntry.mac}
</span>
{clientEntry.mac && (
<span className="ml-2 flex items-center">
<IdentificationBadgeIcon size={14} className="mr-1" />
{clientEntry.mac}
</span>
)}
{clientEntry.vendor && (
<span className="ml-2 opacity-75">
{clientEntry.vendor}
</span>
)}
</div>
</div>
<div className="text-right hidden sm:block">
<span className="text-xs">Last Activity</span>
<div className="text-muted-foreground">
{formatTimeAgo(clientEntry.lastSeen)}
</div>
)}
{clientEntry.vendor && (
<span className="bg-muted-foreground/20 px-2 py-0.5 rounded-md font-mono text-xs">
vendor: {clientEntry.vendor}
</span>
)}
</div>
</div>
</DialogTitle>
</div>
<div className="text-right hidden sm:block">
<span className="text-xs">Last Activity</span>
<div className="text-muted-foreground">
{formatTimeAgo(clientEntry.lastSeen)}
</div>
</div>
</div>
</DialogTitle>
<div className="flex bg-background border-b">
<button
className={`px-4 py-2 text-sm font-medium flex items-center ${
activeTab === "overview"
? "text-blue-400 border-b-2 border-blue-500"
: "text-muted-foreground hover:font-bold hover:border-b-2 border-stone-500 cursor-pointer"
}`}
onClick={() => setActiveTab("overview")}
>
<BirdIcon size={16} className="mr-2" />
Overview
</button>
<button
className={`px-4 py-2 text-sm font-medium flex items-center ${
activeTab === "domains"
? "text-blue-400 border-b-2 border-blue-500"
: "text-muted-foreground hover:font-bold hover:border-b-2 border-stone-500 cursor-pointer"
}`}
onClick={() => setActiveTab("domains")}
>
<LineSegmentsIcon size={16} className="mr-2" />
Domains
</button>
</div>
<Tabs defaultValue="overview">
<TabsList className="bg-transparent space-x-2">
<TabsTrigger
value="overview"
className="border-l-0 !bg-transparent border-t-0 border-r-0 cursor-pointer data-[state=active]:border-b-2 data-[state=active]:!border-b-primary rounded-none"
onClick={() => setActiveTab("overview")}
>
<TargetIcon />
Overview
</TabsTrigger>
<TabsTrigger
value="domains"
className="border-l-0 !bg-transparent border-t-0 border-r-0 cursor-pointer data-[state=active]:border-b-2 data-[state=active]:!border-b-primary rounded-none"
onClick={() => setActiveTab("domains")}
>
<RowsIcon />
Domains
</TabsTrigger>
<TabsTrigger
value="history"
className="border-l-0 !bg-transparent border-t-0 border-r-0 cursor-pointer data-[state=active]:border-b-2 data-[state=active]:!border-b-primary rounded-none"
onClick={() => setActiveTab("history")}
>
<ClockCounterClockwiseIcon />
History
</TabsTrigger>
</TabsList>
</Tabs>
{isLoading ? (
<div className="flex justify-center items-center p-16">
<div className="animate-spin rounded-full h-12 w-12 border-t-2 border-b-2 border-blue-500" />
</div>
) : clientDetails ? (
<div className="overflow-y-auto p-4 sm:p-6 flex-grow">
<div className="overflow-y-auto sm:p-2 grow">
{activeTab === "overview" && (
<>
<div className="grid grid-cols-2 sm:grid-cols-3 lg:grid-cols-6 gap-3 mb-6">
<div className="grid grid-cols-2 sm:grid-cols-3 lg:grid-cols-6 gap-2 mb-4">
<StatCard
icon={<EyeglassesIcon size={18} />}
label="Total Requests"
@@ -181,17 +182,14 @@ export function CardDetails({
<StatCard
icon={<CaretDownIcon size={18} />}
label="Most Queried"
value={clientDetails.mostQueriedDomain.split(".")[0]}
value={clientDetails.mostQueriedDomain}
color="bg-indigo-600"
/>
</div>
<div className="mb-6">
<h3 className="text-lg font-semibold mb-3 flex items-center">
<EyeIcon size={18} className="mr-2 text-blue-400" />
Top Queried Domains
</h3>
<div className="grid gap-3">
<p className="mb-2">Top Queried Domains</p>
<div className="grid gap-2">
{Object.entries(clientDetails.allDomains)
.sort((a, b) => b[1] - a[1])
.slice(0, 5)
@@ -208,10 +206,10 @@ export function CardDetails({
<div className="w-12 text-center font-mono py-1 rounded text-xs font-medium">
{count}
</div>
<div className="ml-3 flex-grow font-medium truncate">
<div className="ml-3 grow font-medium truncate">
{domain}
</div>
<div className="w-24 flex-shrink-0">
<div className="w-24 shrink-0">
<div className="h-2 bg-accent rounded-full w-full">
<div
className={`h-2 rounded-full ${getProgressColor(
@@ -227,49 +225,37 @@ export function CardDetails({
);
})}
</div>
<div className="mt-2 text-center">
<button
className="text-blue-400 hover:text-blue-300 text-sm flex items-center mx-auto"
onClick={() => setActiveTab("domains")}
>
View all domains <CaretRightIcon size={16} />
</button>
</div>
</div>
<div className="grid grid-cols-1 sm:grid-cols-3 gap-3 mt-4">
<ActionButton
label="[WIP] View Details"
bgClass="bg-blue-500 text-white"
/>
<ActionButton
label="[WIP] Block Device"
bgClass="bg-red-500 text-white"
/>
<ActionButton
label="[WIP] Device Settings"
bgClass="bg-stone-500 text-white"
/>
<Button variant="outline" disabled={true}>
[WIP] View Details
</Button>
<Button variant="outline" disabled={true}>
[WIP] Block Device
</Button>
<Button variant="outline" disabled={true}>
[WIP] Device Settings
</Button>
</div>
</>
)}
{activeTab === "domains" && (
<div>
<h3 className="text-lg font-semibold mb-3 flex items-center">
<EyeIcon size={18} className="mr-2 text-blue-400" />
<p className="mb-2">
All Queried Domains
<span className="ml-2 text-xs bg-accent px-2 py-0.5 rounded-full text-muted-foreground">
{Object.keys(clientDetails.allDomains).length} domains
</span>
</h3>
</p>
<div className="shadow-md border rounded-md overflow-hidden">
<div className="flex justify-between items-center py-2 px-3">
<div className="w-16 text-xs text-muted-foreground font-medium">
Count
</div>
<div className="flex-grow text-xs text-muted-foreground font-medium">
<div className="grow text-xs text-muted-foreground font-medium">
Domain
</div>
<div className="w-24 text-xs text-muted-foreground font-medium">
@@ -294,7 +280,7 @@ export function CardDetails({
<div className="w-16 font-mono bg-accent py-1 rounded text-center text-xs font-medium">
{count}
</div>
<div className="ml-3 flex-grow font-medium truncate">
<div className="ml-3 grow font-medium truncate">
{domain}
</div>
<div className="w-24 text-right text-muted-foreground text-sm">
@@ -309,7 +295,7 @@ export function CardDetails({
)}
</div>
) : (
<div className="text-center py-16 flex-grow flex flex-col items-center justify-center">
<div className="text-center py-16 grow flex flex-col items-center justify-center">
<ShieldIcon size={48} className="mb-4" />
<div className="text-lg">No data available for this client</div>
<div className="text-sm mt-2 text-muted-foreground">
@@ -326,23 +312,12 @@ function StatCard({ icon, label, value, color }) {
return (
<div className="rounded-sm shadow-md bg-background">
<div className={`${color} h-1`}></div>
<div className="p-3">
<div className="flex items-center text-xs text-muted-foreground mb-1">
<span className="mr-1">{icon}</span>
{label}
<div className="p-2">
<div className="flex items-center text-xs text-muted-foreground mb-1 gap-1">
{icon} {label}
</div>
<div className="font-bold text-lg truncate">{value}</div>
<div className="font-bold text-sm truncate">{value}</div>
</div>
</div>
);
}
function ActionButton({ label, bgClass }) {
return (
<button
className={`${bgClass} px-4 py-3 rounded-md text-sm font-medium transition-all shadow-md hover:shadow-lg`}
>
{label}
</button>
);
}

View File

@@ -7,6 +7,7 @@ import { XIcon } from "@phosphor-icons/react";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import ForceGraph2D, { ForceGraphMethods } from "react-force-graph-2d";
import { CardDetails } from "./details";
import { toast } from "sonner";
interface ClientEntry {
ip: string;
@@ -129,8 +130,8 @@ export default function DNSServerVisualizer() {
const [viewSettings, setViewSettings] = useState<ViewSettings>({
clusterBySubnet: false,
hideInactiveClients: false,
minNodeSize: 3,
maxNodeSize: 12,
minNodeSize: 2,
maxNodeSize: 10,
showLabels: true,
activityThresholdMinutes: 60
});
@@ -235,7 +236,7 @@ export default function DNSServerVisualizer() {
setError(
err instanceof Error ? err.message : "Failed to fetch clients"
);
console.error("Error fetching clients:", err);
toast.warning("Error fetching clients", { description: `${err}` });
}
};
@@ -289,7 +290,9 @@ export default function DNSServerVisualizer() {
});
}
} catch (error) {
console.error("Error handling WebSocket message:", error);
toast.warning("Error handling WebSocket message", {
description: `${error}`
});
}
};
@@ -319,14 +322,14 @@ export default function DNSServerVisualizer() {
id: "dns-server",
name: "DNS Server",
type: "server",
color: "#3b82f6",
color: "cornflowerblue",
size: viewSettings.maxNodeSize
},
{
id: "upstream",
name: "Upstream",
type: "server",
color: "#008000",
color: "teal",
size: viewSettings.maxNodeSize
}
];
@@ -557,7 +560,7 @@ export default function DNSServerVisualizer() {
<div
ref={containerRef}
className="rounded-xl shadow-md p-4 w-full border-1 dark:bg-accent"
className="rounded-xl shadow-md p-4 w-full border dark:bg-accent"
>
<div className="grid grid-cols-4 gap-2 text-sm mb-4">
{[
@@ -583,7 +586,7 @@ export default function DNSServerVisualizer() {
).length
}
].map(({ label, plural, value }) => (
<div key={label} className="rounded-lg py-0.5 text-center border-1">
<div key={label} className="rounded-lg py-0.5 text-center border">
<p className="text-sm font-medium">{value}</p>
<p className="text-xs text-muted-foreground">
{value === 1 ? label : plural}

View File

@@ -1,8 +1,9 @@
"use client";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import { NoContent } from "@/shared";
import { GetRequest } from "@/util";
import { ArticleIcon, WarningIcon } from "@phosphor-icons/react";
import { ArticleIcon } from "@phosphor-icons/react";
import { useEffect, useState } from "react";
import { toast } from "sonner";
@@ -87,16 +88,9 @@ export default function Audit() {
))}
</div>
) : (
<EmptyState />
<NoContent text={"No audits created"} />
)}
</CardContent>
</Card>
);
}
const EmptyState = () => (
<div className="grid place-items-center py-8">
<WarningIcon size={32} className="text-destructive mb-2" />
<p className="text-sm text-muted-foreground">No audit entries found</p>
</div>
);

View File

@@ -2,7 +2,7 @@
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import { GetRequest } from "@/util";
import { NetworkSlashIcon, WarningIcon } from "@phosphor-icons/react";
import { NetworkSlashIcon } from "@phosphor-icons/react";
import { useEffect, useState, useRef } from "react";
import {
Bar,
@@ -20,6 +20,7 @@ import {
NameType,
ValueType
} from "recharts/types/component/DefaultTooltipContent";
import { NoContent } from "@/shared";
type TopBlockedDomains = {
frequency: number;
@@ -58,17 +59,6 @@ const CustomTooltip = ({
return null;
};
const EmptyState = () => (
<div className="flex flex-col items-center justify-center h-full w-full py-10">
<div className="mb-4">
<WarningIcon size={36} className="text-destructive" />
</div>
<p className="text-muted-foreground text-sm text-center">
Blocked domains will appear here when detected
</p>
</div>
);
const isNewData = (a: TopBlockedDomains[], b: TopBlockedDomains[]): boolean => {
if (a.length !== b.length) return false;
@@ -211,7 +201,7 @@ export default function FrequencyChartBlockedDomains() {
</BarChart>
</ResponsiveContainer>
) : (
<EmptyState />
<NoContent text={"No domain has been blocked"} />
)}
</CardContent>
</Card>

View File

@@ -2,8 +2,9 @@
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import { Tabs, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { NoContent } from "@/shared";
import { GetRequest } from "@/util";
import { UsersIcon, WarningIcon } from "@phosphor-icons/react";
import { UsersIcon } from "@phosphor-icons/react";
import { useEffect, useState, useRef } from "react";
import {
Bar,
@@ -58,17 +59,6 @@ const CustomTooltip = ({
return null;
};
const EmptyState = () => (
<div className="flex flex-col items-center justify-center h-full w-full py-10">
<div className="mb-4">
<WarningIcon size={36} className="text-destructive" />
</div>
<p className="text-muted-foreground text-sm text-center">
Client data will appear here when requests are detected
</p>
</div>
);
const isNewData = (a: TopBlockedClients[], b: TopBlockedClients[]): boolean => {
if (a.length !== b.length) return false;
@@ -214,7 +204,7 @@ export default function FrequencyChartTopBlockedClients() {
</BarChart>
</ResponsiveContainer>
) : (
<EmptyState />
<NoContent text={"No client data to show"} />
)}
</CardContent>
</Card>

View File

@@ -6,6 +6,7 @@ import {
ChartLegendContent,
ChartTooltip
} from "@/components/ui/chart";
import { NoContent } from "@/shared";
import { GetRequest } from "@/util";
import {
ArrowsClockwiseIcon,
@@ -93,9 +94,15 @@ export default function ResponseSizeTimeline() {
}, [timelineInterval]);
useEffect(() => {
fetchData();
const timer = setTimeout(() => {
fetchData();
}, 0);
const interval = setInterval(fetchData, 10000);
return () => clearInterval(interval);
return () => {
clearTimeout(timer);
clearInterval(interval);
};
}, [fetchData]);
const getFilteredData = () => {
@@ -207,7 +214,7 @@ export default function ResponseSizeTimeline() {
<div className="flex gap-2">
{isZoomed && (
<Button
className="bg-transparent border-1 text-white hover:bg-stone-800"
className="bg-transparent border text-white hover:bg-stone-800"
onClick={handleZoomOut}
>
<MagnifyingGlassMinusIcon weight="bold" className="mr-1" />
@@ -444,13 +451,8 @@ export default function ResponseSizeTimeline() {
</CardContent>
</>
) : (
<CardContent className="flex h-[300px] items-center justify-center">
<div className="text-center">
<p className="text-lg font-medium">No data available</p>
<p className="text-sm text-muted-foreground">
No response size data recorded yet
</p>
</div>
<CardContent className="flex h-[220px] items-center justify-center">
<NoContent text={"No requests recorded"} />
</CardContent>
)}
</Card>

View File

@@ -27,11 +27,11 @@ import {
ArrowsClockwiseIcon,
ChartLineIcon,
MagnifyingGlassMinusIcon,
MagnifyingGlassPlusIcon,
WarningIcon
MagnifyingGlassPlusIcon
} from "@phosphor-icons/react";
import { useCallback, useEffect, useState } from "react";
import { Button } from "../../components/ui/button";
import { NoContent } from "@/shared";
const chartConfig = {
blocked: {
@@ -90,9 +90,14 @@ export default function RequestTimeline() {
}, [timelineInterval]);
useEffect(() => {
fetchData();
const interval = setInterval(fetchData, 10000);
return () => clearInterval(interval);
const timeout = window.setTimeout(() => {
fetchData();
}, 0);
const interval = window.setInterval(fetchData, 10000);
return () => {
window.clearTimeout(timeout);
window.clearInterval(interval);
};
}, [fetchData]);
const getFilteredData = () => {
@@ -196,7 +201,7 @@ export default function RequestTimeline() {
<div className="flex gap-2">
{isZoomed && (
<Button
className="bg-transparent border-1 text-white hover:bg-stone-800"
className="bg-transparent border text-white hover:bg-stone-800"
onClick={handleZoomOut}
>
<MagnifyingGlassMinusIcon weight="bold" className="mr-1" />
@@ -411,15 +416,8 @@ export default function RequestTimeline() {
</CardContent>
</>
) : (
<CardContent className="flex h-[300px] items-center justify-center">
<div className="flex flex-col items-center justify-center">
<div className="mb-4">
<WarningIcon size={36} className="text-destructive" />
</div>
<p className="text-sm text-muted-foreground">
No requests recorded yet
</p>
</div>
<CardContent className="flex h-[220px] items-center justify-center">
<NoContent text={"No requests recorded"} />
</CardContent>
)}
</Card>

View File

@@ -13,6 +13,7 @@ import {
SelectTrigger,
SelectValue
} from "@/components/ui/select";
import { NoContent } from "@/shared";
import { GetRequest } from "@/util";
import { SetStateAction, useEffect, useState } from "react";
import {
@@ -144,12 +145,7 @@ export default function RequestTypeChart() {
</CardContent>
) : (
<CardContent className="flex h-[200px] items-center justify-center">
<div className="text-center">
<p className="text-sm font-medium">No data available</p>
<p className="text-xs text-muted-foreground">
No query types have yet been identified
</p>
</div>
<NoContent text={"No query types has been identified"} />
</CardContent>
)}
</Card>

View File

@@ -20,8 +20,8 @@ async function removeDomain(domain: string) {
}
}
export default function BlockedDomainsList({ listName }) {
const [domains, setDomains] = useState([]);
export default function BlockedDomainsList({ listName }: { listName: string }) {
const [domains, setDomains] = useState<string[]>([]);
const [searchTerm, setSearchTerm] = useState("");
const [loading, setLoading] = useState(true);

View File

@@ -32,7 +32,7 @@ export function ListCard(
...listEntry
} = props;
const formattedDate = new Date(listEntry.lastUpdated * 1000).toLocaleString(
const formattedDate = new Date(listEntry.lastUpdated).toLocaleString(
"en-US",
{
month: "short",

Some files were not shown because too many files have changed in this diff Show More