mirror of
https://github.com/pommee/goaway.git
synced 2026-01-11 00:20:31 -06:00
fix: restructure codebase, make setup and flow easier
This commit is contained in:
@@ -5,7 +5,7 @@ tmp_dir = "tmp"
|
||||
bin = "./tmp/goaway"
|
||||
args_bin = [ "--dns-port=6121", "--dot-port=6122", "--doh-port=9012", "--webserver-port=8080", "--log-level=0", "--logging=true", "--auth=false", "--statistics-retention=7", "--dashboard=false", "--ansi=true" ]
|
||||
cmd = 'go build -o goaway -ldflags="-X main.version=0.0.0 -X main.commit=ead2d7830add26d53ecab3c907a290f0cdc1e078 -X main.date=2025-04-11T13:37:56Z" -o ./tmp/goaway .'
|
||||
exclude_dir = [ "assets", "tmp", "vendor", "client", "test", "resources" ]
|
||||
exclude_dir = [ "assets", "tmp", "vendor", "client", "test", "resources", "config", "data" ]
|
||||
include_ext = [ "go", "tpl", "tmpl", "html", "css", "js", "jsx", "ts", "tsx" ]
|
||||
|
||||
[color]
|
||||
|
||||
34
.github/workflows/docs.yml
vendored
Normal file
34
.github/workflows/docs.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
name: Build & Deploy Homepage
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Configure Git Credentials
|
||||
run: |
|
||||
git config user.name github-actions[bot]
|
||||
git config user.email 41898282+github-actions[bot]@users.noreply.github.com
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.x
|
||||
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
|
||||
- uses: actions/cache@v4
|
||||
with:
|
||||
key: mkdocs-material-${{ env.cache_id }}
|
||||
path: ~/.cache
|
||||
restore-keys: |
|
||||
mkdocs-material-
|
||||
- run: pip install mkdocs-material mkdocs-git-revision-date-localized-plugin
|
||||
- run: |
|
||||
mkdocs gh-deploy --force \
|
||||
--config-file docs/mkdocs.yml \
|
||||
-m "docs: update website"
|
||||
2
.github/workflows/pull-request.yml
vendored
2
.github/workflows/pull-request.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: "1.25.1"
|
||||
go-version: "1.25.3"
|
||||
|
||||
- name: Golangci-lint
|
||||
uses: golangci/golangci-lint-action@v8.0.0
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -2,7 +2,7 @@ tmp
|
||||
requests.json
|
||||
counters.json
|
||||
main
|
||||
goaway*
|
||||
goaway
|
||||
*.db**
|
||||
.vite
|
||||
dist
|
||||
@@ -38,3 +38,5 @@ benchmark.prof
|
||||
|
||||
*.crt
|
||||
*.key
|
||||
|
||||
docs/site/**
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM golang:1.25.1-alpine
|
||||
FROM golang:1.25.3-alpine
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
22
README.md
22
README.md
@@ -1,5 +1,3 @@
|
||||
# GoAway - DNS Sinkhole
|
||||
|
||||

|
||||

|
||||

|
||||
@@ -7,11 +5,19 @@
|
||||
|
||||
A lightweight DNS sinkhole for blocking unwanted domains at the network level. Block ads, trackers, and malicious domains before they reach your devices.
|
||||
|
||||

|
||||

|
||||
|
||||
**[View more screenshots](./resources/PREVIEW.md)**
|
||||
|
||||
<<<<<<< Updated upstream
|
||||
## 🌟 Features
|
||||
=======
|
||||
## Getting started
|
||||
|
||||
Instructions for installation, configuration and more can be found on the homepage: https://pommee.github.io/goaway
|
||||
|
||||
## Features
|
||||
>>>>>>> Stashed changes
|
||||
|
||||
- DNS-level domain blocking
|
||||
- Web-based admin dashboard
|
||||
@@ -219,7 +225,7 @@ Contributions are welcomed! Here's how you can help:
|
||||
3. **Submit PRs:** Before any work is started, create a new issue explaining what is wanted, why it would fit, how it can be done, so on and so forth...
|
||||
Once the topic has been discussed with a maintainer then either you or a maintainer starts with the implementation. This is done to prevent any collisions, save time and confusion. [Read more here](./CONTRIBUTING.md)
|
||||
|
||||
## ⚠️ Platform Support
|
||||
## Platform Support
|
||||
|
||||
| Platform | Architecture | Support Level |
|
||||
| -------- | ------------ | ------------- |
|
||||
@@ -233,7 +239,7 @@ Contributions are welcomed! Here's how you can help:
|
||||
|
||||
> **Note**: Primary testing is conducted on Linux (amd64). While the aim is to support all listed platforms, functionality on macOS and Windows may vary.
|
||||
|
||||
## 🔍 Troubleshooting
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
@@ -252,7 +258,7 @@ Contributions are welcomed! Here's how you can help:
|
||||
- Check device DNS settings point to GoAway's IP address
|
||||
- Test with `nslookup google.com <goaway-ip>` or `dig @<goaway-ip> google.com.`
|
||||
|
||||
## 📈 Performance
|
||||
## Performance
|
||||
|
||||
GoAway is designed to be lightweight and efficient:
|
||||
|
||||
@@ -261,10 +267,10 @@ GoAway is designed to be lightweight and efficient:
|
||||
- **Network:** Low latency DNS resolution
|
||||
- **Storage:** Logs and statistics use minimal disk space
|
||||
|
||||
## 📜 License
|
||||
## License
|
||||
|
||||
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
||||
|
||||
## 🙏 Acknowledgments
|
||||
## Acknowledgments
|
||||
|
||||
This project is heavily inspired by [Pi-hole](https://github.com/pi-hole/pi-hole). Thanks to all people involved for their work.
|
||||
|
||||
@@ -16,8 +16,8 @@ type DiscordConfig struct {
|
||||
}
|
||||
|
||||
type DiscordService struct {
|
||||
config DiscordConfig
|
||||
client *http.Client
|
||||
config DiscordConfig
|
||||
}
|
||||
|
||||
type DiscordWebhookPayload struct {
|
||||
@@ -28,12 +28,12 @@ type DiscordWebhookPayload struct {
|
||||
}
|
||||
|
||||
type DiscordEmbed struct {
|
||||
Author *DiscordEmbedAuthor `json:"author,omitempty"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Color int `json:"color,omitempty"`
|
||||
Fields []DiscordEmbedField `json:"fields,omitempty"`
|
||||
Timestamp string `json:"timestamp,omitempty"`
|
||||
Author *DiscordEmbedAuthor `json:"author,omitempty"`
|
||||
Fields []DiscordEmbedField `json:"fields,omitempty"`
|
||||
Color int `json:"color,omitempty"`
|
||||
}
|
||||
|
||||
type DiscordEmbedField struct {
|
||||
@@ -81,7 +81,7 @@ func (d *DiscordService) SendMessage(ctx context.Context, msg Message) error {
|
||||
return fmt.Errorf("failed to marshal Discord payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", d.config.WebhookURL, bytes.NewBuffer(jsonData))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, d.config.WebhookURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Discord request: %w", err)
|
||||
}
|
||||
|
||||
61
backend/alert/repository.go
Normal file
61
backend/alert/repository.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package alert
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"goaway/backend/database"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
SaveAlert(alert database.Alert) error
|
||||
GetAllAlerts() ([]database.Alert, error)
|
||||
RemoveAlert(alertType string) error
|
||||
}
|
||||
|
||||
type repository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewRepository(db *gorm.DB) Repository {
|
||||
return &repository{db: db}
|
||||
}
|
||||
|
||||
func (r *repository) SaveAlert(alert database.Alert) error {
|
||||
result := r.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "type"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"enabled", "name", "webhook"}),
|
||||
}).Create(&alert)
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to save alert: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repository) GetAllAlerts() ([]database.Alert, error) {
|
||||
var alerts []database.Alert
|
||||
result := r.db.Find(&alerts)
|
||||
return alerts, result.Error
|
||||
}
|
||||
|
||||
func (r *repository) RemoveAlert(alertType string) error {
|
||||
if alertType == "" {
|
||||
return fmt.Errorf("alert type cannot be empty")
|
||||
}
|
||||
|
||||
result := r.db.Where("type = ?", alertType).Delete(&database.Alert{})
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to remove alert: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
log.Warning("No alert found with type: %s", alertType)
|
||||
return fmt.Errorf("no alert found with type: %s", alertType)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -3,14 +3,14 @@ package alert
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"goaway/backend/dns/database"
|
||||
"goaway/backend/database"
|
||||
"goaway/backend/logging"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
var log = logging.GetLogger()
|
||||
type Service struct {
|
||||
repository Repository
|
||||
services []messageSender
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Content string
|
||||
@@ -19,32 +19,50 @@ type Message struct {
|
||||
Severity string
|
||||
}
|
||||
|
||||
type MessageSender interface {
|
||||
type messageSender interface {
|
||||
SendMessage(ctx context.Context, msg Message) error
|
||||
IsEnabled() bool
|
||||
GetServiceName() string
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
DB *gorm.DB
|
||||
services []MessageSender
|
||||
}
|
||||
var log = logging.GetLogger()
|
||||
|
||||
func NewManager(db *gorm.DB) *Manager {
|
||||
return &Manager{
|
||||
DB: db,
|
||||
services: make([]MessageSender, 0),
|
||||
func NewService(repo Repository) *Service {
|
||||
return &Service{
|
||||
repository: repo,
|
||||
services: make([]messageSender, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Reset() {
|
||||
m.services = make([]MessageSender, 0)
|
||||
func (s *Service) reset() {
|
||||
s.services = make([]messageSender, 0)
|
||||
}
|
||||
|
||||
func (m *Manager) Load() {
|
||||
func (s *Service) registerService(service messageSender) {
|
||||
log.Debug("Registering alert service: %s", service.GetServiceName())
|
||||
s.services = append(s.services, service)
|
||||
}
|
||||
|
||||
func (s *Service) SaveAlert(alert database.Alert) error {
|
||||
err := s.repository.SaveAlert(alert)
|
||||
if err != nil {
|
||||
log.Error("Failed to save alert settings: %v", err)
|
||||
return err
|
||||
}
|
||||
s.Load()
|
||||
log.Info("Alert settings saved for type: %s", alert.Type)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) GetAllAlerts() ([]database.Alert, error) {
|
||||
return s.repository.GetAllAlerts()
|
||||
}
|
||||
|
||||
func (s *Service) Load() {
|
||||
discordService := NewDiscordService(DiscordConfig{})
|
||||
|
||||
alerts, err := m.GetAllAlerts()
|
||||
alerts, err := s.GetAllAlerts()
|
||||
|
||||
if err != nil {
|
||||
log.Warning("Failed to load alerts from database: %v", err)
|
||||
@@ -62,61 +80,15 @@ func (m *Manager) Load() {
|
||||
}
|
||||
}
|
||||
|
||||
m.Reset()
|
||||
m.RegisterService(discordService)
|
||||
log.Debug("Alert Manager loaded with %d services", len(m.services))
|
||||
s.reset()
|
||||
s.registerService(discordService)
|
||||
log.Debug("Alert Manager loaded with %d services", len(s.services))
|
||||
}
|
||||
|
||||
func (m *Manager) SaveAlert(alert database.Alert) error {
|
||||
result := m.DB.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "type"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"enabled", "name", "webhook"}),
|
||||
}).Create(&alert)
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to save alert: %w", result.Error)
|
||||
}
|
||||
|
||||
log.Info("Alert settings saved for type: %s", alert.Type)
|
||||
|
||||
m.Load()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) RemoveAlert(alertType string) error {
|
||||
if alertType == "" {
|
||||
return fmt.Errorf("alert type cannot be empty")
|
||||
}
|
||||
|
||||
result := m.DB.Where("type = ?", alertType).Delete(&database.Alert{})
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to remove alert: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
log.Warning("No alert found with type: %s", alertType)
|
||||
return fmt.Errorf("no alert found with type: %s", alertType)
|
||||
}
|
||||
|
||||
m.Load()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetAllAlerts() ([]database.Alert, error) {
|
||||
ctx := context.Background()
|
||||
return gorm.G[database.Alert](m.DB).Find(ctx)
|
||||
}
|
||||
|
||||
func (m *Manager) RegisterService(service MessageSender) {
|
||||
log.Debug("Registering alert service: %s", service.GetServiceName())
|
||||
m.services = append(m.services, service)
|
||||
}
|
||||
|
||||
func (m *Manager) SendToAll(ctx context.Context, msg Message) error {
|
||||
func (s *Service) SendToAll(ctx context.Context, msg Message) error {
|
||||
var errors []error
|
||||
|
||||
for _, service := range m.services {
|
||||
for _, service := range s.services {
|
||||
if !service.IsEnabled() {
|
||||
continue
|
||||
}
|
||||
@@ -135,8 +107,20 @@ func (m *Manager) SendToAll(ctx context.Context, msg Message) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) SendTest(ctx context.Context, alertType, name, webhook string) error {
|
||||
err := m.SaveAlert(database.Alert{
|
||||
func (s *Service) RemoveAlert(alertType string) error {
|
||||
err := s.repository.RemoveAlert(alertType)
|
||||
if err != nil {
|
||||
log.Error("Failed to remove alert: %v", err)
|
||||
return err
|
||||
}
|
||||
s.Load()
|
||||
log.Info("Alert removed for type: %s", alertType)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) SendTest(ctx context.Context, alertType, name, webhook string) error {
|
||||
err := s.SaveAlert(database.Alert{
|
||||
Type: alertType,
|
||||
Enabled: true,
|
||||
Name: name,
|
||||
@@ -147,7 +131,7 @@ func (m *Manager) SendTest(ctx context.Context, alertType, name, webhook string)
|
||||
return err
|
||||
}
|
||||
|
||||
for _, service := range m.services {
|
||||
for _, service := range s.services {
|
||||
if service.GetServiceName() == alertType {
|
||||
log.Info("Sending test alert via %s", service.GetServiceName())
|
||||
err = service.SendMessage(ctx, Message{
|
||||
@@ -162,7 +146,7 @@ func (m *Manager) SendTest(ctx context.Context, alertType, name, webhook string)
|
||||
}
|
||||
}
|
||||
|
||||
err = m.RemoveAlert(alertType)
|
||||
err = s.RemoveAlert(alertType)
|
||||
if err != nil {
|
||||
log.Error("Failed to remove test alert settings: %v", err)
|
||||
return err
|
||||
@@ -2,7 +2,7 @@ package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"goaway/backend/dns/database"
|
||||
"goaway/backend/database"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -15,14 +15,14 @@ const (
|
||||
SeverityError = "error"
|
||||
)
|
||||
|
||||
type DiscordSettings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
type discordSettings struct {
|
||||
Name string `json:"name"`
|
||||
Webhook string `json:"webhook"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
type AlertSettings struct {
|
||||
Discord DiscordSettings `json:"discord"`
|
||||
type alertSettings struct {
|
||||
Discord discordSettings `json:"discord"`
|
||||
}
|
||||
|
||||
func (api *API) registerAlertRoutes() {
|
||||
@@ -33,7 +33,7 @@ func (api *API) registerAlertRoutes() {
|
||||
}
|
||||
|
||||
func (api *API) setAlert(c *gin.Context) {
|
||||
var request AlertSettings
|
||||
var request alertSettings
|
||||
err := c.Bind(&request)
|
||||
if err != nil {
|
||||
log.Error("Failed to parse alert settings: %v", err)
|
||||
@@ -41,7 +41,7 @@ func (api *API) setAlert(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err = api.DNSServer.Alerts.SaveAlert(database.Alert{
|
||||
err = api.DNSServer.AlertService.SaveAlert(database.Alert{
|
||||
Type: "discord",
|
||||
Enabled: request.Discord.Enabled,
|
||||
Name: request.Discord.Name,
|
||||
@@ -57,17 +57,17 @@ func (api *API) setAlert(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (api *API) getAlert(c *gin.Context) {
|
||||
alerts, err := api.DNSServer.Alerts.GetAllAlerts()
|
||||
alerts, err := api.DNSServer.AlertService.GetAllAlerts()
|
||||
if err != nil {
|
||||
log.Error("Failed to retrieve alert settings: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve alert settings"})
|
||||
return
|
||||
}
|
||||
|
||||
var response AlertSettings
|
||||
var response alertSettings
|
||||
for _, alert := range alerts {
|
||||
if alert.Type == "discord" {
|
||||
response.Discord = DiscordSettings{
|
||||
response.Discord = discordSettings{
|
||||
Enabled: alert.Enabled,
|
||||
Name: alert.Name,
|
||||
Webhook: alert.Webhook,
|
||||
@@ -79,7 +79,7 @@ func (api *API) getAlert(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (api *API) testAlert(c *gin.Context) {
|
||||
var request AlertSettings
|
||||
var request alertSettings
|
||||
err := c.Bind(&request)
|
||||
if err != nil {
|
||||
log.Error("Failed to parse alert settings: %v", err)
|
||||
@@ -87,7 +87,7 @@ func (api *API) testAlert(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err = api.DNSServer.Alerts.SendTest(
|
||||
err = api.DNSServer.AlertService.SendTest(
|
||||
context.Background(),
|
||||
"discord",
|
||||
request.Discord.Name,
|
||||
|
||||
@@ -5,6 +5,18 @@ import (
|
||||
"embed"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"goaway/backend/api/key"
|
||||
"goaway/backend/api/ratelimit"
|
||||
"goaway/backend/blacklist"
|
||||
"goaway/backend/dns/server"
|
||||
"goaway/backend/logging"
|
||||
"goaway/backend/notification"
|
||||
"goaway/backend/prefetch"
|
||||
"goaway/backend/request"
|
||||
"goaway/backend/resolution"
|
||||
"goaway/backend/settings"
|
||||
"goaway/backend/user"
|
||||
"goaway/backend/whitelist"
|
||||
"io/fs"
|
||||
"mime"
|
||||
"net"
|
||||
@@ -18,65 +30,53 @@ import (
|
||||
"github.com/gin-contrib/gzip"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"goaway/backend/api/key"
|
||||
"goaway/backend/api/ratelimit"
|
||||
"goaway/backend/api/user"
|
||||
"goaway/backend/dns/database"
|
||||
"goaway/backend/dns/lists"
|
||||
"goaway/backend/dns/server"
|
||||
"goaway/backend/dns/server/prefetch"
|
||||
"goaway/backend/logging"
|
||||
notification "goaway/backend/notifications"
|
||||
"goaway/backend/settings"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var log = logging.GetLogger()
|
||||
|
||||
const (
|
||||
maxRetries = 10
|
||||
retryDelay = 10 * time.Second
|
||||
)
|
||||
|
||||
type API struct {
|
||||
Authentication bool
|
||||
Config *settings.Config
|
||||
DNSPort int
|
||||
|
||||
Version string
|
||||
Commit string
|
||||
Date string
|
||||
|
||||
router *gin.Engine
|
||||
routes *gin.RouterGroup
|
||||
|
||||
DNSServer *server.DNSServer
|
||||
DBManager *database.DatabaseManager
|
||||
Blacklist *lists.Blacklist
|
||||
Whitelist *lists.Whitelist
|
||||
KeyManager *key.ApiKeyManager
|
||||
PrefetchedDomainsManager *prefetch.Manager
|
||||
Notifications *notification.Manager
|
||||
|
||||
WSQueries *websocket.Conn
|
||||
DNS *server.DNSServer
|
||||
RateLimiter *ratelimit.RateLimiter
|
||||
DBConn *gorm.DB
|
||||
WSCommunication *websocket.Conn
|
||||
WSQueries *websocket.Conn
|
||||
router *gin.Engine
|
||||
routes *gin.RouterGroup
|
||||
Config *settings.Config
|
||||
DNSServer *server.DNSServer
|
||||
Version string
|
||||
Date string
|
||||
Commit string
|
||||
DNSPort int
|
||||
Authentication bool
|
||||
|
||||
RateLimiter *ratelimit.RateLimiter
|
||||
RequestService *request.Service
|
||||
UserService *user.Service
|
||||
KeyService *key.Service
|
||||
PrefetchService *prefetch.Service
|
||||
ResolutionService *resolution.Service
|
||||
NotificationService *notification.Service
|
||||
BlacklistService *blacklist.Service
|
||||
WhitelistService *whitelist.Service
|
||||
}
|
||||
|
||||
func (api *API) Start(content embed.FS, errorChannel chan struct{}) {
|
||||
api.initializeRouter()
|
||||
api.configureCORS()
|
||||
api.KeyManager = key.NewApiKeyManager(api.DBManager)
|
||||
api.RateLimiter = ratelimit.NewRateLimiter(
|
||||
api.Config.API.RateLimiterConfig.Enabled,
|
||||
api.Config.API.RateLimiterConfig.MaxTries,
|
||||
api.Config.API.RateLimiterConfig.Window,
|
||||
)
|
||||
api.setupRoutes()
|
||||
api.RateLimiter = ratelimit.NewRateLimiter(
|
||||
api.Config.API.RateLimit.Enabled,
|
||||
api.Config.API.RateLimit.MaxTries,
|
||||
api.Config.API.RateLimit.Window,
|
||||
)
|
||||
|
||||
if api.Config.Dashboard {
|
||||
api.ServeEmbeddedContent(content)
|
||||
if api.Config.Misc.Dashboard {
|
||||
api.serveEmbeddedContent(content)
|
||||
}
|
||||
|
||||
api.startServer(errorChannel)
|
||||
@@ -104,7 +104,7 @@ func (api *API) configureCORS() {
|
||||
}
|
||||
)
|
||||
|
||||
if api.Config.Dashboard {
|
||||
if api.Config.Misc.Dashboard {
|
||||
corsConfig.AllowOrigins = append(corsConfig.AllowOrigins, "*")
|
||||
} else {
|
||||
log.Warning("Dashboard UI is disabled")
|
||||
@@ -134,23 +134,19 @@ func (api *API) setupRoutes() {
|
||||
|
||||
func (api *API) setupAuthAndMiddleware() {
|
||||
if api.Authentication {
|
||||
api.SetupAuth()
|
||||
api.setupAuth()
|
||||
api.routes.Use(api.authMiddleware())
|
||||
} else {
|
||||
log.Warning("Authentication is disabled.")
|
||||
log.Warning("Dashboard authentication is disabled.")
|
||||
}
|
||||
}
|
||||
|
||||
func (api *API) SetupAuth() {
|
||||
newUser := &user.User{Username: "admin"}
|
||||
if newUser.Exists(api.DBManager.Conn) {
|
||||
func (api *API) setupAuth() {
|
||||
if api.UserService.Exists("admin") {
|
||||
return
|
||||
}
|
||||
|
||||
password := api.getOrGeneratePassword()
|
||||
newUser.Password = password
|
||||
|
||||
if err := newUser.Create(api.DBManager.Conn); err != nil {
|
||||
if err := api.UserService.CreateUser("admin", api.getOrGeneratePassword()); err != nil {
|
||||
log.Error("Unable to create new user: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -204,7 +200,7 @@ func (api *API) startServer(errorChannel chan struct{}) {
|
||||
}
|
||||
}
|
||||
|
||||
func (api *API) ServeEmbeddedContent(content embed.FS) {
|
||||
func (api *API) serveEmbeddedContent(content embed.FS) {
|
||||
ipAddress, err := GetServerIP()
|
||||
if err != nil {
|
||||
log.Error("Error getting IP address: %v", err)
|
||||
@@ -291,6 +287,7 @@ func injectServerConfig(htmlContent, serverIP string, port int) string {
|
||||
)
|
||||
}
|
||||
|
||||
// GetServerIP retrieves the first non-loopback IPv4 address of the server.
|
||||
func GetServerIP() (string, error) {
|
||||
addrs, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
|
||||
@@ -11,7 +11,7 @@ func (api *API) registerAuditRoutes() {
|
||||
}
|
||||
|
||||
func (api *API) getAudits(c *gin.Context) {
|
||||
audits, err := api.DNSServer.Audits.ReadAudits()
|
||||
audits, err := api.DNSServer.AuditService.ReadAudits()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
|
||||
@@ -3,18 +3,14 @@ package api
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"goaway/backend/alert"
|
||||
"goaway/backend/api/user"
|
||||
"goaway/backend/audit"
|
||||
"goaway/backend/dns/database"
|
||||
"goaway/backend/user"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func (api *API) registerAuthRoutes() {
|
||||
@@ -27,11 +23,6 @@ func (api *API) registerAuthRoutes() {
|
||||
api.routes.GET("/deleteApiKey", api.deleteAPIKey)
|
||||
}
|
||||
|
||||
func (api *API) validateCredentials(username, password string) bool {
|
||||
existingUser := &user.User{Username: username, Password: password}
|
||||
return existingUser.Authenticate(api.DBManager.Conn)
|
||||
}
|
||||
|
||||
func (api *API) handleLogin(c *gin.Context) {
|
||||
allowed, timeUntilReset := api.RateLimiter.CheckLimit(c.ClientIP())
|
||||
if !allowed {
|
||||
@@ -42,21 +33,21 @@ func (api *API) handleLogin(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var creds user.Credentials
|
||||
if err := c.BindJSON(&creds); err != nil {
|
||||
var loginUser user.User
|
||||
if err := c.BindJSON(&loginUser); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := creds.Validate(); err != nil {
|
||||
if err := api.UserService.ValidateCredentials(loginUser); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid input"})
|
||||
return
|
||||
}
|
||||
|
||||
if api.authenticateUser(creds.Username, creds.Password) {
|
||||
token, err := generateToken(creds.Username)
|
||||
if api.UserService.Authenticate(loginUser.Username, loginUser.Password) {
|
||||
token, err := generateToken(loginUser.Username)
|
||||
if err != nil {
|
||||
log.Info("Token generation failed for user %s: %v", creds.Username, err)
|
||||
log.Info("Token generation failed for user %s: %v", loginUser.Username, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "Authentication service temporarily unavailable",
|
||||
})
|
||||
@@ -75,27 +66,6 @@ func (api *API) handleLogin(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (api *API) authenticateUser(username, password string) bool {
|
||||
var user database.User
|
||||
|
||||
if err := api.DBManager.Conn.Where("username = ?", username).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Info("Authentication attempt for non-existent or invalid credentials")
|
||||
} else {
|
||||
log.Warning("Database error during authentication: %v", err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
|
||||
log.Info("Password comparison failed for user: %s", username)
|
||||
return false
|
||||
}
|
||||
|
||||
log.Info("Successful authentication for user: %s", username)
|
||||
return true
|
||||
}
|
||||
|
||||
func (api *API) getAuthentication(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"enabled": api.Authentication})
|
||||
}
|
||||
@@ -112,24 +82,23 @@ func (api *API) updatePassword(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if !api.validateCredentials("admin", newCredentials.CurrentPassword) {
|
||||
if !api.UserService.Authenticate("admin", newCredentials.CurrentPassword) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Current password is not valid"})
|
||||
return
|
||||
}
|
||||
|
||||
existingUser := user.User{Username: "admin", Password: newCredentials.NewPassword}
|
||||
if err := existingUser.UpdatePassword(api.DBManager.Conn); err != nil {
|
||||
if err := api.UserService.UpdatePassword("admin", newCredentials.NewPassword); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Unable to update password"})
|
||||
return
|
||||
}
|
||||
|
||||
logMsg := fmt.Sprintf("Password changed for user '%s'", existingUser.Username)
|
||||
api.DNSServer.Audits.CreateAudit(&audit.Entry{
|
||||
logMsg := "Password changed for user 'admin'"
|
||||
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
|
||||
Topic: audit.TopicUser,
|
||||
Message: logMsg,
|
||||
})
|
||||
go func() {
|
||||
_ = api.DNSServer.Alerts.SendToAll(context.Background(), alert.Message{
|
||||
_ = api.DNSServer.AlertService.SendToAll(context.Background(), alert.Message{
|
||||
Title: "System",
|
||||
Content: logMsg,
|
||||
Severity: SeverityWarning,
|
||||
@@ -141,7 +110,7 @@ func (api *API) updatePassword(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (api *API) createAPIKey(c *gin.Context) {
|
||||
type NewApiKeyName struct {
|
||||
type NewAPIKeyName struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
@@ -152,21 +121,21 @@ func (api *API) createAPIKey(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var request NewApiKeyName
|
||||
var request NewAPIKeyName
|
||||
if err := json.Unmarshal(body, &request); err != nil {
|
||||
log.Error("Failed to parse JSON: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid JSON format"})
|
||||
return
|
||||
}
|
||||
|
||||
apiKey, err := api.KeyManager.CreateKey(request.Name)
|
||||
apiKey, err := api.KeyService.CreateKey(request.Name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = api.DNSServer.Alerts.SendToAll(context.Background(), alert.Message{
|
||||
_ = api.DNSServer.AlertService.SendToAll(context.Background(), alert.Message{
|
||||
Title: "System",
|
||||
Content: fmt.Sprintf("New API key created with the name '%s'", request.Name),
|
||||
Severity: SeverityWarning,
|
||||
@@ -177,7 +146,7 @@ func (api *API) createAPIKey(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (api *API) getAPIKeys(c *gin.Context) {
|
||||
apiKeys, err := api.KeyManager.GetAllKeys()
|
||||
apiKeys, err := api.KeyService.GetAllKeys()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -188,7 +157,7 @@ func (api *API) getAPIKeys(c *gin.Context) {
|
||||
func (api *API) deleteAPIKey(c *gin.Context) {
|
||||
keyName := c.Query("name")
|
||||
|
||||
err := api.KeyManager.DeleteKey(keyName)
|
||||
err := api.KeyService.DeleteKey(keyName)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"goaway/backend/audit"
|
||||
"goaway/backend/dns/database"
|
||||
"goaway/backend/dns/server/prefetch"
|
||||
"goaway/backend/database"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -44,13 +44,13 @@ func (api *API) createPrefetchedDomain(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err = api.PrefetchedDomainsManager.AddPrefetchedDomain(prefetchedDomain.Domain, prefetchedDomain.Refresh, prefetchedDomain.QType)
|
||||
err = api.PrefetchService.AddPrefetchedDomain(prefetchedDomain.Domain, prefetchedDomain.Refresh, prefetchedDomain.QType)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
api.DNSServer.Audits.CreateAudit(&audit.Entry{
|
||||
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
|
||||
Topic: audit.TopicPrefetch,
|
||||
Message: fmt.Sprintf("Added new prefetch '%s'", prefetchedDomain.Domain),
|
||||
})
|
||||
@@ -58,8 +58,8 @@ func (api *API) createPrefetchedDomain(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (api *API) fetchPrefetchedDomains(c *gin.Context) {
|
||||
prefetchedDomains := make([]prefetch.PrefetchedDomain, 0)
|
||||
for _, b := range api.PrefetchedDomainsManager.Domains {
|
||||
prefetchedDomains := make([]database.Prefetch, 0)
|
||||
for _, b := range api.PrefetchService.Domains {
|
||||
prefetchedDomains = append(prefetchedDomains, b)
|
||||
}
|
||||
c.JSON(http.StatusOK, prefetchedDomains)
|
||||
@@ -72,9 +72,9 @@ func (api *API) removeDomainFromCustom(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Empty domain name"})
|
||||
}
|
||||
|
||||
err := api.Blacklist.RemoveCustomDomain(domain)
|
||||
err := api.BlacklistService.RemoveCustomDomain(context.Background(), domain)
|
||||
if err != nil {
|
||||
log.Debug("Error occured while removing domain from custom list: %v", err)
|
||||
log.Debug("Error occurred while removing domain from custom list: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to update custom blocklist."})
|
||||
return
|
||||
}
|
||||
@@ -98,7 +98,7 @@ func (api *API) getBlacklistedDomains(c *gin.Context) {
|
||||
pageSizeInt = 10
|
||||
}
|
||||
|
||||
domains, total, err := api.Blacklist.LoadPaginatedBlacklist(pageInt, pageSizeInt, search)
|
||||
domains, total, err := api.BlacklistService.LoadPaginatedBlacklist(context.Background(), pageInt, pageSizeInt, search)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -113,8 +113,8 @@ func (api *API) getBlacklistedDomains(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (api *API) getTopBlockedDomains(c *gin.Context) {
|
||||
_, blocked, _ := api.Blacklist.GetAllowedAndBlocked()
|
||||
topBlockedDomains, err := database.GetTopBlockedDomains(api.DBManager.Conn, blocked)
|
||||
_, blocked, _ := api.BlacklistService.GetAllowedAndBlocked(context.Background())
|
||||
topBlockedDomains, err := api.RequestService.GetTopBlockedDomains(blocked)
|
||||
if err != nil {
|
||||
log.Error("%v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -131,7 +131,7 @@ func (api *API) getDomainsForList(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
domains, err := api.Blacklist.GetDomainsForList(list)
|
||||
domains, _, err := api.BlacklistService.FetchDBHostsList(context.Background(), list)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -143,19 +143,19 @@ func (api *API) getDomainsForList(c *gin.Context) {
|
||||
func (api *API) deletePrefetchedDomain(c *gin.Context) {
|
||||
domainPrefetchToDelete := c.Query("domain")
|
||||
|
||||
domain := api.PrefetchedDomainsManager.Domains[domainPrefetchToDelete]
|
||||
if (domain == prefetch.PrefetchedDomain{}) {
|
||||
domain := api.PrefetchService.Domains[domainPrefetchToDelete]
|
||||
if (domain == database.Prefetch{}) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not exist", domainPrefetchToDelete)})
|
||||
return
|
||||
}
|
||||
|
||||
err := api.PrefetchedDomainsManager.RemovePrefetchedDomain(domainPrefetchToDelete)
|
||||
err := api.PrefetchService.RemovePrefetchedDomain(domainPrefetchToDelete)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
api.DNSServer.Audits.CreateAudit(&audit.Entry{
|
||||
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
|
||||
Topic: audit.TopicPrefetch,
|
||||
Message: fmt.Sprintf("Removed prefetched domain '%s'", domainPrefetchToDelete),
|
||||
})
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"goaway/backend/dns/database"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -14,7 +13,7 @@ func (api *API) registerClientRoutes() {
|
||||
}
|
||||
|
||||
func (api *API) getClients(c *gin.Context) {
|
||||
uniqueClients, err := database.FetchAllClients(api.DBManager.Conn)
|
||||
uniqueClients, err := api.RequestService.FetchAllClients()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -37,7 +36,7 @@ func (api *API) getClients(c *gin.Context) {
|
||||
func (api *API) getClientDetails(c *gin.Context) {
|
||||
clientIP := c.DefaultQuery("clientIP", "")
|
||||
|
||||
clientRequestDetails, mostQueriedDomain, domainQueryCounts, err := database.GetClientDetailsWithDomains(api.DBManager.Conn, clientIP)
|
||||
clientRequestDetails, mostQueriedDomain, domainQueryCounts, err := api.RequestService.GetClientDetailsWithDomains(clientIP)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -57,7 +56,7 @@ func (api *API) getClientDetails(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (api *API) getTopClients(c *gin.Context) {
|
||||
topClients, err := database.GetTopClients(api.DBManager.Conn)
|
||||
topClients, err := api.RequestService.GetTopClients()
|
||||
if err != nil {
|
||||
log.Error("%v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"goaway/backend/api/models"
|
||||
"goaway/backend/audit"
|
||||
"goaway/backend/dns/database"
|
||||
"goaway/backend/database"
|
||||
"goaway/backend/dns/server"
|
||||
model "goaway/backend/dns/server/models"
|
||||
"goaway/backend/settings"
|
||||
@@ -18,7 +19,7 @@ import (
|
||||
)
|
||||
|
||||
func (api *API) registerDNSRoutes() {
|
||||
api.setupWSLiveQueries(api.PrefetchedDomainsManager.DNS)
|
||||
api.setupWSLiveQueries(api.DNS)
|
||||
|
||||
api.routes.POST("/pause", api.pauseBlocking)
|
||||
api.routes.GET("/pause", api.getBlocking)
|
||||
@@ -44,52 +45,63 @@ func (api *API) pauseBlocking(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
api.Config.DNS.Status = settings.Status{
|
||||
Paused: true,
|
||||
PausedAt: time.Now(),
|
||||
PauseTime: blockTime.Time,
|
||||
now := time.Now()
|
||||
if blockTime.Time <= 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "Time must be greater than 0",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
duration := time.Duration(blockTime.Time) * time.Second
|
||||
pauseTime := now.Add(duration)
|
||||
|
||||
api.Config.DNS.Status.Paused = true
|
||||
api.Config.DNS.Status.PausedAt = now
|
||||
api.Config.DNS.Status.PauseTime = pauseTime
|
||||
|
||||
log.Info("DNS blocking paused for %d seconds", blockTime.Time)
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
func (api *API) getBlocking(c *gin.Context) {
|
||||
if api.Config.DNS.Status.Paused {
|
||||
elapsed := time.Since(api.Config.DNS.Status.PausedAt).Seconds()
|
||||
remainingTime := api.Config.DNS.Status.PauseTime - int(elapsed)
|
||||
now := time.Now()
|
||||
remainingTime := api.Config.DNS.Status.PauseTime.Sub(now)
|
||||
|
||||
if remainingTime <= 0 {
|
||||
api.Config.DNS.Status.Paused = false
|
||||
c.JSON(http.StatusOK, gin.H{"paused": false})
|
||||
return
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{"paused": true, "timeLeft": remainingTime})
|
||||
secondsLeft := int(remainingTime.Seconds())
|
||||
c.JSON(http.StatusOK, gin.H{"paused": true, "timeLeft": secondsLeft})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !api.Config.DNS.Status.Paused {
|
||||
c.JSON(http.StatusOK, gin.H{"paused": false})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"paused": false})
|
||||
}
|
||||
|
||||
func (api *API) getQueries(c *gin.Context) {
|
||||
query := parseQueryParams(c)
|
||||
|
||||
type result struct {
|
||||
err error
|
||||
queries []model.RequestLogEntry
|
||||
total int
|
||||
err error
|
||||
}
|
||||
|
||||
queryCh := make(chan result, 1)
|
||||
countCh := make(chan result, 1)
|
||||
|
||||
go func() {
|
||||
queries, err := database.FetchQueries(api.DBManager.Conn, query)
|
||||
queries, err := api.RequestService.FetchQueries(query)
|
||||
queryCh <- result{queries: queries, err: err}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
total, err := database.CountQueries(api.DBManager.Conn, query.Search)
|
||||
total, err := api.RequestService.CountQueries(query.Search)
|
||||
countCh <- result{total: total, err: err}
|
||||
}()
|
||||
|
||||
@@ -169,7 +181,7 @@ func (api *API) getQueryTimestamps(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
timestamps, err := database.GetRequestSummaryByInterval(interval, api.DBManager.Conn)
|
||||
timestamps, err := api.RequestService.GetRequestSummaryByInterval(interval)
|
||||
if err != nil {
|
||||
log.Error("Failed to get request summary: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"})
|
||||
@@ -187,7 +199,7 @@ func (api *API) getResponseSizeTimestamps(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
timestamps, err := database.GetResponseSizeSummaryByInterval(interval, api.DBManager.Conn)
|
||||
timestamps, err := api.RequestService.GetResponseSizeSummaryByInterval(interval)
|
||||
if err != nil {
|
||||
log.Error("Error fetching response size timestamps: %v", err.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -198,7 +210,7 @@ func (api *API) getResponseSizeTimestamps(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (api *API) getQueryTypes(c *gin.Context) {
|
||||
queries, err := database.GetUniqueQueryTypes(api.DBManager.Conn)
|
||||
queries, err := api.RequestService.GetUniqueQueryTypes()
|
||||
if err != nil {
|
||||
log.Error("%v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -209,19 +221,19 @@ func (api *API) getQueryTypes(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (api *API) clearQueries(c *gin.Context) {
|
||||
if err := api.DBManager.Conn.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&database.RequestLog{}).Error; err != nil {
|
||||
if err := api.DBConn.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&database.RequestLog{}).Error; err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Could not clear query logs", "reason": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := api.DBManager.Conn.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&database.RequestLogIP{}).Error; err != nil {
|
||||
if err := api.DBConn.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&database.RequestLogIP{}).Error; err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Could not clear IP logs", "reason": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
api.Blacklist.Vacuum()
|
||||
api.BlacklistService.Vacuum(context.Background())
|
||||
|
||||
api.DNSServer.Audits.CreateAudit(&audit.Entry{
|
||||
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
|
||||
Topic: audit.TopicLogs,
|
||||
Message: "Logs were cleared",
|
||||
})
|
||||
|
||||
@@ -1,177 +0,0 @@
|
||||
package key
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"goaway/backend/dns/database"
|
||||
"goaway/backend/logging"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ApiKey struct {
|
||||
Name string `json:"name"`
|
||||
Key string `json:"key"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
type ApiKeyManager struct {
|
||||
dbManager *database.DatabaseManager
|
||||
keyCache map[string]ApiKey
|
||||
cacheMu sync.RWMutex
|
||||
cacheTime time.Time
|
||||
cacheTTL time.Duration
|
||||
}
|
||||
|
||||
var log = logging.GetLogger()
|
||||
|
||||
func NewApiKeyManager(dbManager *database.DatabaseManager) *ApiKeyManager {
|
||||
return &ApiKeyManager{
|
||||
dbManager: dbManager,
|
||||
keyCache: make(map[string]ApiKey),
|
||||
cacheTTL: 1 * time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ApiKeyManager) refreshCache() error {
|
||||
m.cacheMu.RLock()
|
||||
if time.Since(m.cacheTime) < m.cacheTTL && len(m.keyCache) > 0 {
|
||||
m.cacheMu.RUnlock()
|
||||
return nil
|
||||
}
|
||||
m.cacheMu.RUnlock()
|
||||
|
||||
m.cacheMu.Lock()
|
||||
defer m.cacheMu.Unlock()
|
||||
|
||||
if time.Since(m.cacheTime) < m.cacheTTL && len(m.keyCache) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var apiKeys []database.APIKey
|
||||
result := m.dbManager.Conn.Find(&apiKeys)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
newCache := make(map[string]ApiKey)
|
||||
for _, apiKey := range apiKeys {
|
||||
newCache[apiKey.Key] = ApiKey{
|
||||
Name: apiKey.Name,
|
||||
Key: apiKey.Key,
|
||||
CreatedAt: apiKey.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
m.keyCache = newCache
|
||||
m.cacheTime = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *ApiKeyManager) VerifyKey(apiKey string) bool {
|
||||
if err := m.refreshCache(); err != nil {
|
||||
log.Warning("Failed to refresh API key cache: %v", err)
|
||||
|
||||
var count int64
|
||||
result := m.dbManager.Conn.Model(&database.APIKey{}).Where("key = ?", apiKey).Count(&count)
|
||||
if result.Error != nil {
|
||||
log.Warning("Failed to verify API key in database: %v", result.Error)
|
||||
return false
|
||||
}
|
||||
return count > 0
|
||||
}
|
||||
|
||||
m.cacheMu.RLock()
|
||||
defer m.cacheMu.RUnlock()
|
||||
for _, value := range m.keyCache {
|
||||
if value.Key == apiKey {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func generateKey() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
func (m *ApiKeyManager) CreateKey(name string) (string, error) {
|
||||
apiKey, err := generateKey()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
newAPIKey := database.APIKey{
|
||||
Name: name,
|
||||
Key: apiKey,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
result := m.dbManager.Conn.Create(&newAPIKey)
|
||||
if result.Error != nil {
|
||||
return "", fmt.Errorf("key with name '%s' already exists", name)
|
||||
}
|
||||
|
||||
m.cacheMu.Lock()
|
||||
m.keyCache[apiKey] = ApiKey{
|
||||
Name: name,
|
||||
Key: apiKey,
|
||||
CreatedAt: newAPIKey.CreatedAt,
|
||||
}
|
||||
m.cacheMu.Unlock()
|
||||
|
||||
log.Info("Created new API key with name: %s", name)
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
func (m *ApiKeyManager) DeleteKey(keyName string) error {
|
||||
result := m.dbManager.Conn.Where("name = ?", keyName).Delete(&database.APIKey{})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
m.cacheMu.Lock()
|
||||
for key, value := range m.keyCache {
|
||||
if value.Name == keyName {
|
||||
delete(m.keyCache, key)
|
||||
break
|
||||
}
|
||||
}
|
||||
m.cacheMu.Unlock()
|
||||
|
||||
if err := m.refreshCache(); err != nil {
|
||||
log.Warning("%v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *ApiKeyManager) GetAllKeys() ([]ApiKey, error) {
|
||||
if err := m.refreshCache(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.cacheMu.RLock()
|
||||
defer m.cacheMu.RUnlock()
|
||||
|
||||
keys := make([]ApiKey, 0, len(m.keyCache))
|
||||
for _, apiKey := range m.keyCache {
|
||||
keyCopy := apiKey
|
||||
keyCopy.Key = "redacted"
|
||||
keys = append(keys, keyCopy)
|
||||
}
|
||||
|
||||
sort.Slice(keys, func(i, j int) bool {
|
||||
return keys[j].CreatedAt.Before(keys[i].CreatedAt)
|
||||
})
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
9
backend/api/key/models.go
Normal file
9
backend/api/key/models.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package key
|
||||
|
||||
import "time"
|
||||
|
||||
type APIKey struct {
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
Name string `json:"name"`
|
||||
Key string `json:"key"`
|
||||
}
|
||||
55
backend/api/key/repository.go
Normal file
55
backend/api/key/repository.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package key
|
||||
|
||||
import (
|
||||
"goaway/backend/database"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Repository handles all database operations for API keys
|
||||
type Repository interface {
|
||||
Create(apiKey *database.APIKey) error
|
||||
FindByKey(key string) (*database.APIKey, error)
|
||||
FindAll() ([]database.APIKey, error)
|
||||
DeleteByName(name string) error
|
||||
CountByKey(key string) (int64, error)
|
||||
}
|
||||
|
||||
type repository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewRepository(db *gorm.DB) Repository {
|
||||
return &repository{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *repository) Create(apiKey *database.APIKey) error {
|
||||
return r.db.Create(apiKey).Error
|
||||
}
|
||||
|
||||
func (r *repository) FindByKey(key string) (*database.APIKey, error) {
|
||||
var apiKey database.APIKey
|
||||
err := r.db.Where("key = ?", key).First(&apiKey).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &apiKey, nil
|
||||
}
|
||||
|
||||
func (r *repository) FindAll() ([]database.APIKey, error) {
|
||||
var apiKeys []database.APIKey
|
||||
err := r.db.Find(&apiKeys).Error
|
||||
return apiKeys, err
|
||||
}
|
||||
|
||||
func (r *repository) DeleteByName(name string) error {
|
||||
return r.db.Where("name = ?", name).Delete(&database.APIKey{}).Error
|
||||
}
|
||||
|
||||
func (r *repository) CountByKey(key string) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&database.APIKey{}).Where("key = ?", key).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
176
backend/api/key/service.go
Normal file
176
backend/api/key/service.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package key
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"goaway/backend/database"
|
||||
"goaway/backend/logging"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Service handles business logic for API keys
|
||||
type Service struct {
|
||||
repository Repository
|
||||
cacheTime time.Time
|
||||
cacheTTL time.Duration
|
||||
|
||||
cacheMu sync.RWMutex
|
||||
keyCache map[string]APIKey
|
||||
}
|
||||
|
||||
var log = logging.GetLogger()
|
||||
|
||||
func NewService(repo Repository) *Service {
|
||||
return &Service{
|
||||
repository: repo,
|
||||
keyCache: make(map[string]APIKey),
|
||||
cacheTTL: 1 * time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) VerifyKey(apiKey string) bool {
|
||||
if err := s.refreshCache(); err != nil {
|
||||
log.Warning("Failed to refresh API key cache: %v", err)
|
||||
|
||||
count, err := s.repository.CountByKey(apiKey)
|
||||
if err != nil {
|
||||
log.Warning("Failed to verify API key in database: %v", err)
|
||||
return false
|
||||
}
|
||||
return count > 0
|
||||
}
|
||||
|
||||
s.cacheMu.RLock()
|
||||
defer s.cacheMu.RUnlock()
|
||||
|
||||
for _, value := range s.keyCache {
|
||||
if value.Key == apiKey {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// CreateKey generates and stores a new API key
|
||||
func (s *Service) CreateKey(name string) (string, error) {
|
||||
apiKey, err := generateKey()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
newAPIKey := database.APIKey{
|
||||
Name: name,
|
||||
Key: apiKey,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := s.repository.Create(&newAPIKey); err != nil {
|
||||
return "", fmt.Errorf("key with name '%s' already exists", name)
|
||||
}
|
||||
|
||||
s.cacheMu.Lock()
|
||||
s.keyCache[apiKey] = APIKey{
|
||||
Name: name,
|
||||
Key: apiKey,
|
||||
CreatedAt: newAPIKey.CreatedAt,
|
||||
}
|
||||
s.cacheMu.Unlock()
|
||||
|
||||
log.Info("Created new API key with name: %s", name)
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// DeleteKey removes an API key by name
|
||||
func (s *Service) DeleteKey(keyName string) error {
|
||||
if err := s.repository.DeleteByName(keyName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.cacheMu.Lock()
|
||||
for key, value := range s.keyCache {
|
||||
if value.Name == keyName {
|
||||
delete(s.keyCache, key)
|
||||
break
|
||||
}
|
||||
}
|
||||
s.cacheMu.Unlock()
|
||||
|
||||
if err := s.refreshCache(); err != nil {
|
||||
log.Warning("%v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllKeys returns all API keys with redacted key values
|
||||
func (s *Service) GetAllKeys() ([]APIKey, error) {
|
||||
if err := s.refreshCache(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.cacheMu.RLock()
|
||||
defer s.cacheMu.RUnlock()
|
||||
|
||||
keys := make([]APIKey, 0, len(s.keyCache))
|
||||
for _, apiKey := range s.keyCache {
|
||||
keyCopy := apiKey
|
||||
keyCopy.Key = "redacted"
|
||||
keys = append(keys, keyCopy)
|
||||
}
|
||||
|
||||
sort.Slice(keys, func(i, j int) bool {
|
||||
return keys[j].CreatedAt.Before(keys[i].CreatedAt)
|
||||
})
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// refreshCache updates the in-memory cache from the database
|
||||
func (s *Service) refreshCache() error {
|
||||
s.cacheMu.RLock()
|
||||
if time.Since(s.cacheTime) < s.cacheTTL && len(s.keyCache) > 0 {
|
||||
s.cacheMu.RUnlock()
|
||||
return nil
|
||||
}
|
||||
s.cacheMu.RUnlock()
|
||||
|
||||
s.cacheMu.Lock()
|
||||
defer s.cacheMu.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if time.Since(s.cacheTime) < s.cacheTTL && len(s.keyCache) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
apiKeys, err := s.repository.FindAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newCache := make(map[string]APIKey)
|
||||
for _, apiKey := range apiKeys {
|
||||
newCache[apiKey.Key] = APIKey{
|
||||
Name: apiKey.Name,
|
||||
Key: apiKey.Key,
|
||||
CreatedAt: apiKey.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
s.keyCache = newCache
|
||||
s.cacheTime = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateKey creates a random hex-encoded API key
|
||||
func generateKey() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
@@ -48,13 +48,13 @@ func (api *API) updateCustom(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err = api.Blacklist.AddCustomDomains(request.Domains)
|
||||
err = api.BlacklistService.AddCustomDomains(context.Background(), request.Domains)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to update custom blocklist."})
|
||||
return
|
||||
}
|
||||
|
||||
api.DNSServer.Audits.CreateAudit(&audit.Entry{
|
||||
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
|
||||
Topic: audit.TopicList,
|
||||
Message: fmt.Sprintf("Added %d domains to custom blacklist", len(request.Domains)),
|
||||
})
|
||||
@@ -62,7 +62,7 @@ func (api *API) updateCustom(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (api *API) getLists(c *gin.Context) {
|
||||
lists, err := api.Blacklist.GetAllListStatistics()
|
||||
lists, err := api.BlacklistService.GetAllListStatistics(context.Background())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -85,46 +85,46 @@ func (api *API) addList(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err = api.ValidateURLAndName(newList.URL, newList.Name, c)
|
||||
err = api.validateURLAndName(newList.URL, newList.Name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err = api.Blacklist.FetchAndLoadHosts(newList.URL, newList.Name); err != nil {
|
||||
if err = api.BlacklistService.FetchAndLoadHosts(context.Background(), newList.URL, newList.Name); err != nil {
|
||||
log.Error("Failed to fetch and load hosts: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := api.Blacklist.PopulateBlocklistCache(); err != nil {
|
||||
if err := api.BlacklistService.PopulateCache(context.Background()); err != nil {
|
||||
log.Error("Failed to populate blocklist cache: %v", err)
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := api.Blacklist.AddSource(newList.Name, newList.URL); err != nil {
|
||||
if err := api.BlacklistService.AddSource(context.Background(), newList.Name, newList.URL); err != nil {
|
||||
log.Error("Failed to add source: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if !newList.Active {
|
||||
if err := api.Blacklist.ToggleBlocklistStatus(newList.Name); err != nil {
|
||||
if err := api.BlacklistService.ToggleBlocklistStatus(context.Background(), newList.Name); err != nil {
|
||||
log.Error("Failed to toggle blocklist status: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to toggle status for " + newList.Name})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
_, addedList, err := api.Blacklist.GetListStatistics(newList.Name)
|
||||
_, addedList, err := api.BlacklistService.GetListStatistics(context.Background(), newList.Name)
|
||||
if err != nil {
|
||||
log.Error("Failed to get list statistics: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get list statistics"})
|
||||
return
|
||||
}
|
||||
|
||||
api.DNSServer.Audits.CreateAudit(&audit.Entry{
|
||||
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
|
||||
Topic: audit.TopicList,
|
||||
Message: fmt.Sprintf("New blacklist with name '%s' was added", addedList.Name),
|
||||
})
|
||||
@@ -150,31 +150,31 @@ func (api *API) addLists(c *gin.Context) {
|
||||
var addedList []NewList
|
||||
var ignoredList []NewList
|
||||
for _, list := range payload.Lists {
|
||||
if api.Blacklist.URLExists(list.URL) {
|
||||
if api.BlacklistService.URLExists(list.URL) {
|
||||
ignoredList = append(ignoredList, list)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := api.Blacklist.FetchAndLoadHosts(list.URL, list.Name); err != nil {
|
||||
if err := api.BlacklistService.FetchAndLoadHosts(context.Background(), list.URL, list.Name); err != nil {
|
||||
log.Error("Failed to fetch and load hosts: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := api.Blacklist.PopulateBlocklistCache(); err != nil {
|
||||
if err := api.BlacklistService.PopulateCache(context.Background()); err != nil {
|
||||
log.Error("Failed to populate blocklist cache: %v", err)
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := api.Blacklist.AddSource(list.Name, list.URL); err != nil {
|
||||
if err := api.BlacklistService.AddSource(context.Background(), list.Name, list.URL); err != nil {
|
||||
log.Error("Failed to add source: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if !list.Active {
|
||||
if err := api.Blacklist.ToggleBlocklistStatus(list.Name); err != nil {
|
||||
if err := api.BlacklistService.ToggleBlocklistStatus(context.Background(), list.Name); err != nil {
|
||||
log.Error("Failed to toggle blocklist status: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to toggle status for " + list.Name})
|
||||
return
|
||||
@@ -185,7 +185,7 @@ func (api *API) addLists(c *gin.Context) {
|
||||
}
|
||||
|
||||
if len(addedList) > 0 {
|
||||
api.DNSServer.Audits.CreateAudit(&audit.Entry{
|
||||
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
|
||||
Topic: audit.TopicList,
|
||||
Message: fmt.Sprintf("Added %d new blacklists in bulk", len(addedList)),
|
||||
})
|
||||
@@ -197,19 +197,19 @@ func (api *API) addLists(c *gin.Context) {
|
||||
func (api *API) updateListName(c *gin.Context) {
|
||||
oldName := c.Query("old")
|
||||
newName := c.Query("new")
|
||||
url := c.Query("url")
|
||||
listURL := c.Query("url")
|
||||
|
||||
if oldName == "" || newName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "New and old names are required"})
|
||||
return
|
||||
}
|
||||
|
||||
if !api.Blacklist.NameExists(oldName, url) {
|
||||
if !api.BlacklistService.NameExists(oldName, listURL) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "List with that name and url combination does not exist"})
|
||||
return
|
||||
}
|
||||
|
||||
err := api.Blacklist.UpdateSourceName(oldName, newName, url)
|
||||
err := api.BlacklistService.UpdateSourceName(context.Background(), oldName, newName, listURL)
|
||||
if err != nil {
|
||||
log.Warning("%s", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
@@ -221,19 +221,19 @@ func (api *API) updateListName(c *gin.Context) {
|
||||
|
||||
func (api *API) fetchUpdatedList(c *gin.Context) {
|
||||
name := c.Query("name")
|
||||
url := c.Query("url")
|
||||
listURL := c.Query("url")
|
||||
|
||||
if !api.Blacklist.NameExists(name, url) {
|
||||
if !api.BlacklistService.NameExists(name, listURL) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "List with that name and url combination does not exist"})
|
||||
return
|
||||
}
|
||||
|
||||
if name == "" || url == "" {
|
||||
if name == "" || listURL == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing 'name' or 'url' query parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
availableUpdate, err := api.Blacklist.CheckIfUpdateAvailable(url, name)
|
||||
availableUpdate, err := api.BlacklistService.CheckIfUpdateAvailable(context.Background(), listURL, name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -249,33 +249,33 @@ func (api *API) fetchUpdatedList(c *gin.Context) {
|
||||
|
||||
func (api *API) runUpdateList(c *gin.Context) {
|
||||
name := c.Query("name")
|
||||
url := c.Query("url")
|
||||
listURL := c.Query("url")
|
||||
|
||||
if !api.Blacklist.NameExists(name, url) {
|
||||
if !api.BlacklistService.NameExists(name, listURL) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "List does not exist"})
|
||||
return
|
||||
}
|
||||
|
||||
if name == "" || url == "" {
|
||||
if name == "" || listURL == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing 'name' or 'url' query parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
err := api.Blacklist.RemoveSourceAndDomains(name, url)
|
||||
err := api.BlacklistService.RemoveSourceAndDomains(context.Background(), name, listURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
err = api.Blacklist.FetchAndLoadHosts(url, name)
|
||||
err = api.BlacklistService.FetchAndLoadHosts(context.Background(), listURL, name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = api.DNSServer.Alerts.SendToAll(context.Background(), alert.Message{
|
||||
_ = api.DNSServer.AlertService.SendToAll(context.Background(), alert.Message{
|
||||
Title: "System",
|
||||
Content: fmt.Sprintf("List '%s' with url '%s' was updated! ", name, url),
|
||||
Content: fmt.Sprintf("List '%s' with url '%s' was updated! ", name, listURL),
|
||||
Severity: SeveritySuccess,
|
||||
})
|
||||
}()
|
||||
@@ -291,7 +291,7 @@ func (api *API) toggleBlocklist(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err := api.Blacklist.ToggleBlocklistStatus(blocklist)
|
||||
err := api.BlacklistService.ToggleBlocklistStatus(context.Background(), blocklist)
|
||||
if err != nil {
|
||||
log.Error("Failed to toggle blocklist status: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Failed to toggle status for %s", blocklist)})
|
||||
@@ -309,9 +309,9 @@ func (api *API) handleUpdateBlockStatus(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
action := map[string]func(string) error{
|
||||
"true": api.Blacklist.AddBlacklistedDomain,
|
||||
"false": api.Blacklist.RemoveDomain,
|
||||
action := map[string]func(context.Context, string) error{
|
||||
"true": api.BlacklistService.AddBlacklistedDomain,
|
||||
"false": api.BlacklistService.RemoveDomain,
|
||||
}[blocked]
|
||||
|
||||
if action == nil {
|
||||
@@ -319,7 +319,7 @@ func (api *API) handleUpdateBlockStatus(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := action(domain); err != nil {
|
||||
if err := action(context.Background(), domain); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"message": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -334,42 +334,42 @@ func (api *API) handleUpdateBlockStatus(c *gin.Context) {
|
||||
|
||||
func (api *API) removeList(c *gin.Context) {
|
||||
name := c.Query("name")
|
||||
url := c.Query("url")
|
||||
listURL := c.Query("url")
|
||||
|
||||
if !api.Blacklist.NameExists(name, url) {
|
||||
if !api.BlacklistService.NameExists(name, listURL) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "List does not exist"})
|
||||
return
|
||||
}
|
||||
|
||||
err := api.Blacklist.RemoveSourceAndDomains(name, url)
|
||||
err := api.BlacklistService.RemoveSourceAndDomains(context.Background(), name, listURL)
|
||||
if err != nil {
|
||||
log.Error("%v", err.Error())
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
}
|
||||
|
||||
if removed := api.Blacklist.RemoveSourceByNameAndURL(name, url); !removed {
|
||||
log.Error("Failed to remove source with name '%s' and url '%s'", name, url)
|
||||
if removed := api.BlacklistService.RemoveSourceByNameAndURL(name, listURL); !removed {
|
||||
log.Error("Failed to remove source with name '%s' and url '%s'", name, listURL)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to remove the list"})
|
||||
return
|
||||
}
|
||||
|
||||
api.DNSServer.Audits.CreateAudit(&audit.Entry{
|
||||
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
|
||||
Topic: audit.TopicList,
|
||||
Message: fmt.Sprintf("Blacklist with name '%s' was deleted", name),
|
||||
})
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
func (api *API) ValidateURLAndName(URL, name string, c *gin.Context) error {
|
||||
if name == "" || URL == "" {
|
||||
func (api *API) validateURLAndName(listURL, name string) error {
|
||||
if name == "" || listURL == "" {
|
||||
return fmt.Errorf("name and URL are required")
|
||||
}
|
||||
|
||||
if _, err := url.ParseRequestURI(URL); err != nil {
|
||||
if _, err := url.ParseRequestURI(listURL); err != nil {
|
||||
return fmt.Errorf("invalid URL format")
|
||||
}
|
||||
|
||||
if api.Blacklist.URLExists(URL) {
|
||||
if api.BlacklistService.URLExists(listURL) {
|
||||
return fmt.Errorf("list with the same URL already exists")
|
||||
}
|
||||
|
||||
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
TokenDuration = 5 * time.Minute
|
||||
Secret = "kMNSRwKip7Yet4rb2z8"
|
||||
tokenDuration = 5 * time.Minute
|
||||
secret = "kMNSRwKip7Yet4rb2z8"
|
||||
)
|
||||
|
||||
func (api *API) authMiddleware() gin.HandlerFunc {
|
||||
@@ -23,7 +23,7 @@ func (api *API) authMiddleware() gin.HandlerFunc {
|
||||
}
|
||||
|
||||
if apiKey := c.GetHeader("api-key"); apiKey != "" {
|
||||
if api.KeyManager.VerifyKey(apiKey) {
|
||||
if api.KeyService.VerifyKey(apiKey) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -62,7 +62,7 @@ func (api *API) authMiddleware() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
halfDurationSeconds := int64(TokenDuration.Seconds() / 2)
|
||||
halfDurationSeconds := int64(tokenDuration.Seconds() / 2)
|
||||
timeUntilExpiration := expiration - now
|
||||
|
||||
if timeUntilExpiration <= halfDurationSeconds {
|
||||
@@ -85,7 +85,7 @@ func parseToken(tokenString string) (jwt.MapClaims, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return []byte(Secret), nil
|
||||
return []byte(secret), nil
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
return nil, err
|
||||
@@ -103,11 +103,11 @@ func generateToken(username string) (string, error) {
|
||||
now := time.Now()
|
||||
claims := jwt.MapClaims{
|
||||
"username": username,
|
||||
"exp": now.Add(TokenDuration).Unix(),
|
||||
"exp": now.Add(tokenDuration).Unix(),
|
||||
"iat": now.Unix(),
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(Secret))
|
||||
return token.SignedString([]byte(secret))
|
||||
}
|
||||
|
||||
func setAuthCookie(w http.ResponseWriter, token string) {
|
||||
@@ -118,7 +118,7 @@ func setAuthCookie(w http.ResponseWriter, token string) {
|
||||
HttpOnly: true,
|
||||
Secure: false,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Expires: time.Now().Add(TokenDuration),
|
||||
MaxAge: int(TokenDuration.Seconds()),
|
||||
Expires: time.Now().Add(tokenDuration),
|
||||
MaxAge: int(tokenDuration.Seconds()),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package models
|
||||
|
||||
type QueryParams struct {
|
||||
Page int
|
||||
PageSize int
|
||||
Offset int
|
||||
Search string
|
||||
Column string
|
||||
Direction string
|
||||
FilterClient string
|
||||
Page int
|
||||
PageSize int
|
||||
Offset int
|
||||
}
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -16,7 +14,7 @@ func (api *API) registerNotificationRoutes() {
|
||||
}
|
||||
|
||||
func (api *API) fetchNotifications(c *gin.Context) {
|
||||
notifications, err := api.Notifications.ReadNotifications()
|
||||
notifications, err := api.NotificationService.GetNotifications()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
}
|
||||
@@ -28,21 +26,14 @@ func (api *API) markNotificationAsRead(c *gin.Context) {
|
||||
NotificationIDs []int `json:"notificationIds"`
|
||||
}
|
||||
|
||||
notificationsRead, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
log.Error("Failed to read request body: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
var request NotificationsRead
|
||||
if err := json.Unmarshal(notificationsRead, &request); err != nil {
|
||||
log.Error("Failed to parse JSON: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid JSON format"})
|
||||
err := c.BindJSON(&request)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Unable to parse request"})
|
||||
return
|
||||
}
|
||||
|
||||
err = api.Notifications.MarkNotificationsAsRead(request.NotificationIDs)
|
||||
err = api.NotificationService.MarkNotificationsAsRead(request.NotificationIDs)
|
||||
if err != nil {
|
||||
log.Warning("Unable to mark notifications as read %v", err)
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": fmt.Sprintf("Unable to mark notifications as read %v", err.Error())})
|
||||
|
||||
@@ -15,12 +15,12 @@ type RateLimiterConfig struct {
|
||||
}
|
||||
|
||||
type RateLimiter struct {
|
||||
Config RateLimiterConfig
|
||||
mutex *sync.RWMutex
|
||||
attempts map[string][]time.Time
|
||||
Config RateLimiterConfig
|
||||
}
|
||||
|
||||
func NewRateLimiter(enabled bool, maxTries int, window int) *RateLimiter {
|
||||
func NewRateLimiter(enabled bool, maxTries, window int) *RateLimiter {
|
||||
config := RateLimiterConfig{
|
||||
Enabled: enabled,
|
||||
MaxTries: maxTries,
|
||||
|
||||
@@ -3,7 +3,6 @@ package api
|
||||
import (
|
||||
"fmt"
|
||||
"goaway/backend/audit"
|
||||
"goaway/backend/dns/database"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -31,14 +30,14 @@ func (api *API) createResolution(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err := database.CreateNewResolution(api.DBManager.Conn, newResolution.IP, newResolution.Domain)
|
||||
err := api.ResolutionService.CreateResolution(newResolution.IP, newResolution.Domain)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
api.DNSServer.RemoveCachedDomain(newResolution.Domain)
|
||||
|
||||
api.DNSServer.Audits.CreateAudit(&audit.Entry{
|
||||
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
|
||||
Topic: audit.TopicResolution,
|
||||
Message: fmt.Sprintf("Added new resolution '%s'", newResolution.Domain),
|
||||
})
|
||||
@@ -46,7 +45,7 @@ func (api *API) createResolution(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (api *API) getResolutions(c *gin.Context) {
|
||||
resolutions, err := database.FetchResolutions(api.DBManager.Conn)
|
||||
resolutions, err := api.ResolutionService.GetResolutions()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -59,7 +58,7 @@ func (api *API) deleteResolution(c *gin.Context) {
|
||||
domain := c.Query("domain")
|
||||
ip := c.Query("ip")
|
||||
|
||||
rowsAffected, err := database.DeleteResolution(api.DBManager.Conn, ip, domain)
|
||||
rowsAffected, err := api.ResolutionService.DeleteResolution(ip, domain)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": err.Error(),
|
||||
@@ -74,7 +73,7 @@ func (api *API) deleteResolution(c *gin.Context) {
|
||||
|
||||
api.DNSServer.RemoveCachedDomain(domain)
|
||||
|
||||
api.DNSServer.Audits.CreateAudit(&audit.Entry{
|
||||
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
|
||||
Topic: audit.TopicResolution,
|
||||
Message: fmt.Sprintf("Removed resolution '%s'", domain),
|
||||
})
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"goaway/backend/dns/database"
|
||||
"goaway/backend/dns/server"
|
||||
"goaway/backend/updater"
|
||||
"net/http"
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
@@ -18,7 +19,7 @@ import (
|
||||
)
|
||||
|
||||
func (api *API) registerServerRoutes() {
|
||||
api.setupWSLiveCommunication(api.PrefetchedDomainsManager.DNS)
|
||||
api.setupWSLiveCommunication(api.DNS)
|
||||
|
||||
api.router.GET("/api/server", api.handleServer)
|
||||
api.router.GET("/api/dnsMetrics", api.handleMetrics)
|
||||
@@ -47,7 +48,7 @@ func (api *API) handleServer(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"portDNS": api.Config.DNS.Port,
|
||||
"portDNS": api.Config.DNS.Ports.TCPUDP,
|
||||
"portWebsite": api.DNSPort,
|
||||
"totalMem": float64(vMem.Total) / 1024 / 1024 / 1024,
|
||||
"usedMem": float64(vMem.Used) / 1024 / 1024 / 1024,
|
||||
@@ -56,7 +57,7 @@ func (api *API) handleServer(c *gin.Context) {
|
||||
"cpuTemp": temp,
|
||||
"dbSize": dbSize,
|
||||
"version": api.Version,
|
||||
"inAppUpdate": api.Config.InAppUpdate,
|
||||
"inAppUpdate": api.Config.Misc.InAppUpdate,
|
||||
"commit": api.Commit,
|
||||
"date": api.Date,
|
||||
})
|
||||
@@ -110,7 +111,7 @@ func getDBSizeMB() (float64, error) {
|
||||
}
|
||||
|
||||
func (api *API) handleMetrics(c *gin.Context) {
|
||||
allowed, blocked, err := api.Blacklist.GetAllowedAndBlocked()
|
||||
allowed, blocked, err := api.BlacklistService.GetAllowedAndBlocked(context.Background())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -122,14 +123,14 @@ func (api *API) handleMetrics(c *gin.Context) {
|
||||
percentageBlocked = (float64(blocked) / float64(total)) * 100
|
||||
}
|
||||
|
||||
domainsLength, _ := api.Blacklist.CountDomains()
|
||||
domainsLength, _ := api.BlacklistService.CountDomains(context.Background())
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"allowed": allowed,
|
||||
"blocked": blocked,
|
||||
"total": total,
|
||||
"percentageBlocked": percentageBlocked,
|
||||
"domainBlockLen": domainsLength,
|
||||
"clients": database.GetDistinctRequestIP(api.DBManager.Conn),
|
||||
"clients": api.RequestService.GetDistinctRequestIP(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -169,16 +170,47 @@ func (api *API) setupWSLiveCommunication(dnsServer *server.DNSServer) {
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(_ *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
conn.SetPongHandler(func(string) error {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
return nil
|
||||
})
|
||||
|
||||
api.WSCommunication = conn
|
||||
|
||||
if dnsServer != nil {
|
||||
dnsServer.WSCommunication = conn
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
if dnsServer != nil {
|
||||
dnsServer.WSCommunicationLock.Lock()
|
||||
dnsServer.WSCommunication = nil
|
||||
dnsServer.WSCommunicationLock.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
_, _, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
log.Warning("Websocket closed unexpectedly: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -36,11 +36,11 @@ func (api *API) updateSettings(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
api.Config.UpdateSettings(updatedSettings)
|
||||
settingsJson, _ := json.MarshalIndent(updatedSettings, "", " ")
|
||||
log.Debug("%s", string(settingsJson))
|
||||
api.Config.Update(updatedSettings)
|
||||
settingsJSON, _ := json.MarshalIndent(updatedSettings, "", " ")
|
||||
log.Debug("%s", string(settingsJSON))
|
||||
|
||||
api.DNSServer.Audits.CreateAudit(&audit.Entry{
|
||||
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
|
||||
Topic: audit.TopicSettings,
|
||||
Message: "Settings was updated",
|
||||
})
|
||||
@@ -65,7 +65,7 @@ func (api *API) exportDatabase(c *gin.Context) {
|
||||
_ = os.Remove(tempExport)
|
||||
|
||||
// Create a new connection to a temp file and vacuum into it
|
||||
if err := api.DBManager.Conn.Exec(fmt.Sprintf("VACUUM INTO '%s';", tempExport)).Error; err != nil {
|
||||
if err := api.DBConn.Exec(fmt.Sprintf("VACUUM INTO '%s';", tempExport)).Error; err != nil {
|
||||
log.Error("Failed to write WAL to temp export: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to prepare database for export"})
|
||||
return
|
||||
@@ -116,7 +116,7 @@ func (api *API) exportDatabase(c *gin.Context) {
|
||||
return n > 0
|
||||
})
|
||||
|
||||
api.DNSServer.Audits.CreateAudit(&audit.Entry{
|
||||
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
|
||||
Topic: audit.TopicDatabase,
|
||||
Message: "Database was exported",
|
||||
})
|
||||
@@ -125,7 +125,7 @@ func (api *API) exportDatabase(c *gin.Context) {
|
||||
func validateSQLiteFile(filePath string) error {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot open file: %v", err)
|
||||
return fmt.Errorf("cannot open file: %w", err)
|
||||
}
|
||||
go func() {
|
||||
_ = file.Close()
|
||||
@@ -133,7 +133,7 @@ func validateSQLiteFile(filePath string) error {
|
||||
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot stat file: %v", err)
|
||||
return fmt.Errorf("cannot stat file: %w", err)
|
||||
}
|
||||
|
||||
if stat.Size() < 50 {
|
||||
@@ -143,7 +143,7 @@ func validateSQLiteFile(filePath string) error {
|
||||
header := make([]byte, 16)
|
||||
_, err = file.Read(header)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot read file header: %v", err)
|
||||
return fmt.Errorf("cannot read file header: %w", err)
|
||||
}
|
||||
|
||||
expectedHeader := "SQLite format 3\x00"
|
||||
@@ -184,7 +184,7 @@ func (api *API) importDatabase(c *gin.Context) {
|
||||
}
|
||||
_, err = io.Copy(tempFile, file)
|
||||
|
||||
defer func(tempfile *os.File) {
|
||||
defer func(tempFile *os.File) {
|
||||
_ = tempFile.Close()
|
||||
}(tempFile)
|
||||
|
||||
@@ -229,7 +229,7 @@ func (api *API) importDatabase(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
sqlDB, err := api.DBManager.Conn.DB()
|
||||
sqlDB, err := api.DBConn.DB()
|
||||
if err != nil {
|
||||
log.Error("Failed to get underlying sql.DB: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to close current database"})
|
||||
@@ -260,15 +260,15 @@ func (api *API) importDatabase(c *gin.Context) {
|
||||
if err != nil {
|
||||
log.Error("Failed to open imported database with GORM: %v", err)
|
||||
_ = copyFile(backupPath, currentDBPath)
|
||||
api.DBManager.Conn, _ = gorm.Open(sqlite.Open(currentDBPath), &gorm.Config{})
|
||||
api.DBConn, _ = gorm.Open(sqlite.Open(currentDBPath), &gorm.Config{})
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to open imported database, restored from backup"})
|
||||
return
|
||||
}
|
||||
|
||||
api.DBManager.Conn = newDB
|
||||
*api.DBConn = *newDB
|
||||
|
||||
log.Info("Database imported successfully from %s", header.Filename)
|
||||
api.DNSServer.Audits.CreateAudit(&audit.Entry{
|
||||
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
|
||||
Topic: audit.TopicDatabase,
|
||||
Message: "Database was imported",
|
||||
})
|
||||
|
||||
@@ -18,14 +18,14 @@ import (
|
||||
probing "github.com/prometheus-community/pro-bing"
|
||||
)
|
||||
|
||||
type PingResult struct {
|
||||
Duration time.Duration
|
||||
type pingResult struct {
|
||||
Error error
|
||||
Method string
|
||||
Duration time.Duration
|
||||
Successful bool
|
||||
}
|
||||
|
||||
func (pr PingResult) String() string {
|
||||
func (pr pingResult) String() string {
|
||||
if !pr.Successful {
|
||||
return fmt.Sprintf("Failed (%s)", pr.Method)
|
||||
}
|
||||
@@ -68,16 +68,16 @@ func (api *API) createUpstream(c *gin.Context) {
|
||||
upstream += ":53"
|
||||
}
|
||||
|
||||
if slices.Contains(api.Config.DNS.UpstreamDNS, upstream) {
|
||||
if slices.Contains(api.Config.DNS.Upstream.Fallback, upstream) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Upstream already exists"})
|
||||
return
|
||||
}
|
||||
|
||||
api.Config.DNS.UpstreamDNS = append(api.Config.DNS.UpstreamDNS, upstream)
|
||||
api.Config.DNS.Upstream.Fallback = append(api.Config.DNS.Upstream.Fallback, upstream)
|
||||
api.Config.Save()
|
||||
|
||||
log.Info("Added %s as a new upstream", upstream)
|
||||
api.DNSServer.Audits.CreateAudit(&audit.Entry{
|
||||
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
|
||||
Topic: audit.TopicUpstream,
|
||||
Message: fmt.Sprintf("Added a new upstream '%s'", request.Upstream),
|
||||
})
|
||||
@@ -87,9 +87,9 @@ func (api *API) createUpstream(c *gin.Context) {
|
||||
|
||||
func (api *API) getUpstreams(c *gin.Context) {
|
||||
var (
|
||||
upstreams = api.Config.DNS.UpstreamDNS
|
||||
upstreams = api.Config.DNS.Upstream.Fallback
|
||||
results = make([]map[string]any, len(upstreams))
|
||||
preferredUpstream = api.Config.DNS.PreferredUpstream
|
||||
preferredUpstream = api.Config.DNS.Upstream.Preferred
|
||||
wg sync.WaitGroup
|
||||
)
|
||||
wg.Add(len(upstreams))
|
||||
@@ -158,7 +158,7 @@ func resolveHostname(host string) string {
|
||||
return "No IP found"
|
||||
}
|
||||
|
||||
func measureDNSPing(upstream string) PingResult {
|
||||
func measureDNSPing(upstream string) pingResult {
|
||||
var (
|
||||
testDomains = []string{"google.com", "cloudflare.com", "quad9.net"}
|
||||
client = &dns.Client{
|
||||
@@ -193,7 +193,7 @@ func measureDNSPing(upstream string) PingResult {
|
||||
}
|
||||
|
||||
if successCount == 0 {
|
||||
return PingResult{
|
||||
return pingResult{
|
||||
Duration: 0,
|
||||
Error: lastError,
|
||||
Method: "dns",
|
||||
@@ -202,7 +202,7 @@ func measureDNSPing(upstream string) PingResult {
|
||||
}
|
||||
|
||||
avgDuration := totalDuration / time.Duration(successCount)
|
||||
return PingResult{
|
||||
return pingResult{
|
||||
Duration: avgDuration,
|
||||
Error: nil,
|
||||
Method: "dns",
|
||||
@@ -210,7 +210,7 @@ func measureDNSPing(upstream string) PingResult {
|
||||
}
|
||||
}
|
||||
|
||||
func measureICMPPing(host string) PingResult {
|
||||
func measureICMPPing(host string) pingResult {
|
||||
icmpResult := tryICMPPing(host)
|
||||
if icmpResult.Successful {
|
||||
return icmpResult
|
||||
@@ -220,10 +220,10 @@ func measureICMPPing(host string) PingResult {
|
||||
return tcpResult
|
||||
}
|
||||
|
||||
func tryICMPPing(host string) PingResult {
|
||||
func tryICMPPing(host string) pingResult {
|
||||
pinger, err := probing.NewPinger(host)
|
||||
if err != nil {
|
||||
return PingResult{
|
||||
return pingResult{
|
||||
Duration: 0,
|
||||
Error: err,
|
||||
Method: "icmp",
|
||||
@@ -237,7 +237,7 @@ func tryICMPPing(host string) PingResult {
|
||||
|
||||
err = pinger.Run()
|
||||
if err != nil {
|
||||
return PingResult{
|
||||
return pingResult{
|
||||
Duration: 0,
|
||||
Error: err,
|
||||
Method: "icmp",
|
||||
@@ -247,7 +247,7 @@ func tryICMPPing(host string) PingResult {
|
||||
|
||||
stats := pinger.Statistics()
|
||||
if stats.PacketsRecv == 0 {
|
||||
return PingResult{
|
||||
return pingResult{
|
||||
Duration: 0,
|
||||
Error: fmt.Errorf("no packets received"),
|
||||
Method: "icmp",
|
||||
@@ -255,7 +255,7 @@ func tryICMPPing(host string) PingResult {
|
||||
}
|
||||
}
|
||||
|
||||
return PingResult{
|
||||
return pingResult{
|
||||
Duration: stats.AvgRtt,
|
||||
Error: nil,
|
||||
Method: "icmp",
|
||||
@@ -263,12 +263,12 @@ func tryICMPPing(host string) PingResult {
|
||||
}
|
||||
}
|
||||
|
||||
func tryTCPPing(host string) PingResult {
|
||||
func tryTCPPing(host string) pingResult {
|
||||
start := time.Now()
|
||||
|
||||
conn, err := net.DialTimeout("tcp", net.JoinHostPort(host, "53"), 2*time.Second)
|
||||
if err != nil {
|
||||
return PingResult{
|
||||
return pingResult{
|
||||
Duration: 0,
|
||||
Error: err,
|
||||
Method: "tcp",
|
||||
@@ -281,7 +281,7 @@ func tryTCPPing(host string) PingResult {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
return PingResult{
|
||||
return pingResult{
|
||||
Duration: duration,
|
||||
Error: nil,
|
||||
Method: "tcp",
|
||||
@@ -333,22 +333,22 @@ func (api *API) updatePreferredUpstream(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if !slices.Contains(api.Config.DNS.UpstreamDNS, request.Upstream) {
|
||||
if !slices.Contains(api.Config.DNS.Upstream.Fallback, request.Upstream) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Upstream not found"})
|
||||
return
|
||||
}
|
||||
|
||||
if api.Config.DNS.PreferredUpstream == request.Upstream {
|
||||
if api.Config.DNS.Upstream.Preferred == request.Upstream {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Preferred upstream already set to %s", request.Upstream)})
|
||||
return
|
||||
}
|
||||
|
||||
api.Config.DNS.PreferredUpstream = request.Upstream
|
||||
api.Config.DNS.Upstream.Preferred = request.Upstream
|
||||
message := fmt.Sprintf("Preferred upstream set to %s", request.Upstream)
|
||||
log.Info("%s", message)
|
||||
|
||||
api.Config.Save()
|
||||
api.DNSServer.Audits.CreateAudit(&audit.Entry{
|
||||
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
|
||||
Topic: audit.TopicUpstream,
|
||||
Message: fmt.Sprintf("New preferred upstream '%s'", request.Upstream),
|
||||
})
|
||||
@@ -364,23 +364,23 @@ func (api *API) deleteUpstream(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if !slices.Contains(api.Config.DNS.UpstreamDNS, upstreamToDelete) {
|
||||
if !slices.Contains(api.Config.DNS.Upstream.Fallback, upstreamToDelete) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("Upstream %s not found", upstreamToDelete)})
|
||||
return
|
||||
}
|
||||
|
||||
var updatedUpstreams []string
|
||||
for _, upstream := range api.Config.DNS.UpstreamDNS {
|
||||
for _, upstream := range api.Config.DNS.Upstream.Fallback {
|
||||
if upstream != upstreamToDelete {
|
||||
updatedUpstreams = append(updatedUpstreams, upstream)
|
||||
}
|
||||
}
|
||||
|
||||
api.Config.DNS.UpstreamDNS = updatedUpstreams
|
||||
api.Config.DNS.Upstream.Fallback = updatedUpstreams
|
||||
api.Config.Save()
|
||||
log.Info("Removed upstream: %s", upstreamToDelete)
|
||||
|
||||
api.DNSServer.Audits.CreateAudit(&audit.Entry{
|
||||
api.DNSServer.AuditService.CreateAudit(&audit.Entry{
|
||||
Topic: audit.TopicUpstream,
|
||||
Message: fmt.Sprintf("Removed upstream '%s'", upstreamToDelete),
|
||||
})
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
package user
|
||||
|
||||
type User struct {
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
@@ -1,119 +0,0 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"goaway/backend/dns/database"
|
||||
"goaway/backend/logging"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var log = logging.GetLogger()
|
||||
|
||||
type Credentials struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
func (user *User) Create(db *gorm.DB) error {
|
||||
log.Info("Creating a new user with name '%s'", user.Username)
|
||||
|
||||
hashedPassword, err := hashPassword(user.Password)
|
||||
if err != nil {
|
||||
log.Error("Failed to hash password: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
user.Password = hashedPassword
|
||||
|
||||
result := db.Create(user)
|
||||
if result.Error != nil {
|
||||
log.Error("Failed to create user: %v", result.Error)
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
log.Error("User creation failed: no rows affected")
|
||||
return errors.New("user creation failed: no rows affected")
|
||||
}
|
||||
|
||||
log.Debug("User created successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func hashPassword(password string) (string, error) {
|
||||
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
return string(hashed), err
|
||||
}
|
||||
|
||||
func (user *User) Exists(db *gorm.DB) bool {
|
||||
query := database.User{}
|
||||
db.Where("username = ?", user.Username).Find(&query)
|
||||
return query.Username != ""
|
||||
}
|
||||
|
||||
func (user *User) Authenticate(db *gorm.DB) bool {
|
||||
var dbUser User
|
||||
if err := db.Where("username = ?", user.Username).First(&dbUser).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Error("User not found: %s", user.Username)
|
||||
return false
|
||||
}
|
||||
log.Error("Query error: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(dbUser.Password), []byte(user.Password)); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (user *User) UpdatePassword(db *gorm.DB) error {
|
||||
hashedPassword, err := hashPassword(user.Password)
|
||||
if err != nil {
|
||||
log.Error("Failed to hash new password: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
affected, err := gorm.G[database.User](db).Where("username = ?", user.Username).Update(context.Background(), "password", hashedPassword)
|
||||
if err != nil {
|
||||
log.Error("Failed to update password: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if affected == 0 {
|
||||
log.Error("Password update failed: no rows affected")
|
||||
return errors.New("password update failed: no rows affected")
|
||||
}
|
||||
|
||||
log.Debug("Password updated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Credentials) Validate() error {
|
||||
c.Username = strings.TrimSpace(c.Username)
|
||||
c.Password = strings.TrimSpace(c.Password)
|
||||
|
||||
if c.Username == "" || c.Password == "" {
|
||||
return errors.New("username and password cannot be empty")
|
||||
}
|
||||
|
||||
if len(c.Username) > 60 {
|
||||
return errors.New("username too long")
|
||||
}
|
||||
if len(c.Password) > 120 {
|
||||
return errors.New("password too long")
|
||||
}
|
||||
|
||||
for _, r := range c.Username {
|
||||
if r < 32 || r == 127 {
|
||||
return errors.New("username contains invalid characters")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -2,7 +2,7 @@ package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"goaway/backend/dns/database"
|
||||
"goaway/backend/database"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
@@ -34,7 +34,7 @@ func (api *API) addWhitelisted(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err = api.Whitelist.AddDomain(newDomain.Domain)
|
||||
err = api.WhitelistService.AddDomain(newDomain.Domain)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -46,7 +46,7 @@ func (api *API) addWhitelisted(c *gin.Context) {
|
||||
func (api *API) getWhitelistedDomains(c *gin.Context) {
|
||||
var domains []string
|
||||
|
||||
err := api.DBManager.Conn.Model(&database.Whitelist{}).Pluck("domain", &domains).Error
|
||||
err := api.DBConn.Model(&database.Whitelist{}).Pluck("domain", &domains).Error
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "Failed to retrieve whitelisted domains"})
|
||||
return
|
||||
@@ -58,7 +58,7 @@ func (api *API) getWhitelistedDomains(c *gin.Context) {
|
||||
func (api *API) deleteWhitelistedDomain(c *gin.Context) {
|
||||
newDomain := c.Query("domain")
|
||||
|
||||
err := api.Whitelist.RemoveDomain(newDomain)
|
||||
err := api.WhitelistService.RemoveDomain(newDomain)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
|
||||
111
backend/app.go
Normal file
111
backend/app.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"fmt"
|
||||
"goaway/backend/alert"
|
||||
"goaway/backend/api/key"
|
||||
"goaway/backend/audit"
|
||||
"goaway/backend/blacklist"
|
||||
"goaway/backend/lifecycle"
|
||||
"goaway/backend/logging"
|
||||
"goaway/backend/mac"
|
||||
"goaway/backend/notification"
|
||||
"goaway/backend/prefetch"
|
||||
"goaway/backend/request"
|
||||
"goaway/backend/resolution"
|
||||
"goaway/backend/services"
|
||||
"goaway/backend/settings"
|
||||
"goaway/backend/setup"
|
||||
"goaway/backend/user"
|
||||
"goaway/backend/whitelist"
|
||||
)
|
||||
|
||||
var log = logging.GetLogger()
|
||||
|
||||
type Application struct {
|
||||
config *settings.Config
|
||||
context *services.AppContext
|
||||
services *services.ServiceRegistry
|
||||
lifecycle *lifecycle.Manager
|
||||
content embed.FS
|
||||
version string
|
||||
commit string
|
||||
date string
|
||||
}
|
||||
|
||||
func New(setFlags *setup.SetFlags, version, commit, date string, content embed.FS) *Application {
|
||||
config := setup.InitializeSettings(setFlags)
|
||||
|
||||
return &Application{
|
||||
config: config,
|
||||
version: version,
|
||||
commit: commit,
|
||||
date: date,
|
||||
content: content,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Application) Start() error {
|
||||
|
||||
ctx, err := services.NewAppContext(a.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize application context: %w", err)
|
||||
}
|
||||
a.context = ctx
|
||||
|
||||
dbConn := a.context.DBConn
|
||||
alertService := alert.NewService(alert.NewRepository(dbConn))
|
||||
auditService := audit.NewService(audit.NewRepository(dbConn))
|
||||
blacklistService := blacklist.NewService(blacklist.NewRepository(dbConn))
|
||||
keyService := key.NewService(key.NewRepository(dbConn))
|
||||
macService := mac.NewService(mac.NewRepository(dbConn))
|
||||
notificationService := notification.NewService(notification.NewRepository(dbConn))
|
||||
prefetchService := prefetch.NewService(prefetch.NewRepository(dbConn), a.context.DNSServer)
|
||||
requestService := request.NewService(request.NewRepository(dbConn))
|
||||
resolutionService := resolution.NewService(resolution.NewRepository(dbConn))
|
||||
userService := user.NewService(user.NewRepository(dbConn))
|
||||
whitelistService := whitelist.NewService(whitelist.NewRepository(dbConn))
|
||||
|
||||
a.context.DNSServer.AlertService = alertService
|
||||
a.context.DNSServer.AuditService = auditService
|
||||
a.context.DNSServer.BlacklistService = blacklistService
|
||||
a.context.DNSServer.MACService = macService
|
||||
a.context.DNSServer.NotificationService = notificationService
|
||||
a.context.DNSServer.RequestService = requestService
|
||||
a.context.DNSServer.UserService = userService
|
||||
a.context.DNSServer.ResolutionService = resolutionService
|
||||
a.context.DNSServer.WhitelistService = whitelistService
|
||||
|
||||
a.displayStartupInfo()
|
||||
|
||||
a.services = services.NewServiceRegistry(a.context, a.version, a.commit, a.date, a.content)
|
||||
a.services.ResolutionService = resolutionService
|
||||
a.services.BlacklistService = blacklistService
|
||||
a.services.NotificationService = notificationService
|
||||
a.services.PrefetchService = prefetchService
|
||||
a.services.RequestService = requestService
|
||||
a.services.UserService = userService
|
||||
a.services.KeyService = keyService
|
||||
a.services.WhitelistService = whitelistService
|
||||
a.lifecycle = lifecycle.NewManager(a.services)
|
||||
|
||||
runServices := a.lifecycle.Run()
|
||||
return runServices
|
||||
}
|
||||
|
||||
func (a *Application) displayStartupInfo() {
|
||||
domains, err := a.context.DNSServer.BlacklistService.CountDomains(context.Background())
|
||||
if err != nil {
|
||||
log.Warning("Failed to count blacklist domains: %v", err)
|
||||
}
|
||||
|
||||
currentVersion := setup.GetVersionOrDefault(a.version)
|
||||
ASCIIArt(
|
||||
a.config,
|
||||
domains,
|
||||
currentVersion.Original(),
|
||||
a.config.API.Authentication,
|
||||
)
|
||||
}
|
||||
@@ -1,7 +1,8 @@
|
||||
package asciiart
|
||||
package app
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"goaway/backend/logging"
|
||||
"goaway/backend/settings"
|
||||
)
|
||||
|
||||
@@ -15,9 +16,19 @@ const (
|
||||
Magenta = "\033[35m"
|
||||
)
|
||||
|
||||
func AsciiArt(config *settings.Config, blockedDomains int, version string, disableAuth bool, ansi bool) {
|
||||
var art = `
|
||||
__ _ ___ __ ___ ____ _ _ _ DNS port: %s
|
||||
/ _' |/ _ \ / _' \ \ /\ / / _' | | | | Web port: %s
|
||||
| (_| | (_) | (_| |\ V V / (_| | |_| | Upstream: %s
|
||||
\__, |\___/ \__,_| \_/\_/ \__,_|\__, | Authentication: %s
|
||||
__/ | __/ | Cache TTL: %s
|
||||
|___/ %s |___/ Blocked Domains: %s
|
||||
`
|
||||
|
||||
func ASCIIArt(config *settings.Config, blockedDomains int, version string, disableAuth bool) {
|
||||
const versionSpace = 7
|
||||
|
||||
var ansi = logging.GetLogger().Ansi
|
||||
colorize := func(color, text string) string {
|
||||
if !ansi {
|
||||
return text
|
||||
@@ -28,20 +39,21 @@ func AsciiArt(config *settings.Config, blockedDomains int, version string, disab
|
||||
versionFormatted := fmt.Sprintf("%-*s%s%-*s", (versionSpace-len(version))/2, "",
|
||||
colorize(Cyan, version), (versionSpace-len(version)+1)/2, "")
|
||||
|
||||
portFormatted := colorize(Green, fmt.Sprintf("%d", config.DNS.Port))
|
||||
portFormatted := colorize(Green, fmt.Sprintf("%d", config.DNS.Ports.TCPUDP))
|
||||
adminPanelFormatted := colorize(Red, fmt.Sprintf("%d", config.API.Port))
|
||||
upstreamFormatted := colorize(Cyan, config.DNS.PreferredUpstream)
|
||||
upstreamFormatted := colorize(Cyan, config.DNS.Upstream.Preferred)
|
||||
authFormatted := colorize(Yellow, fmt.Sprintf("%v", disableAuth))
|
||||
cacheTTLFormatted := colorize(Blue, fmt.Sprintf("%d", config.DNS.CacheTTL))
|
||||
blockedDomainsFormatted := colorize(Magenta, fmt.Sprintf("%d", blockedDomains))
|
||||
|
||||
fmt.Printf(`
|
||||
__ _ ___ __ ___ ____ _ _ _ DNS port: %s
|
||||
/ _' |/ _ \ / _' \ \ /\ / / _' | | | | Web port: %s
|
||||
| (_| | (_) | (_| |\ V V / (_| | |_| | Upstream: %s
|
||||
\__, |\___/ \__,_| \_/\_/ \__,_|\__, | Authentication: %s
|
||||
__/ | __/ | Cache TTL: %s
|
||||
|___/ %s |___/ Blocked Domains: %s
|
||||
|
||||
`, portFormatted, adminPanelFormatted, upstreamFormatted, authFormatted, cacheTTLFormatted, versionFormatted, blockedDomainsFormatted)
|
||||
fmt.Printf(art,
|
||||
portFormatted,
|
||||
adminPanelFormatted,
|
||||
upstreamFormatted,
|
||||
authFormatted,
|
||||
cacheTTLFormatted,
|
||||
versionFormatted,
|
||||
blockedDomainsFormatted,
|
||||
)
|
||||
fmt.Println()
|
||||
}
|
||||
@@ -1,64 +1,40 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"goaway/backend/dns/database"
|
||||
"goaway/backend/logging"
|
||||
"goaway/backend/database"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
dbManager *database.DatabaseManager
|
||||
dbConn *gorm.DB
|
||||
}
|
||||
|
||||
type Topic string
|
||||
|
||||
const (
|
||||
TopicServer Topic = "server"
|
||||
TopicDNS Topic = "dns"
|
||||
TopicAPI Topic = "api"
|
||||
TopicResolution Topic = "resolution"
|
||||
TopicPrefetch Topic = "prefetch"
|
||||
TopicUpstream Topic = "upstream"
|
||||
TopicUser Topic = "user"
|
||||
TopicList Topic = "list"
|
||||
TopicLogs Topic = "logs"
|
||||
TopicSettings Topic = "settings"
|
||||
TopicDatabase Topic = "database"
|
||||
)
|
||||
|
||||
type Entry struct {
|
||||
Id int `json:"id"`
|
||||
Topic Topic `json:"topic"`
|
||||
Message string `json:"message"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
func NewAuditManager(dbconn *gorm.DB) *Manager {
|
||||
return &Manager{dbConn: dbconn}
|
||||
}
|
||||
|
||||
var logger = logging.GetLogger()
|
||||
|
||||
func NewAuditManager(dbManager *database.DatabaseManager) *Manager {
|
||||
return &Manager{dbManager: dbManager}
|
||||
}
|
||||
|
||||
func (nm *Manager) CreateAudit(newAudit *Entry) {
|
||||
func (m *Manager) CreateAudit(newAudit *Entry) {
|
||||
audit := database.Audit{
|
||||
Topic: string(newAudit.Topic),
|
||||
Message: newAudit.Message,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
result := nm.dbManager.Conn.Create(&audit)
|
||||
result := m.dbConn.Create(&audit)
|
||||
if result.Error != nil {
|
||||
logger.Warning("Unable to create new audit, error: %v", result.Error)
|
||||
log.Warning("Unable to create new audit, error: %v", result.Error)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug("Created new audit, %+v", newAudit)
|
||||
log.Debug("Created new audit, %+v", newAudit)
|
||||
}
|
||||
|
||||
func (nm *Manager) ReadAudits() ([]Entry, error) {
|
||||
func (m *Manager) ReadAudits() ([]Entry, error) {
|
||||
var audits []database.Audit
|
||||
|
||||
result := nm.dbManager.Conn.Order("created_at DESC").Find(&audits)
|
||||
result := m.dbConn.Order("created_at DESC").Find(&audits)
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
@@ -66,7 +42,7 @@ func (nm *Manager) ReadAudits() ([]Entry, error) {
|
||||
entries := make([]Entry, len(audits))
|
||||
for i, audit := range audits {
|
||||
entries[i] = Entry{
|
||||
Id: int(audit.ID),
|
||||
ID: audit.ID,
|
||||
Topic: Topic(audit.Topic),
|
||||
Message: audit.Message,
|
||||
CreatedAt: audit.CreatedAt,
|
||||
|
||||
51
backend/audit/repository.go
Normal file
51
backend/audit/repository.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"goaway/backend/database"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
CreateAudit(audit *Entry) error
|
||||
ReadAudits() ([]Entry, error)
|
||||
}
|
||||
|
||||
type repository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewRepository(db *gorm.DB) Repository {
|
||||
return &repository{db: db}
|
||||
}
|
||||
|
||||
func (r *repository) CreateAudit(audit *Entry) error {
|
||||
dbAudit := database.Audit{
|
||||
Topic: string(audit.Topic),
|
||||
Message: audit.Message,
|
||||
CreatedAt: audit.CreatedAt,
|
||||
}
|
||||
|
||||
result := r.db.Create(&dbAudit)
|
||||
return result.Error
|
||||
}
|
||||
|
||||
func (r *repository) ReadAudits() ([]Entry, error) {
|
||||
var dbAudits []database.Audit
|
||||
result := r.db.Order("created_at DESC").Find(&dbAudits)
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
audits := make([]Entry, len(dbAudits))
|
||||
for i, dbAudit := range dbAudits {
|
||||
audits[i] = Entry{
|
||||
ID: dbAudit.ID,
|
||||
Topic: Topic(dbAudit.Topic),
|
||||
Message: dbAudit.Message,
|
||||
CreatedAt: dbAudit.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
return audits, nil
|
||||
}
|
||||
50
backend/audit/service.go
Normal file
50
backend/audit/service.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"goaway/backend/logging"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Topic string
|
||||
|
||||
const (
|
||||
TopicServer Topic = "server"
|
||||
TopicDNS Topic = "dns"
|
||||
TopicAPI Topic = "api"
|
||||
TopicResolution Topic = "resolution"
|
||||
TopicPrefetch Topic = "prefetch"
|
||||
TopicUpstream Topic = "upstream"
|
||||
TopicUser Topic = "user"
|
||||
TopicList Topic = "list"
|
||||
TopicLogs Topic = "logs"
|
||||
TopicSettings Topic = "settings"
|
||||
TopicDatabase Topic = "database"
|
||||
)
|
||||
|
||||
type Entry struct {
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
Topic Topic `json:"topic"`
|
||||
Message string `json:"message"`
|
||||
ID uint `json:"id"`
|
||||
}
|
||||
|
||||
var log = logging.GetLogger()
|
||||
|
||||
type Service struct {
|
||||
repository Repository
|
||||
}
|
||||
|
||||
func NewService(repo Repository) *Service {
|
||||
return &Service{repository: repo}
|
||||
}
|
||||
|
||||
func (s *Service) CreateAudit(entry *Entry) {
|
||||
err := s.repository.CreateAudit(entry)
|
||||
if err != nil {
|
||||
log.Warning("Could not create audit: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) ReadAudits() ([]Entry, error) {
|
||||
return s.repository.ReadAudits()
|
||||
}
|
||||
16
backend/blacklist/model.go
Normal file
16
backend/blacklist/model.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package blacklist
|
||||
|
||||
type ListUpdateAvailable struct {
|
||||
RemoteChecksum string `json:"remoteChecksum"`
|
||||
DBChecksum string `json:"dbChecksum"`
|
||||
RemoteDomains []string `json:"remoteDomains"`
|
||||
DBDomains []string `json:"dbDomains"`
|
||||
DiffAdded []string `json:"diffAdded"`
|
||||
DiffRemoved []string `json:"diffRemoved"`
|
||||
UpdateAvailable bool `json:"updateAvailable"`
|
||||
}
|
||||
|
||||
type BlocklistSource struct {
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
383
backend/blacklist/repository.go
Normal file
383
backend/blacklist/repository.go
Normal file
@@ -0,0 +1,383 @@
|
||||
package blacklist
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"goaway/backend/database"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type SourceRepository interface {
|
||||
GetSources(ctx context.Context, excludeCustom bool) ([]database.Source, error)
|
||||
GetSourceByName(ctx context.Context, name string) (*database.Source, error)
|
||||
GetSourceByNameAndURL(ctx context.Context, name, url string) (*database.Source, error)
|
||||
CreateOrUpdateSource(ctx context.Context, source *database.Source) error
|
||||
UpdateSourceName(ctx context.Context, oldName, newName, url string) error
|
||||
UpdateSourceLastUpdated(ctx context.Context, url string, timestamp time.Time) error
|
||||
ToggleSourceActive(ctx context.Context, name string) error
|
||||
DeleteSource(ctx context.Context, name, url string) error
|
||||
UpsertSource(ctx context.Context, source *database.Source) error
|
||||
}
|
||||
|
||||
type DomainRepository interface {
|
||||
GetAllDomains(ctx context.Context) ([]string, error)
|
||||
GetDomainsForSource(ctx context.Context, sourceName string) ([]string, error)
|
||||
GetPaginatedDomains(ctx context.Context, page, pageSize int, search string) ([]database.Blacklist, int64, error)
|
||||
CountDomains(ctx context.Context) (int64, error)
|
||||
CreateDomain(ctx context.Context, domain *database.Blacklist) error
|
||||
CreateDomainsInBatches(ctx context.Context, domains []database.Blacklist, batchSize int) error
|
||||
DeleteDomain(ctx context.Context, domain string) error
|
||||
DeleteDomainsBySourceID(ctx context.Context, sourceID uint) error
|
||||
DeleteCustomDomain(ctx context.Context, domain string, sourceID uint) error
|
||||
}
|
||||
|
||||
type StatsRepository interface {
|
||||
GetAllSourceStats(ctx context.Context) ([]SourceWithCount, error)
|
||||
GetSourceStats(ctx context.Context, listname string) (*SourceWithCount, error)
|
||||
GetRequestStats(ctx context.Context) ([]RequestStats, error)
|
||||
}
|
||||
|
||||
type TransactionRepository interface {
|
||||
WithTransaction(ctx context.Context, fn func(*gorm.DB) error) error
|
||||
Vacuum(ctx context.Context) error
|
||||
}
|
||||
|
||||
type Repository interface {
|
||||
SourceRepository
|
||||
DomainRepository
|
||||
StatsRepository
|
||||
TransactionRepository
|
||||
}
|
||||
|
||||
type SourceWithCount struct {
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
ID uint `json:"id"`
|
||||
LastUpdated time.Time `json:"lastUpdated"`
|
||||
BlockedCount int `json:"blockedCount"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
type RequestStats struct {
|
||||
Blocked bool `json:"blocked"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
type repository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewRepository(db *gorm.DB) Repository {
|
||||
return &repository{db: db}
|
||||
}
|
||||
|
||||
func (r *repository) GetSources(ctx context.Context, excludeCustom bool) ([]database.Source, error) {
|
||||
var sources []database.Source
|
||||
query := r.db.WithContext(ctx)
|
||||
|
||||
if excludeCustom {
|
||||
query = query.Where("name != ?", "Custom")
|
||||
}
|
||||
|
||||
if err := query.Find(&sources).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to query sources: %w", err)
|
||||
}
|
||||
|
||||
return sources, nil
|
||||
}
|
||||
|
||||
func (r *repository) GetSourceByName(ctx context.Context, name string) (*database.Source, error) {
|
||||
var source database.Source
|
||||
if err := r.db.WithContext(ctx).Where("name = ?", name).First(&source).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("source '%s' not found", name)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get source: %w", err)
|
||||
}
|
||||
return &source, nil
|
||||
}
|
||||
|
||||
func (r *repository) GetSourceByNameAndURL(ctx context.Context, name, url string) (*database.Source, error) {
|
||||
var source database.Source
|
||||
if err := r.db.WithContext(ctx).Where("name = ? AND url = ?", name, url).First(&source).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("source '%s' with URL '%s' not found", name, url)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get source: %w", err)
|
||||
}
|
||||
return &source, nil
|
||||
}
|
||||
|
||||
func (r *repository) CreateOrUpdateSource(ctx context.Context, source *database.Source) error {
|
||||
result := r.db.WithContext(ctx).Where(database.Source{Name: source.Name, URL: source.URL}).FirstOrCreate(source)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to create or update source: %w", result.Error)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repository) UpdateSourceName(ctx context.Context, oldName, newName, url string) error {
|
||||
if strings.TrimSpace(newName) == "" {
|
||||
return fmt.Errorf("new name cannot be empty")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).Model(&database.Source{}).
|
||||
Where("name = ? AND url = ?", oldName, url).
|
||||
Update("name", newName)
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to update source name: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("list with name '%s' not found", oldName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repository) UpdateSourceLastUpdated(ctx context.Context, url string, timestamp time.Time) error {
|
||||
result := r.db.WithContext(ctx).Model(&database.Source{}).
|
||||
Where("url = ?", url).
|
||||
Update("last_updated", timestamp)
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to update source: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repository) ToggleSourceActive(ctx context.Context, name string) error {
|
||||
var source database.Source
|
||||
if err := r.db.WithContext(ctx).Where("name = ?", name).First(&source).Error; err != nil {
|
||||
return fmt.Errorf("failed to find source %s: %w", name, err)
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).Model(&source).Update("active", !source.Active)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to toggle status for %s: %w", name, result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repository) DeleteSource(ctx context.Context, name, url string) error {
|
||||
var source database.Source
|
||||
if err := r.db.WithContext(ctx).Where("name = ? AND url = ?", name, url).First(&source).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return fmt.Errorf("source '%s' not found", name)
|
||||
}
|
||||
return fmt.Errorf("failed to get source: %w", err)
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).Delete(&source).Error; err != nil {
|
||||
return fmt.Errorf("failed to remove source '%s': %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repository) GetAllDomains(ctx context.Context) ([]string, error) {
|
||||
var domains []string
|
||||
result := r.db.WithContext(ctx).Model(&database.Blacklist{}).
|
||||
Distinct("domain").
|
||||
Pluck("domain", &domains)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to query blacklist: %w", result.Error)
|
||||
}
|
||||
|
||||
return domains, nil
|
||||
}
|
||||
|
||||
func (r *repository) GetDomainsForSource(ctx context.Context, sourceName string) ([]string, error) {
|
||||
var blacklistEntries []database.Blacklist
|
||||
result := r.db.WithContext(ctx).Select("blacklists.domain").
|
||||
Joins("JOIN sources ON blacklists.source_id = sources.id").
|
||||
Where("sources.name = ?", sourceName).
|
||||
Find(&blacklistEntries)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to query domains for list: %w", result.Error)
|
||||
}
|
||||
|
||||
domains := make([]string, len(blacklistEntries))
|
||||
for i, entry := range blacklistEntries {
|
||||
domains[i] = entry.Domain
|
||||
}
|
||||
|
||||
return domains, nil
|
||||
}
|
||||
|
||||
func (r *repository) GetPaginatedDomains(ctx context.Context, page, pageSize int, search string) ([]database.Blacklist, int64, error) {
|
||||
searchPattern := "%" + search + "%"
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
var blacklistEntries []database.Blacklist
|
||||
result := r.db.WithContext(ctx).Select("domain").
|
||||
Where("domain LIKE ?", searchPattern).
|
||||
Order("domain DESC").
|
||||
Limit(pageSize).
|
||||
Offset(offset).
|
||||
Find(&blacklistEntries)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, 0, fmt.Errorf("failed to query blacklist: %w", result.Error)
|
||||
}
|
||||
|
||||
var total int64
|
||||
countResult := r.db.WithContext(ctx).Model(&database.Blacklist{}).
|
||||
Where("domain LIKE ?", searchPattern).
|
||||
Count(&total)
|
||||
|
||||
if countResult.Error != nil {
|
||||
return nil, 0, fmt.Errorf("failed to count domains: %w", countResult.Error)
|
||||
}
|
||||
|
||||
return blacklistEntries, total, nil
|
||||
}
|
||||
|
||||
func (r *repository) CountDomains(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
result := r.db.WithContext(ctx).Model(&database.Blacklist{}).Count(&count)
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("failed to count domains: %w", result.Error)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *repository) CreateDomain(ctx context.Context, domain *database.Blacklist) error {
|
||||
result := r.db.WithContext(ctx).Create(domain)
|
||||
if result.Error != nil {
|
||||
if strings.Contains(result.Error.Error(), "UNIQUE constraint failed") ||
|
||||
strings.Contains(result.Error.Error(), "duplicate key") {
|
||||
return fmt.Errorf("%s is already blacklisted", domain.Domain)
|
||||
}
|
||||
return fmt.Errorf("failed to add domain to blacklist: %w", result.Error)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repository) CreateDomainsInBatches(ctx context.Context, domains []database.Blacklist, batchSize int) error {
|
||||
if len(domains) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).CreateInBatches(domains, batchSize).Error; err != nil {
|
||||
if !strings.Contains(err.Error(), "UNIQUE constraint failed") &&
|
||||
!strings.Contains(err.Error(), "duplicate key") {
|
||||
return fmt.Errorf("failed to add domains: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repository) DeleteDomain(ctx context.Context, domain string) error {
|
||||
result := r.db.WithContext(ctx).Where("domain = ?", domain).Delete(&database.Blacklist{})
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to remove domain from blacklist: %w", result.Error)
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("domain not found: %s", domain)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repository) DeleteDomainsBySourceID(ctx context.Context, sourceID uint) error {
|
||||
if err := r.db.WithContext(ctx).Where("source_id = ?", sourceID).Delete(&database.Blacklist{}).Error; err != nil {
|
||||
return fmt.Errorf("failed to remove domains for source ID %d: %w", sourceID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repository) DeleteCustomDomain(ctx context.Context, domain string, sourceID uint) error {
|
||||
result := r.db.WithContext(ctx).
|
||||
Where("domain = ? AND source_id = ?", domain, sourceID).
|
||||
Delete(&database.Blacklist{})
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to delete domain '%s': %w", domain, result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("domain '%s' not found in custom blacklist", domain)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repository) GetAllSourceStats(ctx context.Context) ([]SourceWithCount, error) {
|
||||
var results []SourceWithCount
|
||||
result := r.db.WithContext(ctx).Table("sources s").
|
||||
Select("s.id, s.name, s.url, s.last_updated, s.active, COALESCE(bc.blocked_count, 0) as blocked_count").
|
||||
Joins("LEFT JOIN (SELECT source_id, COUNT(*) as blocked_count FROM blacklists GROUP BY source_id) bc ON s.id = bc.source_id").
|
||||
Order("s.name, s.id").
|
||||
Scan(&results)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to query source statistics: %w", result.Error)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (r *repository) GetSourceStats(ctx context.Context, listname string) (*SourceWithCount, error) {
|
||||
var result SourceWithCount
|
||||
err := r.db.WithContext(ctx).Table("sources s").
|
||||
Select("s.name, s.url, s.last_updated, s.active, COALESCE(bc.blocked_count, 0) as blocked_count").
|
||||
Joins("LEFT JOIN (SELECT source_id, COUNT(*) as blocked_count FROM blacklists GROUP BY source_id) bc ON s.id = bc.source_id").
|
||||
Where("s.name = ?", listname).
|
||||
First(&result).Error
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("list not found")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to query list statistics: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (r *repository) GetRequestStats(ctx context.Context) ([]RequestStats, error) {
|
||||
var stats []RequestStats
|
||||
result := r.db.WithContext(ctx).Model(&database.RequestLog{}).
|
||||
Select("blocked, COUNT(*) as count").
|
||||
Group("blocked").
|
||||
Scan(&stats)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to query request_logs: %w", result.Error)
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (r *repository) Vacuum(ctx context.Context) error {
|
||||
if err := r.db.WithContext(ctx).Exec("VACUUM").Error; err != nil {
|
||||
return fmt.Errorf("error while vacuuming database: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repository) WithTransaction(ctx context.Context, fn func(*gorm.DB) error) error {
|
||||
return r.db.WithContext(ctx).Transaction(fn)
|
||||
}
|
||||
|
||||
func (r *repository) UpsertSource(ctx context.Context, source *database.Source) error {
|
||||
if err := r.db.WithContext(ctx).Clauses(
|
||||
clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "url"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"name", "last_updated", "active"}),
|
||||
},
|
||||
).Create(source).Error; err != nil {
|
||||
return fmt.Errorf("failed to upsert source: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
693
backend/blacklist/service.go
Normal file
693
backend/blacklist/service.go
Normal file
@@ -0,0 +1,693 @@
|
||||
package blacklist
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"goaway/backend/database"
|
||||
"goaway/backend/logging"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type HTTPClient interface {
|
||||
Get(url string) (*http.Response, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
DefaultSources []BlocklistSource
|
||||
CacheTTL time.Duration
|
||||
BatchSize int
|
||||
UpdateInterval time.Duration
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
repository Repository
|
||||
httpClient HTTPClient
|
||||
cache map[string]bool
|
||||
cacheMu sync.RWMutex
|
||||
blocklistURL []BlocklistSource
|
||||
config Config
|
||||
}
|
||||
|
||||
var log = logging.GetLogger()
|
||||
|
||||
const (
|
||||
blacklistedIP = "0.0.0.0"
|
||||
IPv4Loopback = "127.0.0.1"
|
||||
defaultBatchSize = 1000
|
||||
)
|
||||
|
||||
var defaultConfig = Config{
|
||||
DefaultSources: []BlocklistSource{
|
||||
{
|
||||
Name: "StevenBlack",
|
||||
URL: "https://raw.githubusercontent.com/StevenBlack/hosts/refs/heads/master/hosts",
|
||||
},
|
||||
},
|
||||
CacheTTL: 24 * time.Hour,
|
||||
BatchSize: defaultBatchSize,
|
||||
UpdateInterval: 24 * time.Hour,
|
||||
}
|
||||
|
||||
func NewService(repo Repository) *Service {
|
||||
config := defaultConfig
|
||||
service := &Service{
|
||||
repository: repo,
|
||||
httpClient: http.DefaultClient,
|
||||
cache: make(map[string]bool),
|
||||
config: config,
|
||||
}
|
||||
|
||||
if len(service.blocklistURL) == 0 {
|
||||
service.blocklistURL = config.DefaultSources
|
||||
}
|
||||
|
||||
if err := service.initialize(context.Background()); err != nil {
|
||||
log.Error("Could not initialize blacklist: %v", err)
|
||||
}
|
||||
|
||||
return service
|
||||
}
|
||||
|
||||
func (s *Service) initialize(ctx context.Context) error {
|
||||
count, err := s.repository.CountDomains(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to count domains: %w", err)
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
log.Info("No domains in blacklist. Running initialization...")
|
||||
if err := s.initializeBlockedDomains(ctx); err != nil {
|
||||
return fmt.Errorf("failed to initialize blocked domains: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.InitializeBlocklist(ctx, "Custom", ""); err != nil {
|
||||
return fmt.Errorf("failed to initialize custom blocklist: %w", err)
|
||||
}
|
||||
|
||||
if _, err := s.GetBlocklistUrls(ctx); err != nil {
|
||||
log.Error("Failed to fetch blocklist URLs: %v", err)
|
||||
return fmt.Errorf("failed to fetch blocklist URLs: %w", err)
|
||||
}
|
||||
|
||||
if err := s.PopulateCache(ctx); err != nil {
|
||||
log.Error("Failed to initialize blocklist cache: %v", err)
|
||||
return fmt.Errorf("failed to initialize blocklist cache: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) initializeBlockedDomains(ctx context.Context) error {
|
||||
start := time.Now()
|
||||
for _, source := range s.blocklistURL {
|
||||
if source.Name == "Custom" {
|
||||
continue
|
||||
}
|
||||
if err := s.FetchAndLoadHosts(ctx, source.URL, source.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("Blocked domains initialized in %.2fs", time.Since(start).Seconds())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) PopulateCache(ctx context.Context) error {
|
||||
domains, err := s.repository.GetAllDomains(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to populate cache: %w", err)
|
||||
}
|
||||
|
||||
s.cacheMu.Lock()
|
||||
defer s.cacheMu.Unlock()
|
||||
|
||||
s.cache = make(map[string]bool, len(domains))
|
||||
for _, domain := range domains {
|
||||
s.cache[domain] = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) IsBlacklisted(domain string) bool {
|
||||
s.cacheMu.RLock()
|
||||
defer s.cacheMu.RUnlock()
|
||||
|
||||
if exists, found := s.cache[domain]; found {
|
||||
return exists
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Service) GetBlocklistUrls(ctx context.Context) ([]BlocklistSource, error) {
|
||||
sources, err := s.repository.GetSources(ctx, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get blocklist URLs: %w", err)
|
||||
}
|
||||
|
||||
blocklistURL := make([]BlocklistSource, len(sources))
|
||||
for i, source := range sources {
|
||||
blocklistURL[i] = BlocklistSource{
|
||||
Name: source.Name,
|
||||
URL: source.URL,
|
||||
}
|
||||
}
|
||||
|
||||
s.blocklistURL = blocklistURL
|
||||
return blocklistURL, nil
|
||||
}
|
||||
|
||||
func (s *Service) NameExists(name, url string) bool {
|
||||
for _, source := range s.blocklistURL {
|
||||
if source.Name == name && source.URL == url {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Service) URLExists(url string) bool {
|
||||
for _, source := range s.blocklistURL {
|
||||
if source.URL == url {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Service) CheckIfUpdateAvailable(ctx context.Context, remoteListURL, listName string) (ListUpdateAvailable, error) {
|
||||
listUpdateAvailable := ListUpdateAvailable{}
|
||||
|
||||
remoteDomains, remoteChecksum, err := s.FetchRemoteHostsList(ctx, remoteListURL)
|
||||
if err != nil {
|
||||
log.Warning("Failed to fetch remote hosts list: %v", err)
|
||||
return listUpdateAvailable, fmt.Errorf("failed to fetch remote hosts list: %w", err)
|
||||
}
|
||||
|
||||
dbDomains, dbChecksum, err := s.FetchDBHostsList(ctx, listName)
|
||||
if err != nil {
|
||||
log.Warning("Failed to fetch database hosts list: %v", err)
|
||||
return listUpdateAvailable, fmt.Errorf("failed to fetch database hosts list: %w", err)
|
||||
}
|
||||
|
||||
if remoteChecksum == dbChecksum {
|
||||
log.Debug("No updates available for %s", listName)
|
||||
return listUpdateAvailable, nil
|
||||
}
|
||||
|
||||
diff := func(a, b []string) []string {
|
||||
mb := make(map[string]struct{}, len(b))
|
||||
for _, x := range b {
|
||||
mb[x] = struct{}{}
|
||||
}
|
||||
diff := make([]string, 0)
|
||||
for _, x := range a {
|
||||
if _, found := mb[x]; !found {
|
||||
diff = append(diff, x)
|
||||
}
|
||||
}
|
||||
return diff
|
||||
}
|
||||
|
||||
return ListUpdateAvailable{
|
||||
RemoteDomains: remoteDomains,
|
||||
DBDomains: dbDomains,
|
||||
RemoteChecksum: remoteChecksum,
|
||||
DBChecksum: dbChecksum,
|
||||
UpdateAvailable: true,
|
||||
DiffAdded: diff(remoteDomains, dbDomains),
|
||||
DiffRemoved: diff(dbDomains, remoteDomains),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) FetchRemoteHostsList(ctx context.Context, url string) ([]string, string, error) {
|
||||
resp, err := s.httpClient.Get(url)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to fetch hosts file from %s: %w", url, err)
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
_ = Body.Close()
|
||||
}(resp.Body)
|
||||
|
||||
domains, err := s.ExtractDomains(resp.Body)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to extract domains from %s: %w", url, err)
|
||||
}
|
||||
|
||||
return domains, calculateDomainsChecksum(domains), nil
|
||||
}
|
||||
|
||||
func (s *Service) FetchDBHostsList(ctx context.Context, name string) ([]string, string, error) {
|
||||
domains, err := s.repository.GetDomainsForSource(ctx, name)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("could not fetch domains from database: %w", err)
|
||||
}
|
||||
|
||||
return domains, calculateDomainsChecksum(domains), nil
|
||||
}
|
||||
|
||||
func calculateDomainsChecksum(domains []string) string {
|
||||
sort.Strings(domains)
|
||||
data := strings.Join(domains, "\n")
|
||||
hash := sha256.Sum256([]byte(data))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
func (s *Service) FetchAndLoadHosts(ctx context.Context, url, name string) error {
|
||||
resp, err := s.httpClient.Get(url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch hosts file from %s: %w", url, err)
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
_ = Body.Close()
|
||||
}(resp.Body)
|
||||
|
||||
domains, err := s.ExtractDomains(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to extract domains from %s: %w", url, err)
|
||||
}
|
||||
|
||||
if err := s.InitializeBlocklist(ctx, name, url); err != nil {
|
||||
return fmt.Errorf("failed to initialize blocklist: %w", err)
|
||||
}
|
||||
|
||||
if err := s.AddDomains(ctx, name, domains, url); err != nil {
|
||||
return fmt.Errorf("failed to add domains to database: %w", err)
|
||||
}
|
||||
|
||||
log.Info("Added %d domains from list '%s' with url '%s'", len(domains), name, url)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) isValidDomain(domain string) bool {
|
||||
invalidDomains := map[string]bool{
|
||||
"localhost": true,
|
||||
"localhost.localdomain": true,
|
||||
"broadcasthost": true,
|
||||
"local": true,
|
||||
blacklistedIP: true,
|
||||
}
|
||||
return !invalidDomains[domain]
|
||||
}
|
||||
|
||||
func (s *Service) ExtractDomains(body io.Reader) ([]string, error) {
|
||||
scanner := bufio.NewScanner(body)
|
||||
domainSet := make(map[string]struct{})
|
||||
var domains []string
|
||||
|
||||
for scanner.Scan() {
|
||||
fields := strings.Fields(scanner.Text())
|
||||
if len(fields) == 0 || strings.HasPrefix(fields[0], "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
domain := fields[0]
|
||||
if (domain == blacklistedIP || domain == IPv4Loopback) && len(fields) > 1 {
|
||||
domain = fields[1]
|
||||
if !s.isValidDomain(domain) {
|
||||
continue
|
||||
}
|
||||
} else if domain == blacklistedIP || domain == IPv4Loopback {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, exists := domainSet[domain]; !exists {
|
||||
domainSet[domain] = struct{}{}
|
||||
domains = append(domains, domain)
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error reading hosts file: %w", err)
|
||||
}
|
||||
if len(domains) == 0 {
|
||||
return nil, errors.New("zero results when parsing")
|
||||
}
|
||||
|
||||
return domains, nil
|
||||
}
|
||||
|
||||
func (s *Service) updateCache(domains []string, add bool) {
|
||||
s.cacheMu.Lock()
|
||||
defer s.cacheMu.Unlock()
|
||||
|
||||
for _, domain := range domains {
|
||||
if add {
|
||||
s.cache[domain] = true
|
||||
} else {
|
||||
delete(s.cache, domain)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) AddBlacklistedDomain(ctx context.Context, domain string) error {
|
||||
blacklistEntry := &database.Blacklist{Domain: domain}
|
||||
|
||||
if err := s.repository.CreateDomain(ctx, blacklistEntry); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.updateCache([]string{domain}, true)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) AddDomains(ctx context.Context, name string, domains []string, url string) error {
|
||||
return s.repository.WithTransaction(ctx, func(tx *gorm.DB) error {
|
||||
currentTime := time.Now()
|
||||
|
||||
if err := s.repository.UpdateSourceLastUpdated(ctx, url, currentTime); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
source, err := s.repository.GetSourceByNameAndURL(ctx, name, url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
blacklistEntries := make([]database.Blacklist, 0, len(domains))
|
||||
for _, domain := range domains {
|
||||
blacklistEntries = append(blacklistEntries, database.Blacklist{
|
||||
Domain: domain,
|
||||
SourceID: source.ID,
|
||||
})
|
||||
}
|
||||
|
||||
if len(blacklistEntries) > 0 {
|
||||
batchSize := s.config.BatchSize
|
||||
if batchSize == 0 {
|
||||
batchSize = defaultBatchSize
|
||||
}
|
||||
if err := s.repository.CreateDomainsInBatches(ctx, blacklistEntries, batchSize); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
s.updateCache(domains, true)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) RemoveDomain(ctx context.Context, domain string) error {
|
||||
if err := s.repository.DeleteDomain(ctx, domain); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.updateCache([]string{domain}, false)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) AddCustomDomains(ctx context.Context, domains []string) error {
|
||||
return s.repository.WithTransaction(ctx, func(tx *gorm.DB) error {
|
||||
currentTime := time.Now()
|
||||
|
||||
source, err := s.repository.GetSourceByName(ctx, "Custom")
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
newSource := &database.Source{
|
||||
Name: "Custom",
|
||||
LastUpdated: currentTime,
|
||||
Active: true,
|
||||
}
|
||||
if err := s.repository.CreateOrUpdateSource(ctx, newSource); err != nil {
|
||||
return fmt.Errorf("failed to create custom source: %w", err)
|
||||
}
|
||||
source = newSource
|
||||
} else {
|
||||
return fmt.Errorf("failed to get custom source: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := s.repository.UpdateSourceLastUpdated(ctx, "", currentTime); err != nil {
|
||||
return fmt.Errorf("failed to update custom source: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, domain := range domains {
|
||||
entry := &database.Blacklist{
|
||||
Domain: domain,
|
||||
SourceID: source.ID,
|
||||
}
|
||||
// Ignore duplicate errors
|
||||
if err := s.repository.CreateDomain(ctx, entry); err != nil &&
|
||||
!strings.Contains(err.Error(), "already blacklisted") {
|
||||
log.Warning("Failed to add domain %s: %v", domain, err)
|
||||
} else {
|
||||
s.updateCache([]string{domain}, true)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) RemoveCustomDomain(ctx context.Context, domain string) error {
|
||||
source, err := s.repository.GetSourceByName(ctx, "Custom")
|
||||
if err != nil {
|
||||
return fmt.Errorf("custom source not found: %w", err)
|
||||
}
|
||||
|
||||
if err := s.repository.DeleteCustomDomain(ctx, domain, source.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.updateCache([]string{domain}, false)
|
||||
|
||||
currentTime := time.Now()
|
||||
if err := s.repository.UpdateSourceLastUpdated(ctx, "", currentTime); err != nil {
|
||||
log.Warning("Failed to update custom source timestamp: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) InitializeBlocklist(ctx context.Context, name, url string) error {
|
||||
source := &database.Source{
|
||||
Name: name,
|
||||
URL: url,
|
||||
LastUpdated: time.Now(),
|
||||
Active: true,
|
||||
}
|
||||
|
||||
return s.repository.CreateOrUpdateSource(ctx, source)
|
||||
}
|
||||
|
||||
func (s *Service) AddSource(ctx context.Context, name, url string) error {
|
||||
if strings.TrimSpace(name) == "" || strings.TrimSpace(url) == "" {
|
||||
return fmt.Errorf("name and url cannot be empty")
|
||||
}
|
||||
|
||||
source := &database.Source{
|
||||
Name: name,
|
||||
URL: url,
|
||||
LastUpdated: time.Now(),
|
||||
Active: true,
|
||||
}
|
||||
|
||||
if err := s.repository.UpsertSource(ctx, source); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update in-memory list
|
||||
found := false
|
||||
for _, existing := range s.blocklistURL {
|
||||
if existing.Name == name && existing.URL == url {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
s.blocklistURL = append(s.blocklistURL, BlocklistSource{Name: name, URL: url})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) UpdateSourceName(ctx context.Context, oldName, newName, url string) error {
|
||||
if oldName == newName {
|
||||
return fmt.Errorf("new name is the same as the old name")
|
||||
}
|
||||
|
||||
if err := s.repository.UpdateSourceName(ctx, oldName, newName, url); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update in-memory list
|
||||
for i, source := range s.blocklistURL {
|
||||
if source.Name == oldName {
|
||||
s.blocklistURL[i].Name = newName
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("Updated blocklist name from '%s' to '%s'", oldName, newName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) ToggleBlocklistStatus(ctx context.Context, name string) error {
|
||||
return s.repository.ToggleSourceActive(ctx, name)
|
||||
}
|
||||
|
||||
func (s *Service) RemoveSourceAndDomains(ctx context.Context, name, url string) error {
|
||||
return s.repository.WithTransaction(ctx, func(tx *gorm.DB) error {
|
||||
source, err := s.repository.GetSourceByNameAndURL(ctx, name, url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.repository.DeleteDomainsBySourceID(ctx, source.ID); err != nil {
|
||||
return fmt.Errorf("failed to remove domains for source '%s': %w", name, err)
|
||||
}
|
||||
|
||||
if err := s.repository.DeleteSource(ctx, name, url); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) RemoveSourceByNameAndURL(name, url string) bool {
|
||||
for i := len(s.blocklistURL) - 1; i >= 0; i-- {
|
||||
if s.blocklistURL[i].Name == name && s.blocklistURL[i].URL == url {
|
||||
s.blocklistURL = append(s.blocklistURL[:i], s.blocklistURL[i+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Service) CountDomains(ctx context.Context) (int, error) {
|
||||
count, err := s.repository.CountDomains(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(count), nil
|
||||
}
|
||||
|
||||
func (s *Service) GetAllowedAndBlocked(ctx context.Context) (allowed, blocked int, err error) {
|
||||
stats, err := s.repository.GetRequestStats(ctx)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
for _, stat := range stats {
|
||||
if stat.Blocked {
|
||||
blocked = stat.Count
|
||||
} else {
|
||||
allowed = stat.Count
|
||||
}
|
||||
}
|
||||
|
||||
return allowed, blocked, nil
|
||||
}
|
||||
|
||||
func (s *Service) GetAllListStatistics(ctx context.Context) ([]SourceWithCount, error) {
|
||||
results, err := s.repository.GetAllSourceStats(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stats := make([]SourceWithCount, len(results))
|
||||
for i, r := range results {
|
||||
stats[i] = SourceWithCount{
|
||||
Name: r.Name,
|
||||
URL: r.URL,
|
||||
BlockedCount: r.BlockedCount,
|
||||
LastUpdated: r.LastUpdated,
|
||||
Active: r.Active,
|
||||
}
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *Service) GetListStatistics(ctx context.Context, listname string) (string, SourceWithCount, error) {
|
||||
result, err := s.repository.GetSourceStats(ctx, listname)
|
||||
if err != nil {
|
||||
return "", SourceWithCount{}, err
|
||||
}
|
||||
|
||||
stats := SourceWithCount{
|
||||
URL: result.URL,
|
||||
BlockedCount: result.BlockedCount,
|
||||
LastUpdated: result.LastUpdated,
|
||||
Active: result.Active,
|
||||
}
|
||||
|
||||
return result.Name, stats, nil
|
||||
}
|
||||
|
||||
func (s *Service) LoadPaginatedBlacklist(ctx context.Context, page, pageSize int, search string) ([]string, int, error) {
|
||||
blacklistEntries, total, err := s.repository.GetPaginatedDomains(ctx, page, pageSize, search)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
domains := make([]string, len(blacklistEntries))
|
||||
for i, entry := range blacklistEntries {
|
||||
domains[i] = entry.Domain
|
||||
}
|
||||
|
||||
return domains, int(total), nil
|
||||
}
|
||||
|
||||
func (s *Service) Vacuum(ctx context.Context) {
|
||||
log.Debug("Vacuuming database...")
|
||||
if err := s.repository.Vacuum(ctx); err != nil {
|
||||
log.Warning("Error while vacuuming database: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) ScheduleAutomaticListUpdates() {
|
||||
ticker := time.NewTicker(s.config.UpdateInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
ctx := context.Background()
|
||||
log.Info("Starting automatic list updates...")
|
||||
|
||||
for _, source := range s.blocklistURL {
|
||||
if source.Name == "Custom" {
|
||||
continue
|
||||
}
|
||||
log.Info("Checking for updates for blocklist %s from %s", source.Name, source.URL)
|
||||
|
||||
availableUpdate, err := s.CheckIfUpdateAvailable(ctx, source.URL, source.Name)
|
||||
if err != nil {
|
||||
log.Warning("Failed to check for updates for %s: %v", source.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !availableUpdate.UpdateAvailable {
|
||||
log.Info("No updates available for %s", source.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.RemoveSourceAndDomains(ctx, source.Name, source.URL); err != nil {
|
||||
log.Warning("Failed to remove old domains for %s: %v", source.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.FetchAndLoadHosts(ctx, source.URL, source.Name); err != nil {
|
||||
log.Warning("Failed to fetch and load hosts for %s: %v", source.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Info("Successfully updated %s with %d new domains", source.Name, len(availableUpdate.DiffAdded))
|
||||
}
|
||||
|
||||
if err := s.PopulateCache(ctx); err != nil {
|
||||
log.Warning("Failed to populate blocklist cache after auto-update: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
46
backend/database/database.go
Normal file
46
backend/database/database.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func Initialize() *gorm.DB {
|
||||
if err := os.MkdirAll("data", 0755); err != nil {
|
||||
log.Fatal("failed to create data directory: %w", err)
|
||||
}
|
||||
|
||||
databasePath := filepath.Join("data", "database.db")
|
||||
db, err := gorm.Open(sqlite.Open(databasePath), &gorm.Config{})
|
||||
if err != nil {
|
||||
log.Fatal("failed while initializing database: %w", err)
|
||||
}
|
||||
|
||||
if err := AutoMigrate(db); err != nil {
|
||||
log.Fatal("auto migrate failed: %w", err)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func AutoMigrate(db *gorm.DB) error {
|
||||
return db.AutoMigrate(
|
||||
&Source{},
|
||||
&Blacklist{},
|
||||
&Whitelist{},
|
||||
&RequestLog{},
|
||||
&RequestLogIP{},
|
||||
&Resolution{},
|
||||
&MacAddress{},
|
||||
&User{},
|
||||
&APIKey{},
|
||||
&Notification{},
|
||||
&Prefetch{},
|
||||
&Audit{},
|
||||
&Alert{},
|
||||
)
|
||||
}
|
||||
118
backend/database/model.go
Normal file
118
backend/database/model.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type Source struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Name string `json:"name" validate:"required"`
|
||||
URL string `gorm:"unique;not null" json:"url" validate:"required,url"`
|
||||
Active bool `gorm:"default:true" json:"active"`
|
||||
LastUpdated time.Time `gorm:"not null" json:"lastUpdated"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"createdAt"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updatedAt"`
|
||||
}
|
||||
|
||||
type Blacklist struct {
|
||||
Domain string `gorm:"primaryKey" json:"domain" validate:"required,fqdn"`
|
||||
SourceID uint `gorm:"primaryKey;not null" json:"sourceID" validate:"required"`
|
||||
Source Source `gorm:"foreignKey:SourceID;constraint:OnDelete:CASCADE" json:"source"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"createdAt"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updatedAt"`
|
||||
}
|
||||
|
||||
type Whitelist struct {
|
||||
Domain string `gorm:"primaryKey" json:"domain" validate:"required,fqdn"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"createdAt"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updatedAt"`
|
||||
}
|
||||
|
||||
type RequestLog struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Timestamp time.Time `gorm:"not null;index:idx_timestamp_response_size,priority:1" json:"timestamp"`
|
||||
Domain string `gorm:"type:varchar(255);not null;index:idx_domain_timestamp,priority:1;index:idx_client_ip_domain,priority:2" json:"domain" validate:"required"`
|
||||
ClientIP string `gorm:"type:varchar(45);not null;index:idx_client_ip;index:idx_client_ip_domain,priority:1" json:"clientIP" validate:"required,ip"`
|
||||
ClientName string `gorm:"type:varchar(255)" json:"clientName"`
|
||||
QueryType string `gorm:"type:varchar(10)" json:"queryType"`
|
||||
Status string `gorm:"type:varchar(50)" json:"status"`
|
||||
Protocol string `gorm:"type:varchar(10)" json:"protocol"`
|
||||
ResponseTimeNs int64 `gorm:"not null" json:"repsonseTimeNS"`
|
||||
ResponseSizeBytes int `gorm:"index:idx_timestamp_response_size,priority:2" json:"responseSizeBytes"`
|
||||
Blocked bool `gorm:"not null;index:idx_timestamp_covering,priority:2;default:false" json:"blocked"`
|
||||
Cached bool `gorm:"not null;index:idx_timestamp_covering,priority:3;default:false" json:"cached"`
|
||||
IPs []RequestLogIP `gorm:"foreignKey:RequestLogID;constraint:OnDelete:CASCADE" json:"ips"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"createdAt"`
|
||||
}
|
||||
|
||||
type RequestLogIP struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
RequestLogID uint `gorm:"not null;index" json:"requestLogID" validate:"required"`
|
||||
IP string `gorm:"type:varchar(45);not null" json:"ip" validate:"required,ip"`
|
||||
RecordType string `gorm:"type:varchar(10);not null" json:"recordType" validate:"required"`
|
||||
RequestLog RequestLog `gorm:"foreignKey:RequestLogID;constraint:OnDelete:CASCADE" json:"requestLog"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"createdAt"`
|
||||
}
|
||||
|
||||
type Resolution struct {
|
||||
Domain string `gorm:"primaryKey" json:"domain" validate:"required,fqdn"`
|
||||
IP string `gorm:"index" json:"ip" validate:"required,ip"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"createdAt"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updatedAt"`
|
||||
}
|
||||
|
||||
type MacAddress struct {
|
||||
MAC string `gorm:"primaryKey" json:"mac" validate:"required,mac"`
|
||||
IP string `gorm:"index" json:"ip" validate:"required,ip"`
|
||||
Vendor string `json:"vendor"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"createdAt"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updatedAt"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Username string `gorm:"primaryKey" json:"username" validate:"required,min=3,max=50"`
|
||||
Password string `json:"password" validate:"required,min=8"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"createdAt"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updatedAt"`
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
Name string `gorm:"primaryKey" json:"name" validate:"required"`
|
||||
Key string `gorm:"unique;not null" json:"key" validate:"required"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"createdAt"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updatedAt"`
|
||||
}
|
||||
|
||||
type Notification struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Severity string `gorm:"type:varchar(20);not null" json:"severity" validate:"required,oneof=info warning error critical"`
|
||||
Category string `gorm:"type:varchar(50);not null" json:"category" validate:"required"`
|
||||
Text string `gorm:"type:text;not null" json:"text" validate:"required"`
|
||||
Read bool `gorm:"default:false;index" json:"read"`
|
||||
CreatedAt time.Time `gorm:"not null;index" json:"createdAt"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updatedAt"`
|
||||
}
|
||||
|
||||
type Prefetch struct {
|
||||
Domain string `gorm:"primaryKey" json:"domain" validate:"required,fqdn"`
|
||||
QueryType int `gorm:"not null" json:"queryType" validate:"required,min=1"`
|
||||
Refresh int `gorm:"not null" json:"refresh" validate:"required,min=1"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"createdAt"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updatedAt"`
|
||||
}
|
||||
|
||||
type Audit struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Topic string `gorm:"type:varchar(100);not null;index" json:"topic" validate:"required"`
|
||||
Message string `gorm:"type:text;not null" json:"message" validate:"required"`
|
||||
CreatedAt time.Time `gorm:"not null;index" json:"createdAt"`
|
||||
}
|
||||
|
||||
type Alert struct {
|
||||
Type string `gorm:"primaryKey" json:"type" validate:"required"`
|
||||
Name string `gorm:"not null" json:"name" validate:"required"`
|
||||
Webhook string `json:"webhook" validate:"omitempty,url"`
|
||||
Enabled bool `gorm:"default:false" json:"enabled"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"createdAt"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updatedAt"`
|
||||
}
|
||||
@@ -17,18 +17,34 @@ import (
|
||||
var log = logging.GetLogger()
|
||||
|
||||
type vendorResponse struct {
|
||||
Company string `json:"company"`
|
||||
Success bool `json:"success"`
|
||||
Found bool `json:"found"`
|
||||
Company string `json:"company"`
|
||||
}
|
||||
|
||||
type ARPCache struct {
|
||||
mu sync.RWMutex
|
||||
type Cache struct {
|
||||
table map[string]string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type vendorCacheEntry struct {
|
||||
vendor string
|
||||
err error
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
type VendorCache struct {
|
||||
entries map[string]*vendorCacheEntry
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
var (
|
||||
cache = &ARPCache{table: make(map[string]string)}
|
||||
cache = &Cache{table: make(map[string]string)}
|
||||
vendorCache = &VendorCache{
|
||||
entries: make(map[string]*vendorCacheEntry),
|
||||
ttl: 60 * time.Second,
|
||||
}
|
||||
httpClient = &http.Client{Timeout: 5 * time.Second}
|
||||
)
|
||||
|
||||
@@ -44,6 +60,14 @@ func ProcessARPTable() {
|
||||
}
|
||||
}
|
||||
|
||||
func CleanVendorResponseCache() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
vendorCache.cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
func updateARPTable() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 40*time.Second)
|
||||
defer cancel()
|
||||
@@ -130,39 +154,56 @@ func GetMacVendor(mac string) (string, error) {
|
||||
mac = strings.ReplaceAll(mac, "-", "")
|
||||
mac = strings.ToLower(mac)
|
||||
|
||||
if vendor, err, found := vendorCache.get(mac); found {
|
||||
return vendor, err
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("https://api.maclookup.app/v2/macs/%s", mac)
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
req, err := http.NewRequest(http.MethodGet, url, http.NoBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
reqErr := fmt.Errorf("failed to create request: %w", err)
|
||||
vendorCache.set(mac, "", reqErr)
|
||||
return "", reqErr
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to fetch MAC vendor: %w", err)
|
||||
apiErr := fmt.Errorf("failed to fetch MAC vendor: %w", err)
|
||||
vendorCache.set(mac, "", apiErr)
|
||||
return "", apiErr
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
_ = Body.Close()
|
||||
}(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
statusErr := fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
vendorCache.set(mac, "", statusErr)
|
||||
return "", statusErr
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response body: %w", err)
|
||||
readErr := fmt.Errorf("failed to read response body: %w", err)
|
||||
vendorCache.set(mac, "", readErr)
|
||||
return "", readErr
|
||||
}
|
||||
|
||||
var result vendorResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return "", fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
unmarshalErr := fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
vendorCache.set(mac, "", unmarshalErr)
|
||||
return "", unmarshalErr
|
||||
}
|
||||
|
||||
if result.Found {
|
||||
vendorCache.set(mac, result.Company, nil)
|
||||
return result.Company, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("vendor not found")
|
||||
notFoundErr := fmt.Errorf("vendor not found for mac %s", mac)
|
||||
vendorCache.set(mac, "", notFoundErr)
|
||||
return "", notFoundErr
|
||||
}
|
||||
|
||||
func isValidMAC(mac string) bool {
|
||||
@@ -171,3 +212,42 @@ func isValidMAC(mac string) bool {
|
||||
|
||||
return len(cleanMAC) == 12 && cleanMAC != "000000000000"
|
||||
}
|
||||
|
||||
func (vc *VendorCache) get(mac string) (string, error, bool) {
|
||||
vc.mu.RLock()
|
||||
defer vc.mu.RUnlock()
|
||||
|
||||
entry, exists := vc.entries[mac]
|
||||
if !exists {
|
||||
return "", nil, false
|
||||
}
|
||||
|
||||
if time.Since(entry.timestamp) > vc.ttl {
|
||||
return "", nil, false
|
||||
}
|
||||
|
||||
return entry.vendor, entry.err, true
|
||||
}
|
||||
|
||||
func (vc *VendorCache) set(mac, vendor string, err error) {
|
||||
vc.mu.Lock()
|
||||
defer vc.mu.Unlock()
|
||||
|
||||
vc.entries[mac] = &vendorCacheEntry{
|
||||
vendor: vendor,
|
||||
err: err,
|
||||
timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (vc *VendorCache) cleanup() {
|
||||
vc.mu.Lock()
|
||||
defer vc.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for mac, entry := range vc.entries {
|
||||
if now.Sub(entry.timestamp) > vc.ttl {
|
||||
delete(vc.entries, mac)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,148 +0,0 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type DatabaseManager struct {
|
||||
Conn *gorm.DB
|
||||
Mutex *sync.RWMutex
|
||||
}
|
||||
|
||||
type Source struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Name string `json:"name"`
|
||||
URL string `gorm:"unique" json:"url"`
|
||||
Active bool `json:"active"`
|
||||
LastUpdated int64 `json:"lastUpdated"`
|
||||
}
|
||||
|
||||
type Blacklist struct {
|
||||
Domain string `gorm:"primaryKey" json:"domain"`
|
||||
SourceID uint `gorm:"primaryKey" json:"source_id"`
|
||||
Source Source `gorm:"foreignKey:SourceID;references:ID" json:"source"`
|
||||
}
|
||||
|
||||
type Whitelist struct {
|
||||
Domain string `gorm:"primaryKey" json:"domain"`
|
||||
}
|
||||
|
||||
type RequestLog struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Timestamp time.Time `gorm:"not null;index:idx_request_log_timestamp_covering,priority:1;index:idx_request_log_timestamp_desc;index:idx_request_log_domain_timestamp,priority:2" json:"timestamp"`
|
||||
Domain string `gorm:"type:varchar(255);not null;index:idx_request_log_domain_timestamp,priority:1;index:idx_client_ip_domain,priority:2" json:"domain"`
|
||||
Blocked bool `gorm:"not null;index:idx_request_log_timestamp_covering,priority:2" json:"blocked"`
|
||||
Cached bool `gorm:"not null;index:idx_request_log_timestamp_covering,priority:3" json:"cached"`
|
||||
ResponseTimeNs int64 `gorm:"not null" json:"response_time_ns"`
|
||||
ClientIP string `gorm:"type:varchar(45);index:idx_client_ip;index:idx_client_ip_domain,priority:1" json:"client_ip"`
|
||||
ClientName string `gorm:"type:varchar(255)" json:"client_name"`
|
||||
Status string `gorm:"type:varchar(50)" json:"status"`
|
||||
QueryType string `gorm:"type:varchar(10)" json:"query_type"`
|
||||
ResponseSizeBytes int `json:"response_size_bytes"`
|
||||
Protocol string `gorm:"type:varchar(10)" json:"protocol"`
|
||||
IPs []RequestLogIP `gorm:"foreignKey:RequestLogID;constraint:OnDelete:CASCADE" json:"ips"`
|
||||
}
|
||||
|
||||
type RequestLogIP struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
RequestLogID uint `gorm:"not null;index" json:"request_log_id"`
|
||||
IP string `gorm:"type:varchar(45);not null" json:"ip"`
|
||||
RType string `gorm:"type:varchar(10);not null" json:"rtype"`
|
||||
RequestLog RequestLog `gorm:"foreignKey:RequestLogID;references:ID" json:"request_log"`
|
||||
}
|
||||
|
||||
type Resolution struct {
|
||||
Domain string `gorm:"primaryKey" json:"domain"`
|
||||
IP string `json:"ip"`
|
||||
}
|
||||
|
||||
type MacAddress struct {
|
||||
MAC string `gorm:"primaryKey" json:"mac"`
|
||||
IP string `json:"ip"`
|
||||
Vendor string `json:"vendor"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Username string `gorm:"primaryKey" json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
Name string `gorm:"primaryKey" json:"name"`
|
||||
Key string `json:"key"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
}
|
||||
|
||||
type Notification struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Severity string `json:"severity"`
|
||||
Category string `json:"category"`
|
||||
Text string `json:"text"`
|
||||
Read bool `json:"read"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
}
|
||||
|
||||
type Prefetch struct {
|
||||
Domain string `gorm:"primaryKey" json:"domain"`
|
||||
Refresh int `json:"refresh"`
|
||||
QType int `json:"qtype"`
|
||||
}
|
||||
|
||||
type Audit struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Topic string `json:"topic"`
|
||||
Message string `json:"message"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
}
|
||||
|
||||
type Alert struct {
|
||||
Type string `gorm:"primaryKey" json:"type"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Name string `json:"name"`
|
||||
Webhook string `json:"webhook"`
|
||||
}
|
||||
|
||||
func Initialize() *DatabaseManager {
|
||||
if err := os.MkdirAll("data", 0755); err != nil {
|
||||
log.Fatal("failed to create data directory: %v", err)
|
||||
}
|
||||
|
||||
databasePath := filepath.Join("data", "database.db")
|
||||
db, err := gorm.Open(sqlite.Open(databasePath), &gorm.Config{})
|
||||
if err != nil {
|
||||
log.Fatal("failed while initializing database: %v", err)
|
||||
}
|
||||
|
||||
if err := AutoMigrate(db); err != nil {
|
||||
log.Fatal("auto migrate failed: %v", err)
|
||||
}
|
||||
|
||||
return &DatabaseManager{
|
||||
Conn: db,
|
||||
Mutex: &sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
func AutoMigrate(db *gorm.DB) error {
|
||||
return db.AutoMigrate(
|
||||
&Source{},
|
||||
&Blacklist{},
|
||||
&Whitelist{},
|
||||
&RequestLog{},
|
||||
&RequestLogIP{},
|
||||
&Resolution{},
|
||||
&MacAddress{},
|
||||
&User{},
|
||||
&APIKey{},
|
||||
&Notification{},
|
||||
&Prefetch{},
|
||||
&Audit{},
|
||||
&Alert{},
|
||||
)
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func FindVendor(db *gorm.DB, mac string) (string, error) {
|
||||
var query MacAddress
|
||||
tx := db.Find(&query, "mac = ?", mac)
|
||||
|
||||
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
|
||||
return "", nil
|
||||
}
|
||||
if tx.Error != nil {
|
||||
return "", tx.Error
|
||||
}
|
||||
|
||||
return query.Vendor, nil
|
||||
}
|
||||
|
||||
func SaveMacEntry(db *gorm.DB, clientIP, mac, vendor string) {
|
||||
entry := MacAddress{
|
||||
MAC: mac,
|
||||
IP: clientIP,
|
||||
Vendor: vendor,
|
||||
}
|
||||
tx := db.Create(&entry)
|
||||
|
||||
if tx.Error != nil {
|
||||
log.Warning("Unable to save new MAC entry %v", tx.Error)
|
||||
}
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
package models
|
||||
|
||||
import "time"
|
||||
|
||||
type Client struct {
|
||||
Name, Mac, Vendor string
|
||||
LastSeen time.Time
|
||||
}
|
||||
|
||||
type ClientDetails struct {
|
||||
IP, Name, MAC string
|
||||
}
|
||||
|
||||
type ClientRequestDetails struct {
|
||||
TotalRequests, UniqueDomains, BlockedRequests, CachedRequests int
|
||||
AvgResponseTimeMs float64
|
||||
LastSeen, MostQueriedDomain string
|
||||
}
|
||||
|
||||
type Resolution struct {
|
||||
IP string `json:"ip"`
|
||||
Domain string `json:"domain"`
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func FetchResolutions(db *gorm.DB) ([]Resolution, error) {
|
||||
var resolutions []Resolution
|
||||
if err := db.Find(&resolutions).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch resolutions: %w", err)
|
||||
}
|
||||
return resolutions, nil
|
||||
}
|
||||
|
||||
func FetchResolution(db *gorm.DB, domain string) (string, error) {
|
||||
log.Debug("Finding resolution for domain: %s", domain)
|
||||
var res Resolution
|
||||
|
||||
db.Where("domain = ?", domain).Find(&res)
|
||||
if res.IP != "" {
|
||||
return res.IP, nil
|
||||
}
|
||||
|
||||
parts := strings.Split(domain, ".")
|
||||
for i := 1; i < len(parts); i++ {
|
||||
wildcardDomain := "*." + strings.Join(parts[i:], ".")
|
||||
if err := db.Where("domain = ?", wildcardDomain).Find(&res).Error; err == nil {
|
||||
return res.IP, nil
|
||||
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func CreateNewResolution(db *gorm.DB, ip, domain string) error {
|
||||
res := Resolution{
|
||||
Domain: domain,
|
||||
IP: ip,
|
||||
}
|
||||
|
||||
if err := db.Create(&res).Error; err != nil {
|
||||
if strings.Contains(err.Error(), "UNIQUE") {
|
||||
return fmt.Errorf("domain already exists, must be unique")
|
||||
}
|
||||
return fmt.Errorf("could not create new resolution: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteResolution(db *gorm.DB, ip, domain string) (int, error) {
|
||||
result := db.Where("ip = ? AND domain = ?", ip, domain).Delete(&Resolution{})
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("could not delete resolution: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
log.Warning("No resolution found with IP: %s and Domain: %s", ip, domain)
|
||||
}
|
||||
|
||||
return int(result.RowsAffected), nil
|
||||
}
|
||||
@@ -1,794 +0,0 @@
|
||||
package lists
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"goaway/backend/dns/database"
|
||||
"goaway/backend/logging"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
var log = logging.GetLogger()
|
||||
|
||||
type BlocklistSource struct {
|
||||
Name string
|
||||
URL string
|
||||
}
|
||||
|
||||
type Blacklist struct {
|
||||
DBManager *database.DatabaseManager
|
||||
BlocklistURL []BlocklistSource
|
||||
BlacklistCache map[string]bool
|
||||
}
|
||||
|
||||
type SourceStats struct {
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
BlockedCount int `json:"blockedCount"`
|
||||
LastUpdated int64 `json:"lastUpdated"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
type ListUpdateAvailable struct {
|
||||
RemoteDomains []string `json:"remoteDomains"`
|
||||
DBDomains []string `json:"dbDomains"`
|
||||
RemoteChecksum string `json:"remoteChecksum"`
|
||||
DBChecksum string `json:"dbChecksum"`
|
||||
UpdateAvailable bool `json:"updateAvailable"`
|
||||
DiffAdded []string `json:"diffAdded"`
|
||||
DiffRemoved []string `json:"diffRemoved"`
|
||||
}
|
||||
|
||||
func InitializeBlacklist(dbManager *database.DatabaseManager) (*Blacklist, error) {
|
||||
b := &Blacklist{
|
||||
DBManager: dbManager,
|
||||
BlocklistURL: []BlocklistSource{
|
||||
{Name: "StevenBlack", URL: "https://raw.githubusercontent.com/StevenBlack/hosts/refs/heads/master/hosts"},
|
||||
},
|
||||
BlacklistCache: map[string]bool{},
|
||||
}
|
||||
|
||||
if count, _ := b.CountDomains(); count == 0 {
|
||||
log.Info("No domains in blacklist. Running initialization...")
|
||||
if err := b.initializeBlockedDomains(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize blocked domains: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := b.InitializeBlocklist("Custom", ""); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize custom blocklist: %w", err)
|
||||
}
|
||||
|
||||
_, err := b.GetBlocklistUrls()
|
||||
if err != nil {
|
||||
log.Error("Failed to fetch blocklist URLs: %v", err)
|
||||
return nil, fmt.Errorf("failed to fetch blocklist URLs: %w", err)
|
||||
}
|
||||
_, err = b.PopulateBlocklistCache()
|
||||
if err != nil {
|
||||
log.Error("Failed to initialize blocklist cache")
|
||||
return nil, fmt.Errorf("failed to initialize blocklist cache: %w", err)
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (b *Blacklist) initializeBlockedDomains() error {
|
||||
for _, source := range b.BlocklistURL {
|
||||
if source.Name == "Custom" {
|
||||
continue
|
||||
}
|
||||
if err := b.FetchAndLoadHosts(source.URL, source.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Blacklist) Vacuum() {
|
||||
b.DBManager.Mutex.Lock()
|
||||
tx := b.DBManager.Conn.Raw("VACUUM")
|
||||
if err := tx.Error; err != nil {
|
||||
log.Warning("Error while vacuuming database: %v", err)
|
||||
}
|
||||
err := tx.Commit().Error
|
||||
b.DBManager.Mutex.Unlock()
|
||||
if err != nil {
|
||||
log.Warning("Error while vacuuming database: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Blacklist) GetBlocklistUrls() ([]BlocklistSource, error) {
|
||||
var sources []database.Source
|
||||
|
||||
result := b.DBManager.Conn.Where("name != ?", "Custom").Find(&sources)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to query sources: %w", result.Error)
|
||||
}
|
||||
|
||||
blocklistURL := make([]BlocklistSource, len(sources))
|
||||
for i, source := range sources {
|
||||
blocklistURL[i] = BlocklistSource{
|
||||
Name: source.Name,
|
||||
URL: source.URL,
|
||||
}
|
||||
}
|
||||
|
||||
b.BlocklistURL = blocklistURL
|
||||
return blocklistURL, nil
|
||||
}
|
||||
|
||||
func (b *Blacklist) CheckIfUpdateAvailable(remoteListURL, listName string) (ListUpdateAvailable, error) {
|
||||
listUpdateAvailable := ListUpdateAvailable{}
|
||||
remoteDomains, remoteChecksum, err := b.FetchRemoteHostsList(remoteListURL)
|
||||
if err != nil {
|
||||
log.Warning("Failed to fetch remote hosts list: %v", err)
|
||||
return listUpdateAvailable, fmt.Errorf("failed to fetch remote hosts list: %w", err)
|
||||
}
|
||||
|
||||
dbDomains, dbChecksum, err := b.FetchDBHostsList(listName)
|
||||
if err != nil {
|
||||
log.Warning("Failed to fetch database hosts list: %v", err)
|
||||
return listUpdateAvailable, fmt.Errorf("failed to fetch database hosts list: %w", err)
|
||||
}
|
||||
|
||||
if remoteChecksum == dbChecksum {
|
||||
log.Debug("No updates available for %s", listName)
|
||||
return listUpdateAvailable, nil
|
||||
}
|
||||
|
||||
diff := func(a, b []string) []string {
|
||||
mb := make(map[string]struct{}, len(b))
|
||||
for _, x := range b {
|
||||
mb[x] = struct{}{}
|
||||
}
|
||||
diff := make([]string, 0)
|
||||
for _, x := range a {
|
||||
if _, found := mb[x]; !found {
|
||||
diff = append(diff, x)
|
||||
}
|
||||
}
|
||||
return diff
|
||||
}
|
||||
|
||||
return ListUpdateAvailable{
|
||||
RemoteDomains: remoteDomains,
|
||||
DBDomains: dbDomains,
|
||||
RemoteChecksum: remoteChecksum,
|
||||
DBChecksum: dbChecksum,
|
||||
UpdateAvailable: true,
|
||||
DiffAdded: diff(remoteDomains, dbDomains),
|
||||
DiffRemoved: diff(dbDomains, remoteDomains),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *Blacklist) FetchRemoteHostsList(url string) ([]string, string, error) {
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to fetch hosts file from %s: %w", url, err)
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
_ = Body.Close()
|
||||
}(resp.Body)
|
||||
|
||||
domains, err := b.ExtractDomains(resp.Body)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to extract domains from %s: %w", url, err)
|
||||
}
|
||||
|
||||
return domains, calculateDomainsChecksum(domains), nil
|
||||
}
|
||||
|
||||
func (b *Blacklist) FetchDBHostsList(name string) ([]string, string, error) {
|
||||
domains, err := b.GetDomainsForList(name)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("could not fetch domains from database")
|
||||
}
|
||||
|
||||
return domains, calculateDomainsChecksum(domains), nil
|
||||
}
|
||||
|
||||
func calculateDomainsChecksum(domains []string) string {
|
||||
sort.Strings(domains)
|
||||
data := strings.Join(domains, "\n")
|
||||
|
||||
hash := sha256.Sum256([]byte(data))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
func (b *Blacklist) FetchAndLoadHosts(url, name string) error {
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch hosts file from %s: %w", url, err)
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
_ = Body.Close()
|
||||
}(resp.Body)
|
||||
|
||||
domains, err := b.ExtractDomains(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to extract domains from %s: %w", url, err)
|
||||
}
|
||||
|
||||
_ = b.InitializeBlocklist(name, url)
|
||||
|
||||
if err := b.AddDomains(domains, url); err != nil {
|
||||
return fmt.Errorf("failed to add domains to database: %w", err)
|
||||
}
|
||||
|
||||
log.Info("Added %d domains from list '%s' with url '%s'", len(domains), name, url)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Blacklist) ExtractDomains(body io.Reader) ([]string, error) {
|
||||
scanner := bufio.NewScanner(body)
|
||||
domainSet := make(map[string]struct{})
|
||||
var domains []string
|
||||
|
||||
for scanner.Scan() {
|
||||
fields := strings.Fields(scanner.Text())
|
||||
if len(fields) == 0 || strings.HasPrefix(fields[0], "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
domain := fields[0]
|
||||
if (domain == "0.0.0.0" || domain == "127.0.0.1") && len(fields) > 1 {
|
||||
domain = fields[1]
|
||||
switch domain {
|
||||
case "localhost", "localhost.localdomain", "broadcasthost", "local", "0.0.0.0":
|
||||
continue
|
||||
}
|
||||
} else if domain == "0.0.0.0" || domain == "127.0.0.1" {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, exists := domainSet[domain]; !exists {
|
||||
domainSet[domain] = struct{}{}
|
||||
domains = append(domains, domain)
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error reading hosts file: %w", err)
|
||||
}
|
||||
if len(domains) == 0 {
|
||||
return nil, errors.New("zero results when parsing")
|
||||
}
|
||||
|
||||
return domains, nil
|
||||
}
|
||||
|
||||
func (b *Blacklist) AddBlacklistedDomain(domain string) error {
|
||||
blacklistEntry := database.Blacklist{Domain: domain}
|
||||
|
||||
result := b.DBManager.Conn.Create(&blacklistEntry)
|
||||
if result.Error != nil {
|
||||
if strings.Contains(result.Error.Error(), "UNIQUE constraint failed") ||
|
||||
strings.Contains(result.Error.Error(), "duplicate key") {
|
||||
return fmt.Errorf("%s is already blacklisted", domain)
|
||||
}
|
||||
return fmt.Errorf("failed to add domain to blacklist: %w", result.Error)
|
||||
}
|
||||
|
||||
b.BlacklistCache[domain] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Blacklist) AddDomains(domains []string, url string) error {
|
||||
return b.DBManager.Conn.Transaction(func(tx *gorm.DB) error {
|
||||
var source database.Source
|
||||
currentTime := time.Now().Unix()
|
||||
|
||||
result := tx.Model(&source).Where("url = ?", url).Update("last_updated", currentTime)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to update source: %w", result.Error)
|
||||
}
|
||||
|
||||
if err := tx.Where("url = ?", url).First(&source).Error; err != nil {
|
||||
return fmt.Errorf("failed to find source: %w", err)
|
||||
}
|
||||
|
||||
blacklistEntries := make([]database.Blacklist, 0, len(domains))
|
||||
for _, domain := range domains {
|
||||
blacklistEntries = append(blacklistEntries, database.Blacklist{
|
||||
Domain: domain,
|
||||
SourceID: source.ID,
|
||||
})
|
||||
}
|
||||
|
||||
if len(blacklistEntries) > 0 {
|
||||
if err := tx.CreateInBatches(blacklistEntries, 1000).Error; err != nil {
|
||||
if !strings.Contains(err.Error(), "UNIQUE constraint failed") &&
|
||||
!strings.Contains(err.Error(), "duplicate key") {
|
||||
return fmt.Errorf("failed to add domains: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (b *Blacklist) PopulateBlocklistCache() (int, error) {
|
||||
var databaseDomains []string
|
||||
result := b.DBManager.Conn.Model(&database.Blacklist{}).
|
||||
Distinct("domain").
|
||||
Pluck("domain", &databaseDomains)
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("failed to query blacklist: %w", result.Error)
|
||||
}
|
||||
|
||||
b.BlacklistCache = make(map[string]bool, len(databaseDomains))
|
||||
for _, domain := range databaseDomains {
|
||||
b.BlacklistCache[domain] = true
|
||||
}
|
||||
|
||||
return len(b.BlacklistCache), nil
|
||||
}
|
||||
|
||||
func (b *Blacklist) CountDomains() (int, error) {
|
||||
var count int64
|
||||
result := b.DBManager.Conn.Model(&database.Blacklist{}).Count(&count)
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("failed to count domains: %w", result.Error)
|
||||
}
|
||||
return int(count), nil
|
||||
}
|
||||
|
||||
func (b *Blacklist) GetAllowedAndBlocked() (allowed, blocked int, err error) {
|
||||
type RequestStats struct {
|
||||
Blocked bool
|
||||
Count int
|
||||
}
|
||||
|
||||
var stats []RequestStats
|
||||
result := b.DBManager.Conn.Model(&database.RequestLog{}).
|
||||
Select("blocked, COUNT(*) as count").
|
||||
Group("blocked").
|
||||
Scan(&stats)
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, 0, fmt.Errorf("failed to query request_logs: %w", result.Error)
|
||||
}
|
||||
|
||||
for _, stat := range stats {
|
||||
if stat.Blocked {
|
||||
blocked = stat.Count
|
||||
} else {
|
||||
allowed = stat.Count
|
||||
}
|
||||
}
|
||||
|
||||
return allowed, blocked, nil
|
||||
}
|
||||
|
||||
func (b *Blacklist) RemoveDomain(domain string) error {
|
||||
result := b.DBManager.Conn.Where("domain = ?", domain).Delete(&database.Blacklist{})
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to remove domain from blacklist: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("%s is already whitelisted", domain)
|
||||
}
|
||||
|
||||
delete(b.BlacklistCache, domain)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Blacklist) UpdateSourceName(oldName, newName, url string) error {
|
||||
if strings.TrimSpace(newName) == "" {
|
||||
return fmt.Errorf("new name cannot be empty")
|
||||
}
|
||||
|
||||
if oldName == newName {
|
||||
return fmt.Errorf("new name is the same as the old name")
|
||||
}
|
||||
|
||||
result := b.DBManager.Conn.Model(&database.Source{}).
|
||||
Where("name = ? AND url = ?", oldName, url).
|
||||
Update("name", newName)
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to update source name: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("list with name '%s' not found", oldName)
|
||||
}
|
||||
|
||||
for i, source := range b.BlocklistURL {
|
||||
if source.Name == oldName {
|
||||
b.BlocklistURL[i].Name = newName
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("Updated blocklist name from '%s' to '%s'", oldName, newName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Blacklist) NameExists(name, url string) bool {
|
||||
for _, source := range b.BlocklistURL {
|
||||
if source.Name == name && source.URL == url {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (b *Blacklist) URLExists(url string) bool {
|
||||
for _, source := range b.BlocklistURL {
|
||||
if source.URL == url {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (b *Blacklist) IsBlacklisted(domain string) bool {
|
||||
return b.BlacklistCache[domain]
|
||||
}
|
||||
|
||||
func (b *Blacklist) LoadPaginatedBlacklist(page, pageSize int, search string) ([]string, int, error) {
|
||||
searchPattern := "%" + search + "%"
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
var blacklistEntries []database.Blacklist
|
||||
result := b.DBManager.Conn.Select("domain").
|
||||
Where("domain LIKE ?", searchPattern).
|
||||
Order("domain DESC").
|
||||
Limit(pageSize).
|
||||
Offset(offset).
|
||||
Find(&blacklistEntries)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, 0, fmt.Errorf("failed to query blacklist: %w", result.Error)
|
||||
}
|
||||
|
||||
domains := make([]string, len(blacklistEntries))
|
||||
for i, entry := range blacklistEntries {
|
||||
domains[i] = entry.Domain
|
||||
}
|
||||
|
||||
var total int64
|
||||
countResult := b.DBManager.Conn.Model(&database.Blacklist{}).
|
||||
Where("domain LIKE ?", searchPattern).
|
||||
Count(&total)
|
||||
|
||||
if countResult.Error != nil {
|
||||
return nil, 0, fmt.Errorf("failed to count domains: %w", countResult.Error)
|
||||
}
|
||||
|
||||
return domains, int(total), nil
|
||||
}
|
||||
|
||||
func (b *Blacklist) InitializeBlocklist(name, url string) error {
|
||||
return b.DBManager.Conn.Transaction(func(tx *gorm.DB) error {
|
||||
source := database.Source{
|
||||
Name: name,
|
||||
URL: url,
|
||||
LastUpdated: time.Now().Unix(),
|
||||
Active: true,
|
||||
}
|
||||
|
||||
result := tx.Where(database.Source{Name: name, URL: url}).FirstOrCreate(&source)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to initialize new blocklist: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (b *Blacklist) AddCustomDomains(domains []string) error {
|
||||
return b.DBManager.Conn.Transaction(func(tx *gorm.DB) error {
|
||||
var source database.Source
|
||||
currentTime := time.Now().Unix()
|
||||
|
||||
err := tx.Where("name = ?", "Custom").First(&source).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
source = database.Source{
|
||||
Name: "Custom",
|
||||
LastUpdated: currentTime,
|
||||
}
|
||||
if err := tx.Create(&source).Error; err != nil {
|
||||
return fmt.Errorf("failed to insert custom source: %w", err)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("failed to get custom source ID: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := tx.Model(&source).Update("last_updated", currentTime).Error; err != nil {
|
||||
return fmt.Errorf("failed to update lastUpdated for custom source: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
blacklistEntries := make([]database.Blacklist, 0, len(domains))
|
||||
for _, domain := range domains {
|
||||
blacklistEntries = append(blacklistEntries, database.Blacklist{
|
||||
Domain: domain,
|
||||
SourceID: source.ID,
|
||||
})
|
||||
}
|
||||
|
||||
for _, entry := range blacklistEntries {
|
||||
if err := tx.Where(database.Blacklist{Domain: entry.Domain, SourceID: entry.SourceID}).FirstOrCreate(&entry).Error; err != nil {
|
||||
return fmt.Errorf("failed to add custom domain '%s': %w", entry.Domain, err)
|
||||
}
|
||||
b.BlacklistCache[entry.Domain] = true
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (b *Blacklist) RemoveCustomDomain(domain string) error {
|
||||
b.DBManager.Mutex.Lock()
|
||||
defer b.DBManager.Mutex.Unlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
return b.DBManager.Conn.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
var source database.Source
|
||||
err := tx.Where("name = ?", "Custom").First(&source).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return fmt.Errorf("custom source not found")
|
||||
}
|
||||
return fmt.Errorf("failed to get custom source ID: %w", err)
|
||||
}
|
||||
|
||||
result := tx.Where("domain = ? AND source_id = ?", domain, source.ID).Delete(&database.Blacklist{})
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to delete domain '%s': %w", domain, result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("domain '%s' not found in custom blacklist", domain)
|
||||
}
|
||||
|
||||
delete(b.BlacklistCache, domain)
|
||||
|
||||
currentTime := time.Now().Unix()
|
||||
if err := tx.Model(&source).Update("last_updated", currentTime).Error; err != nil {
|
||||
return fmt.Errorf("failed to update lastUpdated for custom source: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (b *Blacklist) GetAllListStatistics() ([]SourceStats, error) {
|
||||
type SourceWithCount struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
LastUpdated int64 `json:"last_updated"`
|
||||
Active bool `json:"active"`
|
||||
BlockedCount int `json:"blocked_count"`
|
||||
}
|
||||
|
||||
var results []SourceWithCount
|
||||
result := b.DBManager.Conn.Table("sources s").
|
||||
Select("s.id, s.name, s.url, s.last_updated, s.active, COALESCE(bc.blocked_count, 0) as blocked_count").
|
||||
Joins("LEFT JOIN (SELECT source_id, COUNT(*) as blocked_count FROM blacklists GROUP BY source_id) bc ON s.id = bc.source_id").
|
||||
Order("s.name, s.id").
|
||||
Scan(&results)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to query source statistics: %w", result.Error)
|
||||
}
|
||||
|
||||
stats := make([]SourceStats, len(results))
|
||||
for i, r := range results {
|
||||
stats[i] = SourceStats{
|
||||
Name: r.Name,
|
||||
URL: r.URL,
|
||||
BlockedCount: r.BlockedCount,
|
||||
LastUpdated: r.LastUpdated,
|
||||
Active: r.Active,
|
||||
}
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (b *Blacklist) GetListStatistics(listname string) (string, SourceStats, error) {
|
||||
type SourceWithCount struct {
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
LastUpdated int64 `json:"last_updated"`
|
||||
Active bool `json:"active"`
|
||||
BlockedCount int `json:"blocked_count"`
|
||||
}
|
||||
|
||||
var result SourceWithCount
|
||||
err := b.DBManager.Conn.Table("sources s").
|
||||
Select("s.name, s.url, s.last_updated, s.active, COALESCE(bc.blocked_count, 0) as blocked_count").
|
||||
Joins("LEFT JOIN (SELECT source_id, COUNT(*) as blocked_count FROM blacklists GROUP BY source_id) bc ON s.id = bc.source_id").
|
||||
Where("s.name = ?", listname).
|
||||
First(&result).Error
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", SourceStats{}, fmt.Errorf("list not found")
|
||||
}
|
||||
return "", SourceStats{}, fmt.Errorf("failed to query list statistics: %w", err)
|
||||
}
|
||||
|
||||
stats := SourceStats{
|
||||
URL: result.URL,
|
||||
BlockedCount: result.BlockedCount,
|
||||
LastUpdated: result.LastUpdated,
|
||||
Active: result.Active,
|
||||
}
|
||||
|
||||
return result.Name, stats, nil
|
||||
}
|
||||
|
||||
func (b *Blacklist) GetDomainsForList(list string) ([]string, error) {
|
||||
var blacklistEntries []database.Blacklist
|
||||
result := b.DBManager.Conn.Select("blacklists.domain").
|
||||
Joins("JOIN sources ON blacklists.source_id = sources.id").
|
||||
Where("sources.name = ?", list).
|
||||
Find(&blacklistEntries)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to query domains for list: %w", result.Error)
|
||||
}
|
||||
|
||||
domains := make([]string, len(blacklistEntries))
|
||||
for i, entry := range blacklistEntries {
|
||||
domains[i] = entry.Domain
|
||||
}
|
||||
|
||||
return domains, nil
|
||||
}
|
||||
|
||||
func (b *Blacklist) ToggleBlocklistStatus(name string) error {
|
||||
var source database.Source
|
||||
if err := b.DBManager.Conn.Where("name = ?", name).First(&source).Error; err != nil {
|
||||
return fmt.Errorf("failed to find source %s: %w", name, err)
|
||||
}
|
||||
|
||||
result := b.DBManager.Conn.Model(&source).Update("active", !source.Active)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to toggle status for %s: %w", name, result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Blacklist) RemoveSourceAndDomains(name, url string) error {
|
||||
return b.DBManager.Conn.Transaction(func(tx *gorm.DB) error {
|
||||
var source database.Source
|
||||
err := tx.Where("name = ? AND url = ?", name, url).First(&source).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return fmt.Errorf("source '%s' not found", name)
|
||||
}
|
||||
return fmt.Errorf("failed to get source ID: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Where("source_id = ?", source.ID).Delete(&database.Blacklist{}).Error; err != nil {
|
||||
return fmt.Errorf("failed to remove domains for source '%s': %w", name, err)
|
||||
}
|
||||
|
||||
if err := tx.Delete(&source).Error; err != nil {
|
||||
return fmt.Errorf("failed to remove source '%s': %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (b *Blacklist) RemoveSourceAndDomainsWithCacheRefresh(name, url string) error {
|
||||
if err := b.RemoveSourceAndDomains(name, url); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := b.PopulateBlocklistCache(); err != nil {
|
||||
log.Warning("Failed to clear blocklist cache after removing source: %v", err)
|
||||
}
|
||||
|
||||
log.Info("Removed all domains and source '%s'", name)
|
||||
return nil
|
||||
}
|
||||
func (b *Blacklist) ScheduleAutomaticListUpdates() {
|
||||
for {
|
||||
next := time.Now().Add(24 * time.Hour).Truncate(24 * time.Hour)
|
||||
log.Info("Next auto-update for lists scheduled for: %s", next.Format(time.DateTime))
|
||||
time.Sleep(time.Until(next))
|
||||
|
||||
for _, source := range b.BlocklistURL {
|
||||
if source.Name == "Custom" {
|
||||
continue
|
||||
}
|
||||
log.Info("Checking for updates for blocklist %s from %s", source.Name, source.URL)
|
||||
|
||||
availableUpdate, err := b.CheckIfUpdateAvailable(source.URL, source.Name)
|
||||
if err != nil {
|
||||
log.Warning("Failed to check for updates for %s: %v", source.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !availableUpdate.UpdateAvailable {
|
||||
log.Info("No updates available for %s", source.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := b.RemoveSourceAndDomains(source.Name, source.URL); err != nil {
|
||||
log.Warning("Failed to remove old domains for %s: %v", source.Name, err)
|
||||
continue
|
||||
}
|
||||
if err := b.FetchAndLoadHosts(source.URL, source.Name); err != nil {
|
||||
log.Warning("Failed to fetch and load hosts for %s: %v", source.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Info("Successfully updated %s with %d new domains", source.Name, len(availableUpdate.DiffAdded))
|
||||
}
|
||||
if _, err := b.PopulateBlocklistCache(); err != nil {
|
||||
log.Warning("Failed to populate blocklist cache after auto-update: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Blacklist) AddSource(name, url string) error {
|
||||
if strings.TrimSpace(name) == "" || strings.TrimSpace(url) == "" {
|
||||
return fmt.Errorf("name and url cannot be empty")
|
||||
}
|
||||
|
||||
if err := b.DBManager.Conn.Clauses(
|
||||
clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "url"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"name", "last_updated", "active"}),
|
||||
},
|
||||
).Create(&database.Source{
|
||||
Name: name,
|
||||
URL: url,
|
||||
LastUpdated: time.Now().Unix(),
|
||||
Active: true,
|
||||
}).Error; err != nil {
|
||||
return fmt.Errorf("failed to insert source: %w", err)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, s := range b.BlocklistURL {
|
||||
if s.Name == name && s.URL == url {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
b.BlocklistURL = append(b.BlocklistURL, BlocklistSource{Name: name, URL: url})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Blacklist) RemoveSourceByNameAndURL(name, url string) bool {
|
||||
for i := len(b.BlocklistURL) - 1; i >= 0; i-- {
|
||||
if b.BlocklistURL[i].Name == name && b.BlocklistURL[i].URL == url {
|
||||
b.BlocklistURL = append(b.BlocklistURL[:i], b.BlocklistURL[i+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
package lists
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"goaway/backend/dns/database"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type Whitelist struct {
|
||||
DBManager *database.DatabaseManager
|
||||
Cache map[string]bool
|
||||
}
|
||||
|
||||
func InitializeWhitelist(dbManager *database.DatabaseManager) (*Whitelist, error) {
|
||||
w := &Whitelist{
|
||||
DBManager: dbManager,
|
||||
Cache: map[string]bool{},
|
||||
}
|
||||
|
||||
return w, w.refreshCache()
|
||||
}
|
||||
|
||||
func (w *Whitelist) AddDomain(domain string) error {
|
||||
result := w.DBManager.Conn.Clauses(clause.OnConflict{DoNothing: true}).Create(&database.Whitelist{Domain: domain})
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to add domain to whitelist: %w", result.Error)
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("%s is already whitelisted", domain)
|
||||
}
|
||||
|
||||
w.Cache[domain] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Whitelist) RemoveDomain(domain string) error {
|
||||
result := w.DBManager.Conn.Delete(&database.Whitelist{}, "domain = ?", domain)
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to remove domain from whitelist: %w", result.Error)
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("%s does not exist", domain)
|
||||
}
|
||||
|
||||
delete(w.Cache, domain)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Whitelist) refreshCache() error {
|
||||
for k := range w.Cache {
|
||||
delete(w.Cache, k)
|
||||
}
|
||||
|
||||
domains, err := w.GetDomains()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not get whitelisted domains while refreshing cache, %v", err)
|
||||
}
|
||||
|
||||
for domain := range domains {
|
||||
w.Cache[domain] = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Whitelist) GetDomains() (map[string]bool, error) {
|
||||
var records []database.Whitelist
|
||||
if err := w.DBManager.Conn.Find(&records).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to query whitelist: %w", err)
|
||||
}
|
||||
|
||||
domains := make(map[string]bool, len(records))
|
||||
for _, rec := range records {
|
||||
domains[rec.Domain] = true
|
||||
}
|
||||
|
||||
return domains, nil
|
||||
}
|
||||
|
||||
func (w *Whitelist) IsWhitelisted(domain string) bool {
|
||||
return w.Cache[domain]
|
||||
}
|
||||
@@ -32,7 +32,7 @@ func (s *DNSServer) getCachedRecord(cached interface{}) ([]dns.RR, bool) {
|
||||
|
||||
if cachedRecord.Key != "" {
|
||||
log.Debug("Cached entry has expired, removing %s from cache", cachedRecord.Key)
|
||||
s.Cache.Delete(cachedRecord.Key)
|
||||
s.DomainCache.Delete(cachedRecord.Key)
|
||||
}
|
||||
|
||||
return nil, false
|
||||
@@ -43,14 +43,14 @@ func (s *DNSServer) RemoveCachedDomain(domain string) {
|
||||
return
|
||||
}
|
||||
|
||||
s.Cache.Range(func(key, value interface{}) bool {
|
||||
s.DomainCache.Range(func(key, value interface{}) bool {
|
||||
cachedRecord, ok := value.(CachedRecord)
|
||||
if !ok || cachedRecord.Domain != domain+"." {
|
||||
return true
|
||||
}
|
||||
|
||||
log.Debug("Removing cached record for domain %s", domain)
|
||||
s.Cache.Delete(key)
|
||||
s.DomainCache.Delete(key)
|
||||
return true
|
||||
})
|
||||
}
|
||||
@@ -69,7 +69,7 @@ func (s *DNSServer) CacheRecord(cacheKey, domain string, ipAddresses []dns.RR, t
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
s.Cache.Store(cacheKey, CachedRecord{
|
||||
s.DomainCache.Store(cacheKey, CachedRecord{
|
||||
IPAddresses: ipAddresses,
|
||||
ExpiresAt: now.Add(cacheTTL),
|
||||
CachedAt: now,
|
||||
|
||||
@@ -16,11 +16,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
MaxDoHRequestSize = 4096
|
||||
DoHTimeout = 20 * time.Second
|
||||
DoHReadTimeout = 8 * time.Second
|
||||
DoHWriteTimeout = 8 * time.Second
|
||||
MB = 1 << 20
|
||||
maxDoHRequestSize = 4096
|
||||
doHTimeout = 20 * time.Second
|
||||
doHReadTimeout = 8 * time.Second
|
||||
doHWriteTimeout = 8 * time.Second
|
||||
megabyte = 1 << 20
|
||||
)
|
||||
|
||||
func (s *DNSServer) InitDoH(cert tls.Certificate) (*http.Server, error) {
|
||||
@@ -29,7 +29,7 @@ func (s *DNSServer) InitDoH(cert tls.Certificate) (*http.Server, error) {
|
||||
mux.HandleFunc("/health", s.handleHealthCheck)
|
||||
|
||||
server := &http.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", s.Config.DNS.Address, s.Config.DNS.DoHPort),
|
||||
Addr: fmt.Sprintf("%s:%d", s.Config.DNS.Address, s.Config.DNS.Ports.DoH),
|
||||
Handler: mux,
|
||||
TLSConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
@@ -38,17 +38,17 @@ func (s *DNSServer) InitDoH(cert tls.Certificate) (*http.Server, error) {
|
||||
PreferServerCipherSuites: true,
|
||||
NextProtos: []string{"h2", "http/1.1"},
|
||||
},
|
||||
ReadTimeout: DoHReadTimeout,
|
||||
WriteTimeout: DoHWriteTimeout,
|
||||
ReadTimeout: doHReadTimeout,
|
||||
WriteTimeout: doHWriteTimeout,
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
IdleTimeout: 60 * time.Second,
|
||||
MaxHeaderBytes: 1 * MB,
|
||||
MaxHeaderBytes: 1 * megabyte,
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func (s *DNSServer) handleHealthCheck(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *DNSServer) handleHealthCheck(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err := w.Write([]byte(`{"status":"healthy"}`))
|
||||
@@ -60,14 +60,14 @@ func (s *DNSServer) handleHealthCheck(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *DNSServer) handleDoHRequest(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), DoHTimeout)
|
||||
ctx, cancel := context.WithTimeout(r.Context(), doHTimeout)
|
||||
defer cancel()
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
log.Debug("DoH request received: %s %s from %s", r.Method, r.URL.String(), r.RemoteAddr)
|
||||
|
||||
if r.ContentLength > MaxDoHRequestSize {
|
||||
if r.ContentLength > maxDoHRequestSize {
|
||||
log.Warning("DoH request too large: %d bytes from %s", r.ContentLength, r.RemoteAddr)
|
||||
http.Error(w, "Request too large", http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
@@ -79,9 +79,9 @@ func (s *DNSServer) handleDoHRequest(w http.ResponseWriter, r *http.Request) {
|
||||
client model.Client
|
||||
)
|
||||
if xRealIP != "" {
|
||||
go s.WSCom(communicationMessage{true, false, false, xRealIP})
|
||||
go s.WSCom(communicationMessage{IP: xRealIP, Client: true, Upstream: false, DNS: false})
|
||||
} else {
|
||||
go s.WSCom(communicationMessage{true, false, false, clientIP})
|
||||
go s.WSCom(communicationMessage{IP: clientIP, Client: true, Upstream: false, DNS: false})
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -90,9 +90,9 @@ func (s *DNSServer) handleDoHRequest(w http.ResponseWriter, r *http.Request) {
|
||||
)
|
||||
|
||||
switch r.Method {
|
||||
case "GET":
|
||||
case http.MethodGet:
|
||||
dnsQuery, err = s.handleDoHGet(r)
|
||||
case "POST":
|
||||
case http.MethodPost:
|
||||
dnsQuery, err = s.handleDoHPost(r)
|
||||
default:
|
||||
log.Warning("DoH request invalid method: %s from %s", r.Method, r.RemoteAddr)
|
||||
@@ -124,7 +124,7 @@ func (s *DNSServer) handleDoHRequest(w http.ResponseWriter, r *http.Request) {
|
||||
responseWriter := &DoHResponseWriter{
|
||||
httpWriter: w,
|
||||
remoteAddr: r.RemoteAddr,
|
||||
DoHPort: s.Config.DNS.DoHPort,
|
||||
DoHPort: s.Config.DNS.Ports.DoH,
|
||||
}
|
||||
|
||||
if xRealIP != "" {
|
||||
@@ -147,7 +147,7 @@ func (s *DNSServer) handleDoHRequest(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
logEntry := s.processQuery(req)
|
||||
|
||||
go s.WSCom(communicationMessage{false, false, true, clientIP})
|
||||
go s.WSCom(communicationMessage{IP: clientIP, Client: false, Upstream: false, DNS: true})
|
||||
|
||||
select {
|
||||
case s.logEntryChannel <- logEntry:
|
||||
@@ -162,7 +162,7 @@ func (s *DNSServer) handleDoHGet(r *http.Request) ([]byte, error) {
|
||||
return nil, fmt.Errorf("missing dns parameter")
|
||||
}
|
||||
|
||||
if len(dnsParam) > MaxDoHRequestSize {
|
||||
if len(dnsParam) > maxDoHRequestSize {
|
||||
return nil, fmt.Errorf("dns parameter too long")
|
||||
}
|
||||
|
||||
@@ -184,7 +184,7 @@ func (s *DNSServer) handleDoHPost(r *http.Request) ([]byte, error) {
|
||||
return nil, fmt.Errorf("invalid content type: %s", contentType)
|
||||
}
|
||||
|
||||
limitedReader := io.LimitReader(r.Body, MaxDoHRequestSize)
|
||||
limitedReader := io.LimitReader(r.Body, maxDoHRequestSize)
|
||||
dnsQuery, err := io.ReadAll(limitedReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read request body: %w", err)
|
||||
@@ -266,4 +266,4 @@ func (w *DoHResponseWriter) Header() http.Header { return nil }
|
||||
|
||||
func (w *DoHResponseWriter) Network() string { return "tcp" }
|
||||
|
||||
func (w *DoHResponseWriter) WriteHeader(statusCode int) {}
|
||||
func (w *DoHResponseWriter) WriteHeader(_ int) {}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
func (s *DNSServer) InitDoT(cert tls.Certificate) (*dns.Server, error) {
|
||||
notifyReady := func() {
|
||||
log.Info("Started DoT (dns-over-tls) server on port %d", s.Config.DNS.DoTPort)
|
||||
log.Info("Started DoT (dns-over-tls) server on port %d", s.Config.DNS.Ports.DoT)
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
@@ -18,7 +18,7 @@ func (s *DNSServer) InitDoT(cert tls.Certificate) (*dns.Server, error) {
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
server := &dns.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", s.Config.DNS.Address, s.Config.DNS.DoTPort),
|
||||
Addr: fmt.Sprintf("%s:%d", s.Config.DNS.Address, s.Config.DNS.Ports.DoT),
|
||||
Net: "tcp-tls",
|
||||
Handler: s,
|
||||
TLSConfig: tlsConfig,
|
||||
|
||||
@@ -5,9 +5,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
arp "goaway/backend/dns"
|
||||
"goaway/backend/dns/database"
|
||||
model "goaway/backend/dns/server/models"
|
||||
notification "goaway/backend/notifications"
|
||||
"goaway/backend/notification"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -24,8 +23,13 @@ var (
|
||||
blackholeIPv6 = net.ParseIP("::")
|
||||
)
|
||||
|
||||
const (
|
||||
IPv4Loopback = "127.0.0.1"
|
||||
unknownHostname = "unknown"
|
||||
)
|
||||
|
||||
func trimDomainDot(name string) string {
|
||||
if len(name) > 0 && name[len(name)-1] == '.' {
|
||||
if name != "" && name[len(name)-1] == '.' {
|
||||
return name[:len(name)-1]
|
||||
}
|
||||
return name
|
||||
@@ -37,15 +41,15 @@ func isPTRQuery(request *Request, domainName string) bool {
|
||||
|
||||
func (s *DNSServer) checkAndUpdatePauseStatus() {
|
||||
if s.Config.DNS.Status.Paused &&
|
||||
time.Since(s.Config.DNS.Status.PausedAt).Seconds() >= float64(s.Config.DNS.Status.PauseTime) {
|
||||
s.Config.DNS.Status.PausedAt.After(s.Config.DNS.Status.PauseTime) {
|
||||
s.Config.DNS.Status.Paused = false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DNSServer) shouldBlockQuery(domainName, fullName string) bool {
|
||||
return !s.Config.DNS.Status.Paused &&
|
||||
s.Blacklist.IsBlacklisted(domainName) &&
|
||||
!s.Whitelist.IsWhitelisted(fullName)
|
||||
s.BlacklistService.IsBlacklisted(domainName) &&
|
||||
!s.WhitelistService.IsWhitelisted(fullName)
|
||||
}
|
||||
|
||||
func (s *DNSServer) processQuery(request *Request) model.RequestLogEntry {
|
||||
@@ -77,20 +81,6 @@ func (s *DNSServer) processQuery(request *Request) model.RequestLogEntry {
|
||||
return s.handleStandardQuery(request)
|
||||
}
|
||||
|
||||
func (s *DNSServer) GetVendor(mac string) (string, error) {
|
||||
s.DBManager.Mutex.Lock()
|
||||
defer s.DBManager.Mutex.Unlock()
|
||||
return database.FindVendor(s.DBManager.Conn, mac)
|
||||
}
|
||||
|
||||
func (s *DNSServer) SaveMacVendor(clientIP, mac, vendor string) {
|
||||
s.DBManager.Mutex.Lock()
|
||||
defer s.DBManager.Mutex.Unlock()
|
||||
|
||||
log.Debug("Saving new MAC address: %s %s", mac, vendor)
|
||||
database.SaveMacEntry(s.DBManager.Conn, clientIP, mac, vendor)
|
||||
}
|
||||
|
||||
func (s *DNSServer) reverseHostnameLookup(requestedHostname string) (string, bool) {
|
||||
trimmed := strings.TrimSuffix(requestedHostname, ".")
|
||||
|
||||
@@ -114,24 +104,24 @@ func (s *DNSServer) getClientInfo(remoteAddr string) *model.Client {
|
||||
hostname := s.resolveHostname(clientIP)
|
||||
resultIP := clientIP
|
||||
|
||||
vendor, err := s.GetVendor(macAddress)
|
||||
if macAddress != "unknown" {
|
||||
vendor, err := s.MACService.FindVendor(macAddress)
|
||||
if macAddress != unknownHostname {
|
||||
if err != nil || vendor == "" {
|
||||
log.Debug("Lookup vendor for mac %s", macAddress)
|
||||
vendor, err = arp.GetMacVendor(macAddress)
|
||||
if err == nil {
|
||||
s.SaveMacVendor(clientIP, macAddress, vendor)
|
||||
s.MACService.SaveMac(clientIP, macAddress, vendor)
|
||||
} else {
|
||||
log.Warning("Error while lookup mac address vendor: %v", err)
|
||||
log.Warning("Was not able to find vendor for addr '%s'. %v", remoteAddr, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if clientIP == "127.0.0.1" || clientIP == "::1" || clientIP == "[" {
|
||||
if clientIP == IPv4Loopback || clientIP == "::1" || clientIP == "[" {
|
||||
localIP, err := getLocalIP()
|
||||
if err != nil {
|
||||
log.Warning("Failed to get local IP: %v", err)
|
||||
localIP = "127.0.0.1"
|
||||
localIP = IPv4Loopback
|
||||
}
|
||||
resultIP = localIP
|
||||
|
||||
@@ -145,7 +135,7 @@ func (s *DNSServer) getClientInfo(remoteAddr string) *model.Client {
|
||||
client := model.Client{IP: resultIP, Name: hostname, MAC: macAddress}
|
||||
s.clientCache.Store(clientIP, &client)
|
||||
|
||||
if client.Name != "unknown" {
|
||||
if client.Name != unknownHostname {
|
||||
s.hostnameCache.Store(client.Name, client.IP)
|
||||
}
|
||||
|
||||
@@ -153,19 +143,27 @@ func (s *DNSServer) getClientInfo(remoteAddr string) *model.Client {
|
||||
}
|
||||
|
||||
func (s *DNSServer) resolveHostname(clientIP string) string {
|
||||
if hostname := s.reverseDNSLookup(clientIP); hostname != "unknown" {
|
||||
ip := net.ParseIP(clientIP)
|
||||
if ip.IsLoopback() {
|
||||
hostname, err := os.Hostname()
|
||||
if err == nil {
|
||||
return hostname
|
||||
}
|
||||
}
|
||||
|
||||
if hostname := s.reverseDNSLookup(clientIP); hostname != unknownHostname {
|
||||
return hostname
|
||||
}
|
||||
|
||||
if hostname := s.avahiLookup(clientIP); hostname != "unknown" {
|
||||
if hostname := s.avahiLookup(clientIP); hostname != unknownHostname {
|
||||
return hostname
|
||||
}
|
||||
|
||||
if hostname := s.sshBannerLookup(clientIP); hostname != "unknown" {
|
||||
if hostname := s.sshBannerLookup(clientIP); hostname != unknownHostname {
|
||||
return hostname
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
return unknownHostname
|
||||
}
|
||||
|
||||
func (s *DNSServer) avahiLookup(clientIP string) string {
|
||||
@@ -175,8 +173,8 @@ func (s *DNSServer) avahiLookup(clientIP string) string {
|
||||
cmd := exec.CommandContext(ctx, "avahi-resolve-address", clientIP)
|
||||
output, err := cmd.Output()
|
||||
if err == nil {
|
||||
lines := strings.Split(string(output), "\n")
|
||||
for _, line := range lines {
|
||||
lines := strings.SplitSeq(string(output), "\n")
|
||||
for line := range lines {
|
||||
if strings.Contains(line, clientIP) {
|
||||
parts := strings.Fields(line)
|
||||
if len(parts) >= 2 {
|
||||
@@ -190,7 +188,7 @@ func (s *DNSServer) avahiLookup(clientIP string) string {
|
||||
}
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
return unknownHostname
|
||||
}
|
||||
|
||||
func (s *DNSServer) reverseDNSLookup(clientIP string) string {
|
||||
@@ -206,13 +204,13 @@ func (s *DNSServer) reverseDNSLookup(clientIP string) string {
|
||||
return hostname
|
||||
}
|
||||
}
|
||||
return "unknown"
|
||||
return unknownHostname
|
||||
}
|
||||
|
||||
func (s *DNSServer) sshBannerLookup(clientIP string) string {
|
||||
conn, err := net.DialTimeout("tcp", clientIP+":22", 1*time.Second)
|
||||
if err != nil {
|
||||
return "unknown"
|
||||
return unknownHostname
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
@@ -222,13 +220,13 @@ func (s *DNSServer) sshBannerLookup(clientIP string) string {
|
||||
if err != nil {
|
||||
log.Warning("Failed to set deadline for SSH banner lookup: %v", err)
|
||||
_ = conn.Close()
|
||||
return "unknown"
|
||||
return unknownHostname
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
banner, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return "unknown"
|
||||
return unknownHostname
|
||||
}
|
||||
|
||||
patterns := []*regexp.Regexp{
|
||||
@@ -248,7 +246,7 @@ func (s *DNSServer) sshBannerLookup(clientIP string) string {
|
||||
}
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
return unknownHostname
|
||||
}
|
||||
|
||||
func getLocalIP() (string, error) {
|
||||
@@ -265,7 +263,7 @@ func getLocalIP() (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return "127.0.0.1", fmt.Errorf("no non-loopback IPv4 address found")
|
||||
return IPv4Loopback, fmt.Errorf("no non-loopback IPv4 address found")
|
||||
}
|
||||
|
||||
func (s *DNSServer) handlePTRQuery(request *Request) model.RequestLogEntry {
|
||||
@@ -277,7 +275,7 @@ func (s *DNSServer) handlePTRQuery(request *Request) model.RequestLogEntry {
|
||||
}
|
||||
ipStr := strings.Join(parts, ".")
|
||||
|
||||
if ipStr == "127.0.0.1" {
|
||||
if ipStr == IPv4Loopback {
|
||||
return s.respondWithLocalhost(request)
|
||||
}
|
||||
|
||||
@@ -285,12 +283,12 @@ func (s *DNSServer) handlePTRQuery(request *Request) model.RequestLogEntry {
|
||||
return s.forwardPTRQueryUpstream(request)
|
||||
}
|
||||
|
||||
hostname := database.GetClientNameFromRequestLog(s.DBManager.Conn, ipStr)
|
||||
if hostname == "unknown" {
|
||||
hostname := s.RequestService.GetClientNameFromIP(ipStr)
|
||||
if hostname == unknownHostname {
|
||||
hostname = s.resolveHostname(ipStr)
|
||||
}
|
||||
|
||||
if hostname != "unknown" {
|
||||
if hostname != unknownHostname {
|
||||
return s.respondWithHostnamePTR(request, hostname)
|
||||
}
|
||||
|
||||
@@ -540,11 +538,11 @@ func (s *DNSServer) handleStandardQuery(request *Request) model.RequestLogEntry
|
||||
err := request.ResponseWriter.WriteMsg(request.Msg)
|
||||
if err != nil {
|
||||
log.Warning("Could not write query response. client: [%s] with query [%v], err: %v", request.Client.IP, request.Msg.Answer, err.Error())
|
||||
s.Notifications.CreateNotification(¬ification.Notification{
|
||||
Severity: notification.SeverityWarning,
|
||||
Category: notification.CategoryDNS,
|
||||
Text: fmt.Sprintf("Could not write query response. Client: %s, err: %v", request.Client.IP, err.Error()),
|
||||
})
|
||||
s.NotificationService.SendNotification(
|
||||
notification.SeverityWarning,
|
||||
notification.CategoryDNS,
|
||||
fmt.Sprintf("Could not write query response. Client: %s, err: %v", request.Client.IP, err.Error()),
|
||||
)
|
||||
}
|
||||
|
||||
return model.RequestLogEntry{
|
||||
@@ -563,7 +561,7 @@ func (s *DNSServer) handleStandardQuery(request *Request) model.RequestLogEntry
|
||||
|
||||
func (s *DNSServer) Resolve(req *Request) ([]dns.RR, bool, string) {
|
||||
cacheKey := req.Question.Name + ":" + strconv.Itoa(int(req.Question.Qtype))
|
||||
if cached, found := s.Cache.Load(cacheKey); found {
|
||||
if cached, found := s.DomainCache.Load(cacheKey); found {
|
||||
if ipAddresses, valid := s.getCachedRecord(cached); valid {
|
||||
return ipAddresses, true, dns.RcodeToString[dns.RcodeSuccess]
|
||||
}
|
||||
@@ -593,7 +591,7 @@ func (s *DNSServer) resolveResolution(domain string) ([]dns.RR, uint32, string)
|
||||
status = dns.RcodeToString[dns.RcodeSuccess]
|
||||
)
|
||||
|
||||
ipFound, err := database.FetchResolution(s.DBManager.Conn, domain)
|
||||
ipFound, err := s.ResolutionService.GetResolution(domain)
|
||||
if err != nil {
|
||||
log.Error("Database lookup error for domain (%s): %v", domain, err)
|
||||
return nil, 0, dns.RcodeToString[dns.RcodeServerFailure]
|
||||
@@ -648,14 +646,14 @@ func (s *DNSServer) QueryUpstream(req *Request) ([]dns.RR, uint32, string) {
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
go s.WSCom(communicationMessage{false, true, false, ""})
|
||||
go s.WSCom(communicationMessage{IP: "", Client: false, Upstream: true, DNS: false})
|
||||
|
||||
upstreamMsg := &dns.Msg{}
|
||||
upstreamMsg.SetQuestion(req.Question.Name, req.Question.Qtype)
|
||||
upstreamMsg.RecursionDesired = true
|
||||
upstreamMsg.Id = dns.Id()
|
||||
|
||||
upstream := s.Config.DNS.PreferredUpstream
|
||||
upstream := s.Config.DNS.Upstream.Preferred
|
||||
if s.dnsClient.Net == "tcp-tls" {
|
||||
host, port, err := net.SplitHostPort(upstream)
|
||||
if err != nil {
|
||||
@@ -681,7 +679,7 @@ func (s *DNSServer) QueryUpstream(req *Request) ([]dns.RR, uint32, string) {
|
||||
|
||||
select {
|
||||
case in := <-resultCh:
|
||||
go s.WSCom(communicationMessage{false, false, true, ""})
|
||||
go s.WSCom(communicationMessage{IP: "", Client: false, Upstream: false, DNS: true})
|
||||
|
||||
status := dns.RcodeToString[dns.RcodeServerFailure]
|
||||
if statusStr, ok := dns.RcodeToString[in.Rcode]; ok {
|
||||
@@ -713,11 +711,11 @@ func (s *DNSServer) QueryUpstream(req *Request) ([]dns.RR, uint32, string) {
|
||||
|
||||
case err := <-errCh:
|
||||
log.Warning("Resolution error for domain (%s): %v", req.Question.Name, err)
|
||||
s.Notifications.CreateNotification(¬ification.Notification{
|
||||
Severity: notification.SeverityWarning,
|
||||
Category: notification.CategoryDNS,
|
||||
Text: fmt.Sprintf("Resolution error for domain (%s)", req.Question.Name),
|
||||
})
|
||||
s.NotificationService.SendNotification(
|
||||
notification.SeverityWarning,
|
||||
notification.CategoryDNS,
|
||||
fmt.Sprintf("Resolution error for domain (%s)", req.Question.Name),
|
||||
)
|
||||
return nil, 0, dns.RcodeToString[dns.RcodeServerFailure]
|
||||
|
||||
case <-time.After(5 * time.Second):
|
||||
@@ -733,8 +731,13 @@ func (s *DNSServer) LocalForwardLookup(req *Request) (model.RequestLogEntry, err
|
||||
hostname += "."
|
||||
}
|
||||
|
||||
queryType := req.Question.Qtype
|
||||
if queryType == 0 {
|
||||
queryType = dns.TypeA
|
||||
}
|
||||
|
||||
dnsMsg := new(dns.Msg)
|
||||
dnsMsg.SetQuestion(hostname, dns.TypeA)
|
||||
dnsMsg.SetQuestion(hostname, queryType)
|
||||
|
||||
client := &dns.Client{Net: "udp"}
|
||||
start := time.Now()
|
||||
@@ -759,7 +762,7 @@ func (s *DNSServer) LocalForwardLookup(req *Request) (model.RequestLogEntry, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
if len(ips) == 0 && queryType == dns.TypeA {
|
||||
return model.RequestLogEntry{}, fmt.Errorf("no A records found for hostname: %s", hostname)
|
||||
}
|
||||
|
||||
@@ -772,7 +775,7 @@ func (s *DNSServer) LocalForwardLookup(req *Request) (model.RequestLogEntry, err
|
||||
entry := model.RequestLogEntry{
|
||||
Domain: req.Question.Name,
|
||||
Status: dns.RcodeToString[in.Rcode],
|
||||
QueryType: dns.TypeToString[dns.TypeA],
|
||||
QueryType: dns.TypeToString[queryType],
|
||||
IP: ips,
|
||||
ResponseSizeBytes: in.Len(),
|
||||
Timestamp: start,
|
||||
|
||||
@@ -2,14 +2,13 @@ package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"goaway/backend/dns/database"
|
||||
model "goaway/backend/dns/server/models"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const BatchSize = 1000
|
||||
const batchSize = 1000
|
||||
|
||||
func (s *DNSServer) ProcessLogEntries() {
|
||||
var batch []model.RequestLogEntry
|
||||
@@ -26,7 +25,7 @@ func (s *DNSServer) ProcessLogEntries() {
|
||||
}
|
||||
|
||||
batch = append(batch, entry)
|
||||
if len(batch) >= BatchSize {
|
||||
if len(batch) >= batchSize {
|
||||
s.saveBatch(batch)
|
||||
batch = nil
|
||||
}
|
||||
@@ -40,14 +39,13 @@ func (s *DNSServer) ProcessLogEntries() {
|
||||
}
|
||||
|
||||
func (s *DNSServer) saveBatch(entries []model.RequestLogEntry) {
|
||||
s.DBManager.Mutex.Lock()
|
||||
err := database.SaveRequestLog(s.DBManager.Conn, entries)
|
||||
s.DBManager.Mutex.Unlock()
|
||||
err := s.RequestService.SaveRequestLog(entries)
|
||||
if err != nil {
|
||||
log.Warning("Error while saving logs, reason: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Removes old log entries based on the configured retention period.
|
||||
func (s *DNSServer) ClearOldEntries() {
|
||||
const (
|
||||
maxRetries = 10
|
||||
@@ -56,10 +54,10 @@ func (s *DNSServer) ClearOldEntries() {
|
||||
)
|
||||
|
||||
for {
|
||||
requestThreshold := ((60 * 60) * 24) * s.Config.StatisticsRetention
|
||||
requestThreshold := ((60 * 60) * 24) * s.Config.Misc.StatisticsRetention
|
||||
log.Debug("Next cleanup running at %s", time.Now().Add(cleanupInterval).Format(time.DateTime))
|
||||
time.Sleep(cleanupInterval)
|
||||
|
||||
database.DeleteRequestLogsTimebased(s.Blacklist.Vacuum, s.DBManager.Conn, requestThreshold, maxRetries, retryDelay)
|
||||
s.RequestService.DeleteRequestLogsTimebased(s.BlacklistService.Vacuum, requestThreshold, maxRetries, retryDelay)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,18 +3,18 @@ package model
|
||||
import "time"
|
||||
|
||||
type RequestLogEntry struct {
|
||||
ID int64 `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
ClientInfo *Client `json:"client"`
|
||||
Domain string `json:"domain"`
|
||||
Status string `json:"status"`
|
||||
QueryType string `json:"queryType"`
|
||||
Protocol Protocol `json:"protocol"`
|
||||
IP []ResolvedIP `json:"ip"`
|
||||
ID uint `json:"id"`
|
||||
ResponseSizeBytes int `json:"responseSizeBytes"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
ResponseTime time.Duration `json:"responseTimeNS"`
|
||||
Blocked bool `json:"blocked"`
|
||||
Cached bool `json:"cached"`
|
||||
ClientInfo *Client `json:"client"`
|
||||
Protocol Protocol `json:"protocol"`
|
||||
}
|
||||
|
||||
type Protocol string
|
||||
@@ -39,8 +39,8 @@ type RequestLogIntervalSummary struct {
|
||||
}
|
||||
|
||||
type ResponseSizeSummary struct {
|
||||
StartUnix int64 `json:"-"`
|
||||
Start time.Time `json:"start"`
|
||||
StartUnix int64 `json:"-"`
|
||||
TotalSizeBytes int `json:"total_size_bytes"`
|
||||
AvgResponseSizeBytes int `json:"avg_response_size_bytes"`
|
||||
MinResponseSizeBytes int `json:"min_response_size_bytes"`
|
||||
|
||||
@@ -1,204 +0,0 @@
|
||||
package prefetch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"goaway/backend/dns/database"
|
||||
"goaway/backend/dns/server"
|
||||
"goaway/backend/logging"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
var log = logging.GetLogger()
|
||||
|
||||
type Manager struct {
|
||||
dbManager *database.DatabaseManager
|
||||
DNS *server.DNSServer
|
||||
Domains map[string]PrefetchedDomain
|
||||
}
|
||||
|
||||
type PrefetchedDomain struct {
|
||||
Domain string `json:"domain"`
|
||||
Refresh int `json:"refresh"`
|
||||
Qtype int `json:"qtype"`
|
||||
}
|
||||
|
||||
func (manager *Manager) LoadPrefetchedDomains() {
|
||||
var prefetched []database.Prefetch
|
||||
if err := manager.dbManager.Conn.Find(&prefetched).Error; err != nil {
|
||||
log.Warning("Failed to query prefetch table: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, p := range prefetched {
|
||||
manager.Domains[p.Domain] = PrefetchedDomain{
|
||||
Domain: p.Domain,
|
||||
Refresh: p.Refresh,
|
||||
Qtype: p.QType,
|
||||
}
|
||||
}
|
||||
|
||||
if len(manager.Domains) > 0 {
|
||||
log.Info("Loaded %d prefetched domain(s)", len(manager.Domains))
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *Manager) AddPrefetchedDomain(domain string, refresh, qtype int) error {
|
||||
prefetch := database.Prefetch{
|
||||
Domain: domain,
|
||||
Refresh: refresh,
|
||||
QType: qtype,
|
||||
}
|
||||
|
||||
result := manager.dbManager.Conn.FirstOrCreate(&prefetch, database.Prefetch{Domain: domain})
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to add new domain to prefetch table: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("%s already exists", domain)
|
||||
}
|
||||
|
||||
manager.Domains[domain] = PrefetchedDomain{
|
||||
Domain: domain,
|
||||
Refresh: refresh,
|
||||
Qtype: qtype,
|
||||
}
|
||||
|
||||
log.Info("%s was added as a prefetched domain", domain)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *Manager) RemovePrefetchedDomain(domain string) error {
|
||||
result := manager.dbManager.Conn.Delete(&database.Prefetch{}, "domain = ?", domain)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to remove %s from prefetch table: %w", domain, result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("%s does not exist in the database", domain)
|
||||
}
|
||||
|
||||
delete(manager.Domains, domain)
|
||||
log.Info("%s was removed as a prefetched domain", domain)
|
||||
return nil
|
||||
}
|
||||
|
||||
func New(dnsServer *server.DNSServer) Manager {
|
||||
manager := Manager{
|
||||
dbManager: dnsServer.DBManager,
|
||||
DNS: dnsServer,
|
||||
Domains: make(map[string]PrefetchedDomain),
|
||||
}
|
||||
|
||||
manager.LoadPrefetchedDomains()
|
||||
return manager
|
||||
}
|
||||
|
||||
func (manager *Manager) Run() {
|
||||
ticker := time.NewTicker(500 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
manager.checkNewDomains()
|
||||
manager.processExpiredEntries()
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *Manager) checkNewDomains() {
|
||||
for domain, prefetchDomain := range manager.Domains {
|
||||
cacheKey := manager.buildCacheKey(domain, dns.Type(prefetchDomain.Qtype))
|
||||
if _, exists := manager.DNS.Cache.Load(cacheKey); !exists {
|
||||
log.Debug("Prefetching new/missing domain: %s", domain)
|
||||
manager.prefetchDomain(prefetchDomain)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *Manager) processExpiredEntries() {
|
||||
now := time.Now()
|
||||
var expiredKeys []interface{}
|
||||
var removeFromDomains []string
|
||||
|
||||
manager.DNS.Cache.Range(func(key, value interface{}) bool {
|
||||
cachedDomain, ok := value.(server.CachedRecord)
|
||||
if !ok {
|
||||
log.Debug("Cache entry type assertion failed for key: %v", key)
|
||||
return true
|
||||
}
|
||||
|
||||
if manager.isExpired(cachedDomain, now) {
|
||||
expiredKeys = append(expiredKeys, key)
|
||||
|
||||
if _, isPrefetched := manager.Domains[cachedDomain.Domain]; !isPrefetched {
|
||||
removeFromDomains = append(removeFromDomains, cachedDomain.Domain)
|
||||
log.Debug("Non-prefetch entry '%v' expired and will be removed", key)
|
||||
} else {
|
||||
log.Debug("Prefetch entry '%v' expired and will be refreshed", key)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
manager.handleExpiredKeys(expiredKeys)
|
||||
manager.removeNonPrefetchDomains(removeFromDomains)
|
||||
}
|
||||
|
||||
func (manager *Manager) isExpired(record server.CachedRecord, now time.Time) bool {
|
||||
return now.After(record.ExpiresAt) || now.Equal(record.ExpiresAt)
|
||||
}
|
||||
|
||||
func (manager *Manager) handleExpiredKeys(expiredKeys []interface{}) {
|
||||
for _, key := range expiredKeys {
|
||||
if value, exists := manager.DNS.Cache.Load(key); exists {
|
||||
if cachedDomain, ok := value.(server.CachedRecord); ok {
|
||||
manager.DNS.Cache.Delete(key)
|
||||
manager.handleExpiredEntry(cachedDomain)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *Manager) removeNonPrefetchDomains(domains []string) {
|
||||
for _, domain := range domains {
|
||||
delete(manager.Domains, domain)
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *Manager) prefetchDomain(prefetchDomain PrefetchedDomain) {
|
||||
question := dns.Question{
|
||||
Name: prefetchDomain.Domain,
|
||||
Qtype: uint16(prefetchDomain.Qtype),
|
||||
Qclass: 1,
|
||||
}
|
||||
|
||||
request := &server.Request{
|
||||
Msg: &dns.Msg{Question: []dns.Question{question}},
|
||||
Question: question,
|
||||
Sent: time.Now(),
|
||||
Prefetch: true,
|
||||
}
|
||||
|
||||
answers, ttl, _ := manager.DNS.QueryUpstream(request)
|
||||
cacheKey := manager.buildCacheKey(question.Name, dns.Type(question.Qtype))
|
||||
manager.DNS.CacheRecord(cacheKey, prefetchDomain.Domain, answers, ttl)
|
||||
}
|
||||
|
||||
func (manager *Manager) handleExpiredEntry(record server.CachedRecord) {
|
||||
domain := record.IPAddresses[0].Header().Name
|
||||
prefetchDomain, exists := manager.Domains[domain]
|
||||
|
||||
if !exists {
|
||||
log.Debug("%s not set to be prefetched", domain)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("Prefetching expired domain: %s", domain)
|
||||
manager.prefetchDomain(prefetchDomain)
|
||||
}
|
||||
|
||||
func (manager *Manager) buildCacheKey(domain string, qtype dns.Type) string {
|
||||
return domain + ":" + strconv.Itoa(int(qtype))
|
||||
}
|
||||
@@ -3,22 +3,25 @@ package server
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"goaway/backend/alert"
|
||||
"goaway/backend/audit"
|
||||
"goaway/backend/dns/database"
|
||||
"goaway/backend/dns/lists"
|
||||
"goaway/backend/blacklist"
|
||||
model "goaway/backend/dns/server/models"
|
||||
"goaway/backend/logging"
|
||||
notification "goaway/backend/notifications"
|
||||
"goaway/backend/mac"
|
||||
"goaway/backend/notification"
|
||||
"goaway/backend/request"
|
||||
"goaway/backend/resolution"
|
||||
"goaway/backend/settings"
|
||||
"goaway/backend/user"
|
||||
"goaway/backend/whitelist"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/miekg/dns"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -26,74 +29,66 @@ var (
|
||||
)
|
||||
|
||||
type DNSServer struct {
|
||||
DBConn *gorm.DB
|
||||
dnsClient *dns.Client
|
||||
Config *settings.Config
|
||||
Blacklist *lists.Blacklist
|
||||
Whitelist *lists.Whitelist
|
||||
DBManager *database.DatabaseManager
|
||||
logIntervalSeconds int
|
||||
lastLogTime time.Time
|
||||
Cache sync.Map
|
||||
clientCache sync.Map
|
||||
hostnameCache sync.Map
|
||||
WebServer *gin.Engine
|
||||
logEntryChannel chan model.RequestLogEntry
|
||||
WSQueries *websocket.Conn
|
||||
WSCommunication *websocket.Conn
|
||||
hostnameCache sync.Map
|
||||
clientCache sync.Map
|
||||
DomainCache sync.Map
|
||||
WSCommunicationLock sync.Mutex
|
||||
dnsClient *dns.Client
|
||||
Notifications *notification.Manager
|
||||
Alerts *alert.Manager
|
||||
Audits *audit.Manager
|
||||
|
||||
RequestService *request.Service
|
||||
AuditService *audit.Service
|
||||
UserService *user.Service
|
||||
AlertService *alert.Service
|
||||
MACService *mac.Service
|
||||
ResolutionService *resolution.Service
|
||||
NotificationService *notification.Service
|
||||
BlacklistService *blacklist.Service
|
||||
WhitelistService *whitelist.Service
|
||||
}
|
||||
|
||||
type CachedRecord struct {
|
||||
IPAddresses []dns.RR
|
||||
ExpiresAt time.Time
|
||||
CachedAt time.Time
|
||||
OriginalTTL uint32
|
||||
Key string
|
||||
Domain string
|
||||
IPAddresses []dns.RR
|
||||
OriginalTTL uint32
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
Sent time.Time
|
||||
ResponseWriter dns.ResponseWriter
|
||||
Msg *dns.Msg
|
||||
Question dns.Question
|
||||
Sent time.Time
|
||||
Client *model.Client
|
||||
Prefetch bool
|
||||
Protocol model.Protocol
|
||||
Question dns.Question
|
||||
Prefetch bool
|
||||
}
|
||||
|
||||
type communicationMessage struct {
|
||||
IP string `json:"ip"`
|
||||
Client bool `json:"client"`
|
||||
Upstream bool `json:"upstream"`
|
||||
DNS bool `json:"dns"`
|
||||
Ip string `json:"ip"`
|
||||
}
|
||||
|
||||
func NewDNSServer(config *settings.Config, dbManager *database.DatabaseManager, notificationsManager *notification.Manager, alertManager *alert.Manager, auditManager *audit.Manager, cert tls.Certificate) (*DNSServer, error) {
|
||||
whitelistEntry, err := lists.InitializeWhitelist(dbManager)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize whitelist: %w", err)
|
||||
}
|
||||
|
||||
func NewDNSServer(config *settings.Config, dbconn *gorm.DB, cert tls.Certificate) (*DNSServer, error) {
|
||||
var client dns.Client
|
||||
if cert.Certificate != nil {
|
||||
client = dns.Client{Net: "tcp-tls"}
|
||||
}
|
||||
|
||||
server := &DNSServer{
|
||||
Config: config,
|
||||
Whitelist: whitelistEntry,
|
||||
DBManager: dbManager,
|
||||
logIntervalSeconds: 1,
|
||||
lastLogTime: time.Now(),
|
||||
logEntryChannel: make(chan model.RequestLogEntry, 1000),
|
||||
dnsClient: &client,
|
||||
Notifications: notificationsManager,
|
||||
Alerts: alertManager,
|
||||
Audits: auditManager,
|
||||
Config: config,
|
||||
DBConn: dbconn,
|
||||
logEntryChannel: make(chan model.RequestLogEntry, 1000),
|
||||
dnsClient: &client,
|
||||
DomainCache: sync.Map{},
|
||||
}
|
||||
|
||||
return server, nil
|
||||
@@ -101,7 +96,7 @@ func NewDNSServer(config *settings.Config, dbManager *database.DatabaseManager,
|
||||
|
||||
func (s *DNSServer) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) != 1 {
|
||||
log.Warning("Query container more than one question, ignoring!")
|
||||
log.Warning("Query contains more than one question, ignoring!")
|
||||
r.SetRcode(r, dns.RcodeFormatError)
|
||||
_ = w.WriteMsg(r)
|
||||
return
|
||||
@@ -114,7 +109,7 @@ func (s *DNSServer) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
Client: true,
|
||||
Upstream: false,
|
||||
DNS: false,
|
||||
Ip: client.IP,
|
||||
IP: client.IP,
|
||||
})
|
||||
|
||||
entry := s.processQuery(&Request{
|
||||
@@ -131,7 +126,7 @@ func (s *DNSServer) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
Client: false,
|
||||
Upstream: false,
|
||||
DNS: true,
|
||||
Ip: client.IP,
|
||||
IP: client.IP,
|
||||
})
|
||||
s.logEntryChannel <- entry
|
||||
}
|
||||
@@ -154,27 +149,12 @@ func (s *DNSServer) detectProtocol(w dns.ResponseWriter) model.Protocol {
|
||||
}
|
||||
|
||||
func (s *DNSServer) PopulateHostnameCache() error {
|
||||
type Result struct {
|
||||
ClientIP string
|
||||
ClientName string
|
||||
}
|
||||
|
||||
var results []Result
|
||||
|
||||
if err := s.DBManager.Conn.
|
||||
Model(&database.RequestLog{}).
|
||||
Select("DISTINCT client_ip, client_name").
|
||||
Where("client_name IS NOT NULL AND client_name != ?", "unknown").
|
||||
Find(&results).Error; err != nil {
|
||||
return fmt.Errorf("failed to fetch hostnames: %w", err)
|
||||
}
|
||||
|
||||
for _, r := range results {
|
||||
if _, exists := s.hostnameCache.Load(r.ClientName); !exists {
|
||||
s.hostnameCache.Store(r.ClientName, r.ClientIP)
|
||||
}
|
||||
uniqueClients := s.RequestService.GetUniqueClientNameAndIP()
|
||||
for _, client := range uniqueClients {
|
||||
_, _ = s.hostnameCache.LoadOrStore(client.IP, client.Name)
|
||||
}
|
||||
|
||||
log.Debug("Populated hostname cache with %d client(s)", len(uniqueClients))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -186,12 +166,7 @@ func (s *DNSServer) WSCom(message communicationMessage) {
|
||||
s.WSCommunicationLock.Lock()
|
||||
defer s.WSCommunicationLock.Unlock()
|
||||
|
||||
if err := s.WSCommunication.WriteControl(
|
||||
websocket.PingMessage,
|
||||
nil,
|
||||
time.Now().Add(2*time.Second),
|
||||
); err != nil {
|
||||
log.Debug("Websocket connection not alive, skipping message: %v", err)
|
||||
if s.WSCommunication == nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -201,11 +176,13 @@ func (s *DNSServer) WSCom(message communicationMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.WSCommunication.SetWriteDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
if err := s.WSCommunication.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
log.Warning("Failed to set websocket write deadline: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.WSCommunication.WriteMessage(websocket.TextMessage, entryWSJson); err != nil {
|
||||
log.Debug("Failed to write websocket message: %v", err)
|
||||
s.WSCommunication = nil
|
||||
}
|
||||
}
|
||||
|
||||
76
backend/jobs/background.go
Normal file
76
backend/jobs/background.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package jobs
|
||||
|
||||
import (
|
||||
arp "goaway/backend/dns"
|
||||
"goaway/backend/logging"
|
||||
"goaway/backend/services"
|
||||
)
|
||||
|
||||
var log = logging.GetLogger()
|
||||
|
||||
type BackgroundJobs struct {
|
||||
registry *services.ServiceRegistry
|
||||
}
|
||||
|
||||
func NewBackgroundJobs(registry *services.ServiceRegistry) *BackgroundJobs {
|
||||
return &BackgroundJobs{
|
||||
registry: registry,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BackgroundJobs) Start(readyChan <-chan struct{}) {
|
||||
b.startHostnameCachePopulation()
|
||||
b.cleanVendorResponseCache(readyChan)
|
||||
b.startARPProcessing(readyChan)
|
||||
b.startScheduledUpdates(readyChan)
|
||||
b.startCacheCleanup(readyChan)
|
||||
b.startPrefetcher(readyChan)
|
||||
}
|
||||
|
||||
func (b *BackgroundJobs) startHostnameCachePopulation() {
|
||||
if err := b.registry.Context.DNSServer.PopulateHostnameCache(); err != nil {
|
||||
log.Warning("Unable to populate hostname cache: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BackgroundJobs) startARPProcessing(readyChan <-chan struct{}) {
|
||||
go func() {
|
||||
<-readyChan
|
||||
log.Debug("Starting ARP table processing...")
|
||||
arp.ProcessARPTable()
|
||||
}()
|
||||
}
|
||||
|
||||
func (b *BackgroundJobs) cleanVendorResponseCache(readyChan <-chan struct{}) {
|
||||
go func() {
|
||||
<-readyChan
|
||||
log.Debug("Starting vendor response table processing...")
|
||||
arp.CleanVendorResponseCache()
|
||||
}()
|
||||
}
|
||||
|
||||
func (b *BackgroundJobs) startScheduledUpdates(readyChan <-chan struct{}) {
|
||||
go func() {
|
||||
<-readyChan
|
||||
if b.registry.Context.Config.Misc.ScheduledBlacklistUpdates {
|
||||
log.Debug("Starting scheduler for automatic list updates...")
|
||||
b.registry.BlacklistService.ScheduleAutomaticListUpdates()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (b *BackgroundJobs) startCacheCleanup(readyChan <-chan struct{}) {
|
||||
go func() {
|
||||
<-readyChan
|
||||
log.Debug("Starting cache cleanup routine...")
|
||||
b.registry.Context.DNSServer.ClearOldEntries()
|
||||
}()
|
||||
}
|
||||
|
||||
func (b *BackgroundJobs) startPrefetcher(readyChan <-chan struct{}) {
|
||||
go func() {
|
||||
<-readyChan
|
||||
log.Debug("Starting prefetcher...")
|
||||
b.registry.PrefetchService.Run()
|
||||
}()
|
||||
}
|
||||
61
backend/lifecycle/manager.go
Normal file
61
backend/lifecycle/manager.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"goaway/backend/jobs"
|
||||
"goaway/backend/logging"
|
||||
"goaway/backend/services"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
var log = logging.GetLogger()
|
||||
|
||||
// Coordinates startup, shutdown, and signal handling
|
||||
type Manager struct {
|
||||
services *services.ServiceRegistry
|
||||
backgroundJobs *jobs.BackgroundJobs
|
||||
signalChan chan os.Signal
|
||||
}
|
||||
|
||||
func NewManager(registry *services.ServiceRegistry) *Manager {
|
||||
return &Manager{
|
||||
services: registry,
|
||||
signalChan: make(chan os.Signal, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Run() error {
|
||||
if err := m.services.Initialize(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.backgroundJobs = jobs.NewBackgroundJobs(m.services)
|
||||
|
||||
signal.Notify(m.signalChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
m.services.StartAll()
|
||||
m.backgroundJobs.Start(m.services.ReadyChannel())
|
||||
|
||||
go m.services.WaitGroup().Wait()
|
||||
|
||||
return m.waitForTermination()
|
||||
}
|
||||
|
||||
func (m *Manager) waitForTermination() error {
|
||||
select {
|
||||
case err := <-m.services.ErrorChannel():
|
||||
log.Error("%s server failed: %s", err.Service, err.Err)
|
||||
log.Fatal("Server failure detected. Exiting.")
|
||||
return err.Err
|
||||
case <-m.signalChan:
|
||||
log.Info("Received interrupt. Shutting down.")
|
||||
m.shutdown()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) shutdown() {
|
||||
// TODO: Add graceful shutdown logic
|
||||
os.Exit(0)
|
||||
}
|
||||
@@ -9,11 +9,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
ColorReset = "\033[0m"
|
||||
ColorGray = "\033[90m"
|
||||
ColorWhite = "\033[97m"
|
||||
ColorYellow = "\033[33m"
|
||||
ColorRed = "\033[31m"
|
||||
colorReset = "\033[0m"
|
||||
colorGray = "\033[90m"
|
||||
colorWhite = "\033[97m"
|
||||
colorYellow = "\033[33m"
|
||||
colorRed = "\033[31m"
|
||||
)
|
||||
|
||||
type LogLevel int
|
||||
@@ -27,12 +27,12 @@ const (
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
logger *log.Logger
|
||||
JSONLoggerInstance *slog.Logger
|
||||
logLevel LogLevel
|
||||
LoggingEnabled bool
|
||||
Ansi bool
|
||||
logger *log.Logger
|
||||
JSON bool
|
||||
JSONLoggerInstance *slog.Logger
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -66,7 +66,7 @@ func (l *Logger) SetLevel(level LogLevel) {
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) SetJson(json bool) {
|
||||
func (l *Logger) SetJSON(json bool) {
|
||||
l.JSON = json
|
||||
}
|
||||
|
||||
@@ -74,10 +74,10 @@ func (l *Logger) SetAnsi(ansi bool) {
|
||||
l.Ansi = ansi
|
||||
}
|
||||
|
||||
func (l *Logger) log(level string, color string, message string, msgLevel LogLevel) {
|
||||
func (l *Logger) log(level, color, message string, msgLevel LogLevel) {
|
||||
if !l.JSON {
|
||||
if l.Ansi {
|
||||
l.logger.Printf("%s%s%s%s", color, level, message, ColorReset)
|
||||
l.logger.Printf("%s%s%s%s", color, level, message, colorReset)
|
||||
} else {
|
||||
l.logger.Printf("%s%s", level, message)
|
||||
}
|
||||
@@ -91,6 +91,8 @@ func (l *Logger) log(level string, color string, message string, msgLevel LogLev
|
||||
l.JSONLoggerInstance.Warn(message)
|
||||
case ERROR:
|
||||
l.JSONLoggerInstance.Error(message)
|
||||
default:
|
||||
l.JSONLoggerInstance.Info(message)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -105,9 +107,9 @@ func (l *Logger) Debug(format string, args ...interface{}) {
|
||||
}
|
||||
if len(args) > 0 {
|
||||
message := fmt.Sprintf(format, args...)
|
||||
l.log("[DEBUG] ", ColorGray, message, DEBUG)
|
||||
l.log("[DEBUG] ", colorGray, message, DEBUG)
|
||||
} else {
|
||||
l.log("[DEBUG] ", ColorGray, format, DEBUG)
|
||||
l.log("[DEBUG] ", colorGray, format, DEBUG)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,9 +119,9 @@ func (l *Logger) Info(format string, args ...interface{}) {
|
||||
}
|
||||
if len(args) > 0 {
|
||||
message := fmt.Sprintf(format, args...)
|
||||
l.log("[INFO] ", ColorWhite, message, INFO)
|
||||
l.log("[INFO] ", colorWhite, message, INFO)
|
||||
} else {
|
||||
l.log("[INFO] ", ColorWhite, format, INFO)
|
||||
l.log("[INFO] ", colorWhite, format, INFO)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,9 +131,9 @@ func (l *Logger) Warning(format string, args ...interface{}) {
|
||||
}
|
||||
if len(args) > 0 {
|
||||
message := fmt.Sprintf(format, args...)
|
||||
l.log("[WARN] ", ColorYellow, message, WARNING)
|
||||
l.log("[WARN] ", colorYellow, message, WARNING)
|
||||
} else {
|
||||
l.log("[WARN] ", ColorYellow, format, WARNING)
|
||||
l.log("[WARN] ", colorYellow, format, WARNING)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,9 +143,9 @@ func (l *Logger) Error(format string, args ...interface{}) {
|
||||
}
|
||||
if len(args) > 0 {
|
||||
message := fmt.Sprintf(format, args...)
|
||||
l.log("[ERROR] ", ColorRed, message, ERROR)
|
||||
l.log("[ERROR] ", colorRed, message, ERROR)
|
||||
} else {
|
||||
l.log("[ERROR] ", ColorRed, format, ERROR)
|
||||
l.log("[ERROR] ", colorRed, format, ERROR)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -153,39 +155,9 @@ func (l *Logger) Fatal(format string, args ...interface{}) {
|
||||
}
|
||||
if len(args) > 0 {
|
||||
message := fmt.Sprintf(format, args...)
|
||||
l.log("[FATAL] ", ColorRed, message, FATAL)
|
||||
l.log("[FATAL] ", colorRed, message, FATAL)
|
||||
} else {
|
||||
l.log("[FATAL] ", ColorRed, format, FATAL)
|
||||
l.log("[FATAL] ", colorRed, format, FATAL)
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func FromString(logLevel string) LogLevel {
|
||||
switch logLevel {
|
||||
case "DEBUG":
|
||||
return 0
|
||||
case "INFO":
|
||||
return 1
|
||||
case "WARNING":
|
||||
return 2
|
||||
case "ERROR":
|
||||
return 3
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
func (l LogLevel) String() string {
|
||||
switch l {
|
||||
case DEBUG:
|
||||
return "DEBUG"
|
||||
case INFO:
|
||||
return "INFO"
|
||||
case WARNING:
|
||||
return "WARNING"
|
||||
case ERROR:
|
||||
return "ERROR"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
51
backend/mac/repository.go
Normal file
51
backend/mac/repository.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package mac
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"goaway/backend/database"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
FindVendor(mac string) (string, error)
|
||||
SaveMac(clientIP, mac, vendor string) error
|
||||
}
|
||||
|
||||
type repository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewRepository(db *gorm.DB) Repository {
|
||||
return &repository{db: db}
|
||||
}
|
||||
|
||||
func (r *repository) FindVendor(mac string) (string, error) {
|
||||
var query database.MacAddress
|
||||
tx := r.db.Find(&query, "mac = ?", mac)
|
||||
|
||||
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
|
||||
return "", nil
|
||||
}
|
||||
if tx.Error != nil {
|
||||
return "", tx.Error
|
||||
}
|
||||
|
||||
return query.Vendor, nil
|
||||
}
|
||||
|
||||
func (r *repository) SaveMac(clientIP, mac, vendor string) error {
|
||||
entry := database.MacAddress{
|
||||
MAC: mac,
|
||||
IP: clientIP,
|
||||
Vendor: vendor,
|
||||
}
|
||||
tx := r.db.Save(&entry)
|
||||
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("unable to save new MAC entry %v", tx.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
24
backend/mac/service.go
Normal file
24
backend/mac/service.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package mac
|
||||
|
||||
import "goaway/backend/logging"
|
||||
|
||||
type Service struct {
|
||||
repository Repository
|
||||
}
|
||||
|
||||
var log = logging.GetLogger()
|
||||
|
||||
func NewService(repo Repository) *Service {
|
||||
return &Service{repository: repo}
|
||||
}
|
||||
|
||||
func (s *Service) FindVendor(mac string) (string, error) {
|
||||
return s.repository.FindVendor(mac)
|
||||
}
|
||||
|
||||
func (s *Service) SaveMac(clientIP, mac, vendor string) {
|
||||
err := s.repository.SaveMac(clientIP, mac, vendor)
|
||||
if err != nil {
|
||||
log.Warning("Could not save MAC address, %v", err)
|
||||
}
|
||||
}
|
||||
53
backend/notification/repository.go
Normal file
53
backend/notification/repository.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"goaway/backend/database"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
CreateNotification(newNotification *database.Notification) error
|
||||
GetNotifications() ([]database.Notification, error)
|
||||
MarkNotificationsAsRead(notificationIDs []int) error
|
||||
}
|
||||
|
||||
type repository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewRepository(db *gorm.DB) Repository {
|
||||
return &repository{db: db}
|
||||
}
|
||||
|
||||
func (r *repository) CreateNotification(newNotification *database.Notification) error {
|
||||
tx := r.db.Create(&newNotification)
|
||||
return tx.Error
|
||||
}
|
||||
|
||||
func (r *repository) GetNotifications() ([]database.Notification, error) {
|
||||
var notifications []database.Notification
|
||||
|
||||
result := r.db.Where("read = ?", false).Find(¬ifications)
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return notifications, nil
|
||||
}
|
||||
|
||||
func (r *repository) MarkNotificationsAsRead(notificationIDs []int) error {
|
||||
if len(notificationIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := r.db.Model(&database.Notification{}).
|
||||
Where("id IN ?", notificationIDs).
|
||||
Update("read", true)
|
||||
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
62
backend/notification/service.go
Normal file
62
backend/notification/service.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"goaway/backend/database"
|
||||
"goaway/backend/logging"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
repository Repository
|
||||
}
|
||||
|
||||
type Severity string
|
||||
type Category string
|
||||
|
||||
// Severity level of notification
|
||||
// SeverityInfo: Server was upgraded, password changed...
|
||||
// SeverityWarning: An error occurred on startup, database lock...
|
||||
// SeverityError: Server cant start, requests cant be handled...
|
||||
const (
|
||||
SeverityInfo Severity = "info"
|
||||
SeverityWarning Severity = "warning"
|
||||
SeverityError Severity = "error"
|
||||
)
|
||||
|
||||
// Categories to describe what area the notification covers
|
||||
const (
|
||||
CategoryServer Category = "server"
|
||||
CategoryDNS Category = "dns"
|
||||
CategoryAPI Category = "api"
|
||||
)
|
||||
|
||||
var log = logging.GetLogger()
|
||||
|
||||
func NewService(repo Repository) *Service {
|
||||
return &Service{repository: repo}
|
||||
}
|
||||
|
||||
func (s *Service) SendNotification(severity Severity, category Category, text string) {
|
||||
notification := &database.Notification{
|
||||
Severity: string(severity),
|
||||
Category: string(category),
|
||||
Text: text,
|
||||
Read: false,
|
||||
}
|
||||
|
||||
err := s.repository.CreateNotification(notification)
|
||||
if err != nil {
|
||||
log.Warning("Could not send notification, %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("New notification created, severity: %s", severity)
|
||||
}
|
||||
|
||||
func (s *Service) GetNotifications() ([]database.Notification, error) {
|
||||
return s.repository.GetNotifications()
|
||||
}
|
||||
|
||||
func (s *Service) MarkNotificationsAsRead(notificationIDs []int) error {
|
||||
log.Info("Notifications have been marked as read")
|
||||
return s.repository.MarkNotificationsAsRead(notificationIDs)
|
||||
}
|
||||
@@ -1,89 +0,0 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"goaway/backend/dns/database"
|
||||
"goaway/backend/logging"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
dbManager *database.DatabaseManager
|
||||
}
|
||||
|
||||
type Severity string
|
||||
type Category string
|
||||
|
||||
// Severity level of notification
|
||||
// SeverityInfo: Server was upgraded, password changed...
|
||||
// SeverityWarning: An error occurred on startup, database lock...
|
||||
// SeverityError: Server cant start, requests cant be handled...
|
||||
const (
|
||||
SeverityInfo Severity = "info"
|
||||
SeverityWarning Severity = "warning"
|
||||
SeverityError Severity = "error"
|
||||
)
|
||||
|
||||
// Categories to describe what area the notification covers
|
||||
const (
|
||||
CategoryServer Category = "server"
|
||||
CategoryDNS Category = "dns"
|
||||
CategoryAPI Category = "api"
|
||||
)
|
||||
|
||||
type Notification struct {
|
||||
Id int `json:"id"`
|
||||
Severity Severity `json:"severity"`
|
||||
Category Category `json:"category"`
|
||||
Text string `json:"text"`
|
||||
Read bool `json:"read"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
var logger = logging.GetLogger()
|
||||
|
||||
func NewNotificationManager(dbManager *database.DatabaseManager) *Manager {
|
||||
return &Manager{dbManager: dbManager}
|
||||
}
|
||||
|
||||
func (nm *Manager) CreateNotification(newNotification *Notification) {
|
||||
tx := nm.dbManager.Conn.Create(&database.Notification{
|
||||
Severity: string(newNotification.Severity),
|
||||
Category: string(newNotification.Category),
|
||||
Text: newNotification.Text,
|
||||
Read: false,
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
if tx.Error != nil {
|
||||
logger.Warning("Unable to create new notification, error: %v", tx.Error)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug("Created new notification, %+v", newNotification)
|
||||
}
|
||||
|
||||
func (nm *Manager) ReadNotifications() ([]database.Notification, error) {
|
||||
var notifications []database.Notification
|
||||
|
||||
result := nm.dbManager.Conn.Where("read = ?", true).Find(¬ifications)
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return notifications, nil
|
||||
}
|
||||
|
||||
func (nm *Manager) MarkNotificationsAsRead(notificationIDs []int) error {
|
||||
if len(notificationIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := nm.dbManager.Conn.Model(&database.Notification{}).
|
||||
Where("id IN ?", notificationIDs).
|
||||
Update("read", true)
|
||||
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
53
backend/prefetch/repository.go
Normal file
53
backend/prefetch/repository.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package prefetch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"goaway/backend/database"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
GetAll() ([]database.Prefetch, error)
|
||||
Create(prefetch *database.Prefetch) error
|
||||
Delete(domain string) error
|
||||
}
|
||||
|
||||
type repository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewRepository(db *gorm.DB) Repository {
|
||||
return &repository{db: db}
|
||||
}
|
||||
|
||||
func (r *repository) Create(prefetch *database.Prefetch) error {
|
||||
result := r.db.Create(prefetch)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repository) GetAll() ([]database.Prefetch, error) {
|
||||
var prefetched []database.Prefetch
|
||||
result := r.db.Model(&database.Prefetch{}).Find(&prefetched)
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
return prefetched, nil
|
||||
}
|
||||
|
||||
func (r *repository) Delete(domain string) error {
|
||||
result := r.db.Delete(&database.Prefetch{}, "domain = ?", domain)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("%s does not exist in the database", domain)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
182
backend/prefetch/service.go
Normal file
182
backend/prefetch/service.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package prefetch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"goaway/backend/database"
|
||||
"goaway/backend/dns/server"
|
||||
"goaway/backend/logging"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
repository Repository
|
||||
DNS *server.DNSServer
|
||||
Domains map[string]database.Prefetch
|
||||
}
|
||||
|
||||
var log = logging.GetLogger()
|
||||
|
||||
func NewService(repo Repository, dnsServer *server.DNSServer) *Service {
|
||||
service := &Service{
|
||||
repository: repo,
|
||||
DNS: dnsServer,
|
||||
Domains: make(map[string]database.Prefetch),
|
||||
}
|
||||
|
||||
service.LoadPrefetchedDomains()
|
||||
return service
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
ticker := time.NewTicker(500 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
s.checkNewDomains()
|
||||
s.processExpiredEntries()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) checkNewDomains() {
|
||||
for domain, prefetchDomain := range s.Domains {
|
||||
cacheKey := s.buildCacheKey(domain, dns.Type(prefetchDomain.QueryType))
|
||||
if _, exists := s.DNS.DomainCache.Load(cacheKey); !exists {
|
||||
log.Debug("Prefetching new/missing domain: %s", domain)
|
||||
s.prefetchDomain(prefetchDomain)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) processExpiredEntries() {
|
||||
now := time.Now()
|
||||
var expiredKeys []interface{}
|
||||
var removeFromDomains []string
|
||||
|
||||
s.DNS.DomainCache.Range(func(key, value interface{}) bool {
|
||||
cachedDomain, ok := value.(server.CachedRecord)
|
||||
if !ok {
|
||||
log.Debug("Cache entry type assertion failed for key: %v", key)
|
||||
return true
|
||||
}
|
||||
|
||||
if s.isExpired(cachedDomain, now) {
|
||||
expiredKeys = append(expiredKeys, key)
|
||||
|
||||
if _, isPrefetched := s.Domains[cachedDomain.Domain]; !isPrefetched {
|
||||
removeFromDomains = append(removeFromDomains, cachedDomain.Domain)
|
||||
log.Debug("Non-prefetch entry '%v' expired and will be removed", key)
|
||||
} else {
|
||||
log.Debug("Prefetch entry '%v' expired and will be refreshed", key)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
s.handleExpiredKeys(expiredKeys)
|
||||
s.removeNonPrefetchDomains(removeFromDomains)
|
||||
}
|
||||
|
||||
func (s *Service) isExpired(record server.CachedRecord, now time.Time) bool {
|
||||
return now.After(record.ExpiresAt) || now.Equal(record.ExpiresAt)
|
||||
}
|
||||
|
||||
func (s *Service) handleExpiredKeys(expiredKeys []interface{}) {
|
||||
for _, key := range expiredKeys {
|
||||
if value, exists := s.DNS.DomainCache.Load(key); exists {
|
||||
if cachedDomain, ok := value.(server.CachedRecord); ok {
|
||||
s.DNS.DomainCache.Delete(key)
|
||||
s.handleExpiredEntry(cachedDomain)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) removeNonPrefetchDomains(domains []string) {
|
||||
for _, domain := range domains {
|
||||
delete(s.Domains, domain)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) prefetchDomain(prefetchDomain database.Prefetch) {
|
||||
question := dns.Question{
|
||||
Name: prefetchDomain.Domain,
|
||||
Qtype: uint16(prefetchDomain.QueryType),
|
||||
Qclass: 1,
|
||||
}
|
||||
|
||||
request := &server.Request{
|
||||
Msg: &dns.Msg{Question: []dns.Question{question}},
|
||||
Question: question,
|
||||
Sent: time.Now(),
|
||||
Prefetch: true,
|
||||
}
|
||||
|
||||
answers, ttl, _ := s.DNS.QueryUpstream(request)
|
||||
cacheKey := s.buildCacheKey(question.Name, dns.Type(question.Qtype))
|
||||
s.DNS.CacheRecord(cacheKey, prefetchDomain.Domain, answers, ttl)
|
||||
}
|
||||
|
||||
func (s *Service) buildCacheKey(domain string, qtype dns.Type) string {
|
||||
return domain + ":" + strconv.Itoa(int(qtype))
|
||||
}
|
||||
|
||||
func (s *Service) handleExpiredEntry(record server.CachedRecord) {
|
||||
domain := record.IPAddresses[0].Header().Name
|
||||
prefetchDomain, exists := s.Domains[domain]
|
||||
|
||||
if !exists {
|
||||
log.Debug("%s not set to be prefetched", domain)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("Prefetching expired domain: %s", domain)
|
||||
s.prefetchDomain(prefetchDomain)
|
||||
}
|
||||
|
||||
func (s *Service) LoadPrefetchedDomains() {
|
||||
prefetched, err := s.repository.GetAll()
|
||||
if err != nil {
|
||||
log.Error("failed to load prefetched domains: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, p := range prefetched {
|
||||
s.Domains[p.Domain] = p
|
||||
}
|
||||
|
||||
if len(s.Domains) > 0 {
|
||||
log.Info("Loaded %d prefetched domain(s)", len(s.Domains))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) AddPrefetchedDomain(domain string, refresh, qtype int) error {
|
||||
prefetch := database.Prefetch{
|
||||
Domain: domain,
|
||||
Refresh: refresh,
|
||||
QueryType: qtype,
|
||||
}
|
||||
|
||||
err := s.repository.Create(&prefetch)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add new domain to prefetch table: %w", err)
|
||||
}
|
||||
|
||||
s.Domains[domain] = prefetch
|
||||
|
||||
log.Info("%s was added as a prefetched domain", domain)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) RemovePrefetchedDomain(domain string) error {
|
||||
err := s.repository.Delete(domain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove %s from prefetch table: %w", domain, err)
|
||||
}
|
||||
|
||||
delete(s.Domains, domain)
|
||||
log.Info("%s was removed as a prefetched domain", domain)
|
||||
return nil
|
||||
}
|
||||
15
backend/request/model.go
Normal file
15
backend/request/model.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package request
|
||||
|
||||
import "time"
|
||||
|
||||
type Client struct {
|
||||
LastSeen time.Time
|
||||
Name string
|
||||
Mac string
|
||||
Vendor string
|
||||
}
|
||||
|
||||
type ClientNameAndIP struct {
|
||||
Name string
|
||||
IP string
|
||||
}
|
||||
@@ -1,24 +1,58 @@
|
||||
package database
|
||||
package request
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"goaway/backend/api/models"
|
||||
dbModel "goaway/backend/dns/database/models"
|
||||
"goaway/backend/database"
|
||||
model "goaway/backend/dns/server/models"
|
||||
"goaway/backend/logging"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var log = logging.GetLogger()
|
||||
type Repository interface {
|
||||
GetClientName(ip string) string
|
||||
GetDistinctRequestIP() int
|
||||
GetRequestSummaryByInterval(interval int) ([]model.RequestLogIntervalSummary, error)
|
||||
GetResponseSizeSummaryByInterval(intervalMinutes int) ([]model.ResponseSizeSummary, error)
|
||||
GetUniqueQueryTypes() ([]interface{}, error)
|
||||
FetchQueries(q models.QueryParams) ([]model.RequestLogEntry, error)
|
||||
GetUniqueClientNameAndIP() []database.RequestLog
|
||||
FetchAllClients() (map[string]Client, error)
|
||||
GetClientDetailsWithDomains(clientIP string) (ClientRequestDetails, string, map[string]int, error)
|
||||
GetTopBlockedDomains(blockedRequests int) ([]map[string]interface{}, error)
|
||||
GetTopClients() ([]map[string]interface{}, error)
|
||||
CountQueries(search string) (int, error)
|
||||
|
||||
func GetClientNameFromRequestLog(db *gorm.DB, ip string) string {
|
||||
SaveRequestLog(entries []model.RequestLogEntry) error
|
||||
|
||||
DeleteRequestLogsTimebased(vacuum vacuumFunc, requestThreshold, maxRetries int, retryDelay time.Duration) error
|
||||
}
|
||||
|
||||
type ClientRequestDetails struct {
|
||||
LastSeen string
|
||||
MostQueriedDomain string
|
||||
TotalRequests int
|
||||
UniqueDomains int
|
||||
BlockedRequests int
|
||||
CachedRequests int
|
||||
AvgResponseTimeMs float64
|
||||
}
|
||||
|
||||
type repository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewRepository(db *gorm.DB) *repository {
|
||||
return &repository{db: db}
|
||||
}
|
||||
|
||||
func (r *repository) GetClientName(ip string) string {
|
||||
var hostname string
|
||||
|
||||
err := db.Model(&RequestLog{}).
|
||||
err := r.db.Model(&database.RequestLog{}).
|
||||
Select("client_name").
|
||||
Where("client_ip = ? AND client_name != ?", ip, "unknown").
|
||||
Limit(1).
|
||||
@@ -31,10 +65,10 @@ func GetClientNameFromRequestLog(db *gorm.DB, ip string) string {
|
||||
return strings.TrimSuffix(hostname, ".")
|
||||
}
|
||||
|
||||
func GetDistinctRequestIP(db *gorm.DB) int {
|
||||
func (r *repository) GetDistinctRequestIP() int {
|
||||
var count int64
|
||||
|
||||
err := db.Model(&RequestLog{}).
|
||||
err := r.db.Model(&database.RequestLog{}).
|
||||
Select("COUNT(DISTINCT client_ip)").
|
||||
Scan(&count).Error
|
||||
if err != nil {
|
||||
@@ -44,33 +78,39 @@ func GetDistinctRequestIP(db *gorm.DB) int {
|
||||
return int(count)
|
||||
}
|
||||
|
||||
func GetRequestSummaryByInterval(interval int, db *gorm.DB) ([]model.RequestLogIntervalSummary, error) {
|
||||
func (r *repository) GetRequestSummaryByInterval(interval int) ([]model.RequestLogIntervalSummary, error) {
|
||||
minutes := interval * 60
|
||||
|
||||
var rawSummaries []model.RequestLogIntervalSummary
|
||||
err := db.Table("request_logs").
|
||||
Select(`
|
||||
DATETIME((STRFTIME('%s', timestamp) / ?) * ?, 'unixepoch') AS interval_start,
|
||||
type tempSummary struct {
|
||||
IntervalStartUnix int64 `gorm:"column:interval_start_unix"`
|
||||
BlockedCount int `gorm:"column:blocked_count"`
|
||||
CachedCount int `gorm:"column:cached_count"`
|
||||
AllowedCount int `gorm:"column:allowed_count"`
|
||||
}
|
||||
|
||||
var rawSummaries []tempSummary
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
((CAST(STRFTIME('%%s', timestamp) AS INTEGER) / %d) * %d) AS interval_start_unix,
|
||||
SUM(blocked) AS blocked_count,
|
||||
SUM(cached) AS cached_count,
|
||||
SUM(NOT blocked AND NOT cached) AS allowed_count
|
||||
`, minutes, minutes, minutes).
|
||||
Where("timestamp >= DATETIME('now', '-1 day')").
|
||||
Group("(STRFTIME('%s', timestamp) / ?)").
|
||||
Order("interval_start").
|
||||
Scan(&rawSummaries).Error
|
||||
FROM request_logs
|
||||
WHERE timestamp >= DATETIME('now', '-1 day')
|
||||
GROUP BY (CAST(STRFTIME('%%s', timestamp) AS INTEGER) / %d)
|
||||
ORDER BY interval_start_unix ASC
|
||||
`, minutes, minutes, minutes)
|
||||
|
||||
err := r.db.Raw(query).Scan(&rawSummaries).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
summaries := make([]model.RequestLogIntervalSummary, len(rawSummaries))
|
||||
for i := range rawSummaries {
|
||||
t, err := time.Parse("2006-01-02 15:04:05", rawSummaries[i].IntervalStart)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
summaries[i] = model.RequestLogIntervalSummary{
|
||||
IntervalStart: t.String(),
|
||||
IntervalStart: time.Unix(rawSummaries[i].IntervalStartUnix, 0).Format("2006-01-02 15:04:05"),
|
||||
BlockedCount: rawSummaries[i].BlockedCount,
|
||||
CachedCount: rawSummaries[i].CachedCount,
|
||||
AllowedCount: rawSummaries[i].AllowedCount,
|
||||
@@ -80,27 +120,38 @@ func GetRequestSummaryByInterval(interval int, db *gorm.DB) ([]model.RequestLogI
|
||||
return summaries, nil
|
||||
}
|
||||
|
||||
func GetResponseSizeSummaryByInterval(intervalMinutes int, db *gorm.DB) ([]model.ResponseSizeSummary, error) {
|
||||
func (r *repository) GetResponseSizeSummaryByInterval(intervalMinutes int) ([]model.ResponseSizeSummary, error) {
|
||||
intervalSeconds := int64(intervalMinutes * 60)
|
||||
twentyFourHoursAgo := time.Now().Add(-24 * time.Hour).Unix()
|
||||
twentyFourHoursAgo := time.Now().Add(-24 * time.Hour)
|
||||
|
||||
var summaries []model.ResponseSizeSummary
|
||||
|
||||
query := `
|
||||
WITH logs_unix AS (
|
||||
SELECT
|
||||
CAST(strftime('%s', timestamp) AS INTEGER) AS ts_unix,
|
||||
response_size_bytes
|
||||
FROM request_logs
|
||||
WHERE timestamp >= ? AND response_size_bytes IS NOT NULL
|
||||
)
|
||||
SELECT
|
||||
((strftime('%s', timestamp) / ?) * ?) AS start_unix,
|
||||
(ts_unix / ?) * ? AS start_unix,
|
||||
SUM(response_size_bytes) AS total_size_bytes,
|
||||
ROUND(AVG(response_size_bytes)) AS avg_response_size_bytes,
|
||||
MIN(response_size_bytes) AS min_response_size_bytes,
|
||||
MAX(response_size_bytes) AS max_response_size_bytes
|
||||
FROM request_logs
|
||||
WHERE strftime('%s', timestamp) >= ? AND response_size_bytes IS NOT NULL
|
||||
GROUP BY (strftime('%s', timestamp) / ?)
|
||||
FROM logs_unix
|
||||
GROUP BY ts_unix / ?
|
||||
ORDER BY start_unix ASC
|
||||
`
|
||||
|
||||
err := db.Raw(query, intervalSeconds, intervalSeconds, twentyFourHoursAgo, intervalSeconds).
|
||||
Scan(&summaries).Error
|
||||
err := r.db.Raw(query,
|
||||
twentyFourHoursAgo,
|
||||
intervalSeconds,
|
||||
intervalSeconds,
|
||||
intervalSeconds,
|
||||
).Scan(&summaries).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query response size summary: %w", err)
|
||||
}
|
||||
@@ -112,7 +163,7 @@ func GetResponseSizeSummaryByInterval(intervalMinutes int, db *gorm.DB) ([]model
|
||||
return summaries, nil
|
||||
}
|
||||
|
||||
func GetUniqueQueryTypes(db *gorm.DB) ([]interface{}, error) {
|
||||
func (r *repository) GetUniqueQueryTypes() ([]interface{}, error) {
|
||||
query := `
|
||||
SELECT COUNT(*) AS count, query_type
|
||||
FROM request_logs
|
||||
@@ -120,7 +171,7 @@ func GetUniqueQueryTypes(db *gorm.DB) ([]interface{}, error) {
|
||||
GROUP BY query_type
|
||||
ORDER BY count DESC`
|
||||
|
||||
rows, err := db.Raw(query).Rows()
|
||||
rows, err := r.db.Raw(query).Rows()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -134,8 +185,8 @@ func GetUniqueQueryTypes(db *gorm.DB) ([]interface{}, error) {
|
||||
var queries []any
|
||||
for rows.Next() {
|
||||
query := struct {
|
||||
Count int `json:"count"`
|
||||
QueryType string `json:"queryType"`
|
||||
Count int `json:"count"`
|
||||
}{}
|
||||
|
||||
if err := rows.Scan(&query.Count, &query.QueryType); err != nil {
|
||||
@@ -152,9 +203,9 @@ func GetUniqueQueryTypes(db *gorm.DB) ([]interface{}, error) {
|
||||
return queries, nil
|
||||
}
|
||||
|
||||
func FetchQueries(db *gorm.DB, q models.QueryParams) ([]model.RequestLogEntry, error) {
|
||||
var logs []RequestLog
|
||||
query := db.Model(&RequestLog{})
|
||||
func (r *repository) FetchQueries(q models.QueryParams) ([]model.RequestLogEntry, error) {
|
||||
var logs []database.RequestLog
|
||||
query := r.db.Model(&database.RequestLog{})
|
||||
|
||||
if q.Column == "ip" {
|
||||
query = query.Joins("LEFT JOIN request_log_ips ri ON request_logs.id = ri.request_log_id")
|
||||
@@ -191,7 +242,7 @@ func FetchQueries(db *gorm.DB, q models.QueryParams) ([]model.RequestLogEntry, e
|
||||
results := make([]model.RequestLogEntry, len(logs))
|
||||
for i, log := range logs {
|
||||
results[i] = model.RequestLogEntry{
|
||||
ID: int64(log.ID),
|
||||
ID: log.ID,
|
||||
Timestamp: log.Timestamp,
|
||||
Domain: log.Domain,
|
||||
Blocked: log.Blocked,
|
||||
@@ -206,14 +257,14 @@ func FetchQueries(db *gorm.DB, q models.QueryParams) ([]model.RequestLogEntry, e
|
||||
}
|
||||
|
||||
for j, ip := range log.IPs {
|
||||
results[i].IP[j] = model.ResolvedIP{IP: ip.IP, RType: ip.RType}
|
||||
results[i].IP[j] = model.ResolvedIP{IP: ip.IP, RType: ip.RecordType}
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func FetchAllClients(db *gorm.DB) (map[string]dbModel.Client, error) {
|
||||
func (r *repository) FetchAllClients() (map[string]Client, error) {
|
||||
var rows []struct {
|
||||
ClientIP string `gorm:"column:client_ip"`
|
||||
ClientName string `gorm:"column:client_name"`
|
||||
@@ -222,11 +273,11 @@ func FetchAllClients(db *gorm.DB) (map[string]dbModel.Client, error) {
|
||||
Vendor sql.NullString `gorm:"column:vendor"`
|
||||
}
|
||||
|
||||
subquery := db.Table("request_logs").
|
||||
subquery := r.db.Table("request_logs").
|
||||
Select("client_ip, MAX(timestamp) as max_timestamp").
|
||||
Group("client_ip")
|
||||
|
||||
if err := db.Table("request_logs r").
|
||||
if err := r.db.Table("request_logs r").
|
||||
Select("r.client_ip, r.client_name, r.timestamp, m.mac, m.vendor").
|
||||
Joins("INNER JOIN (?) latest ON r.client_ip = latest.client_ip AND r.timestamp = latest.max_timestamp", subquery).
|
||||
Joins("LEFT JOIN mac_addresses m ON r.client_ip = m.ip").
|
||||
@@ -234,7 +285,7 @@ func FetchAllClients(db *gorm.DB) (map[string]dbModel.Client, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uniqueClients := make(map[string]dbModel.Client, len(rows))
|
||||
uniqueClients := make(map[string]Client, len(rows))
|
||||
for _, row := range rows {
|
||||
macStr := ""
|
||||
vendorStr := ""
|
||||
@@ -245,7 +296,7 @@ func FetchAllClients(db *gorm.DB) (map[string]dbModel.Client, error) {
|
||||
vendorStr = row.Vendor.String
|
||||
}
|
||||
|
||||
uniqueClients[row.ClientIP] = dbModel.Client{
|
||||
uniqueClients[row.ClientIP] = Client{
|
||||
Name: row.ClientName,
|
||||
LastSeen: row.Timestamp,
|
||||
Mac: macStr,
|
||||
@@ -256,9 +307,20 @@ func FetchAllClients(db *gorm.DB) (map[string]dbModel.Client, error) {
|
||||
return uniqueClients, nil
|
||||
}
|
||||
|
||||
func GetClientDetailsWithDomains(db *gorm.DB, clientIP string) (dbModel.ClientRequestDetails, string, map[string]int, error) {
|
||||
var crd dbModel.ClientRequestDetails
|
||||
err := db.Table("request_logs").
|
||||
func (r *repository) GetUniqueClientNameAndIP() []database.RequestLog {
|
||||
var results []database.RequestLog
|
||||
|
||||
r.db.Model(&database.RequestLog{}).
|
||||
Select("DISTINCT client_ip, client_name").
|
||||
Where("client_name IS NOT NULL AND client_name != ?", "unknown").
|
||||
Find(&results)
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
func (r *repository) GetClientDetailsWithDomains(clientIP string) (ClientRequestDetails, string, map[string]int, error) {
|
||||
var crd ClientRequestDetails
|
||||
err := r.db.Table("request_logs").
|
||||
Select(`
|
||||
COUNT(*) as total_requests,
|
||||
COUNT(DISTINCT domain) as unique_domains,
|
||||
@@ -278,7 +340,7 @@ func GetClientDetailsWithDomains(db *gorm.DB, clientIP string) (dbModel.ClientRe
|
||||
Count int `gorm:"column:query_count"`
|
||||
}
|
||||
|
||||
err = db.Table("request_logs").
|
||||
err = r.db.Table("request_logs").
|
||||
Select("domain, COUNT(*) as query_count").
|
||||
Where("client_ip = ?", clientIP).
|
||||
Group("domain").
|
||||
@@ -306,13 +368,13 @@ func GetClientDetailsWithDomains(db *gorm.DB, clientIP string) (dbModel.ClientRe
|
||||
return crd, mostQueriedDomain, domainQueryCounts, nil
|
||||
}
|
||||
|
||||
func GetTopBlockedDomains(db *gorm.DB, blockedRequests int) ([]map[string]interface{}, error) {
|
||||
func (r *repository) GetTopBlockedDomains(blockedRequests int) ([]map[string]interface{}, error) {
|
||||
var rows []struct {
|
||||
Domain string `gorm:"column:domain"`
|
||||
Hits int `gorm:"column:hits"`
|
||||
}
|
||||
|
||||
if err := db.Table("request_logs").
|
||||
if err := r.db.Table("request_logs").
|
||||
Select("domain, COUNT(*) as hits").
|
||||
Where("blocked = ?", true).
|
||||
Group("domain").
|
||||
@@ -337,9 +399,9 @@ func GetTopBlockedDomains(db *gorm.DB, blockedRequests int) ([]map[string]interf
|
||||
return topBlockedDomains, nil
|
||||
}
|
||||
|
||||
func GetTopClients(db *gorm.DB) ([]map[string]interface{}, error) {
|
||||
func (r *repository) GetTopClients() ([]map[string]interface{}, error) {
|
||||
var total int64
|
||||
if err := db.Table("request_logs").Count(&total).Error; err != nil {
|
||||
if err := r.db.Table("request_logs").Count(&total).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -349,7 +411,7 @@ func GetTopClients(db *gorm.DB) ([]map[string]interface{}, error) {
|
||||
Frequency float32 `gorm:"column:frequency"`
|
||||
}
|
||||
|
||||
if err := db.Table("request_logs").
|
||||
if err := r.db.Table("request_logs").
|
||||
Select("? as frequency, client_ip, COUNT(*) as request_count", 0).
|
||||
Group("client_ip").
|
||||
Order("request_count DESC").
|
||||
@@ -370,18 +432,18 @@ func GetTopClients(db *gorm.DB) ([]map[string]interface{}, error) {
|
||||
return clients, nil
|
||||
}
|
||||
|
||||
func CountQueries(db *gorm.DB, search string) (int, error) {
|
||||
func (r *repository) CountQueries(search string) (int, error) {
|
||||
var total int64
|
||||
err := db.Table("request_logs").
|
||||
err := r.db.Table("request_logs").
|
||||
Where("domain LIKE ?", "%"+search+"%").
|
||||
Count(&total).Error
|
||||
return int(total), err
|
||||
}
|
||||
|
||||
func SaveRequestLog(db *gorm.DB, entries []model.RequestLogEntry) error {
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
func (r *repository) SaveRequestLog(entries []model.RequestLogEntry) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
for _, entry := range entries {
|
||||
rl := RequestLog{
|
||||
rl := database.RequestLog{
|
||||
Timestamp: entry.Timestamp,
|
||||
Domain: entry.Domain,
|
||||
Blocked: entry.Blocked,
|
||||
@@ -396,41 +458,40 @@ func SaveRequestLog(db *gorm.DB, entries []model.RequestLogEntry) error {
|
||||
}
|
||||
|
||||
for _, resolvedIP := range entry.IP {
|
||||
rl.IPs = append(rl.IPs, RequestLogIP{
|
||||
IP: resolvedIP.IP,
|
||||
RType: resolvedIP.RType,
|
||||
rl.IPs = append(rl.IPs, database.RequestLogIP{
|
||||
IP: resolvedIP.IP,
|
||||
RecordType: resolvedIP.RType,
|
||||
})
|
||||
}
|
||||
|
||||
if err := tx.Create(&rl).Error; err != nil {
|
||||
return fmt.Errorf("could not save request log: %v", err)
|
||||
return fmt.Errorf("could not save request log: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
type vacuumFunc func()
|
||||
|
||||
func DeleteRequestLogsTimebased(vacuum vacuumFunc, db *gorm.DB, requestThreshold, maxRetries int, retryDelay time.Duration) {
|
||||
func (r *repository) DeleteRequestLogsTimebased(vacuum vacuumFunc, requestThreshold, maxRetries int, retryDelay time.Duration) error {
|
||||
cutoffTime := time.Now().Add(-time.Duration(requestThreshold) * time.Second)
|
||||
|
||||
for retryCount := range maxRetries {
|
||||
result := db.Where("timestamp < ?", cutoffTime).Delete(&RequestLog{})
|
||||
result := r.db.Where("timestamp < ?", cutoffTime).Delete(&database.RequestLog{})
|
||||
if result.Error != nil {
|
||||
if result.Error.Error() == "database is locked" {
|
||||
log.Warning("Database is locked; retrying (%d/%d)", retryCount+1, maxRetries)
|
||||
time.Sleep(retryDelay)
|
||||
continue
|
||||
}
|
||||
log.Error("Failed to clear old entries: %s", result.Error)
|
||||
break
|
||||
return fmt.Errorf("failed to clear old entries: %w", result.Error)
|
||||
}
|
||||
|
||||
if affected := result.RowsAffected; affected > 0 {
|
||||
vacuum()
|
||||
vacuum(context.Background())
|
||||
log.Debug("Cleared %d old entries", affected)
|
||||
}
|
||||
break
|
||||
return nil // Success
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to delete after %d retries", maxRetries)
|
||||
}
|
||||
89
backend/request/service.go
Normal file
89
backend/request/service.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package request
|
||||
|
||||
import (
|
||||
"context"
|
||||
"goaway/backend/api/models"
|
||||
model "goaway/backend/dns/server/models"
|
||||
"goaway/backend/logging"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
repository Repository
|
||||
}
|
||||
|
||||
var log = logging.GetLogger()
|
||||
|
||||
func NewService(repo Repository) *Service {
|
||||
return &Service{repository: repo}
|
||||
}
|
||||
|
||||
func (s *Service) GetClientNameFromIP(ip string) string {
|
||||
return s.repository.GetClientName(ip)
|
||||
}
|
||||
|
||||
func (s *Service) GetDistinctRequestIP() int {
|
||||
return s.repository.GetDistinctRequestIP()
|
||||
}
|
||||
|
||||
func (s *Service) GetRequestSummaryByInterval(interval int) ([]model.RequestLogIntervalSummary, error) {
|
||||
return s.repository.GetRequestSummaryByInterval(interval)
|
||||
}
|
||||
|
||||
func (s *Service) GetResponseSizeSummaryByInterval(intervalMinutes int) ([]model.ResponseSizeSummary, error) {
|
||||
return s.repository.GetResponseSizeSummaryByInterval(intervalMinutes)
|
||||
}
|
||||
|
||||
func (s *Service) GetUniqueQueryTypes() ([]interface{}, error) {
|
||||
return s.repository.GetUniqueQueryTypes()
|
||||
}
|
||||
|
||||
func (s *Service) FetchQueries(q models.QueryParams) ([]model.RequestLogEntry, error) {
|
||||
return s.repository.FetchQueries(q)
|
||||
}
|
||||
|
||||
func (s *Service) FetchAllClients() (map[string]Client, error) {
|
||||
return s.repository.FetchAllClients()
|
||||
}
|
||||
|
||||
func (s *Service) GetUniqueClientNameAndIP() []ClientNameAndIP {
|
||||
queryResult := s.repository.GetUniqueClientNameAndIP()
|
||||
var uniqueClients []ClientNameAndIP
|
||||
|
||||
for _, client := range queryResult {
|
||||
uniqueClients = append(uniqueClients, ClientNameAndIP{
|
||||
Name: client.ClientName,
|
||||
IP: client.ClientIP,
|
||||
})
|
||||
}
|
||||
|
||||
return uniqueClients
|
||||
}
|
||||
|
||||
func (s *Service) GetClientDetailsWithDomains(clientIP string) (ClientRequestDetails, string, map[string]int, error) {
|
||||
return s.repository.GetClientDetailsWithDomains(clientIP)
|
||||
}
|
||||
|
||||
func (s *Service) GetTopBlockedDomains(blockedRequests int) ([]map[string]interface{}, error) {
|
||||
return s.repository.GetTopBlockedDomains(blockedRequests)
|
||||
}
|
||||
|
||||
func (s *Service) GetTopClients() ([]map[string]interface{}, error) {
|
||||
return s.repository.GetTopClients()
|
||||
}
|
||||
|
||||
func (s *Service) CountQueries(search string) (int, error) {
|
||||
return s.repository.CountQueries(search)
|
||||
}
|
||||
|
||||
func (s *Service) SaveRequestLog(entries []model.RequestLogEntry) error {
|
||||
return s.repository.SaveRequestLog(entries)
|
||||
}
|
||||
|
||||
type vacuumFunc func(ctx context.Context)
|
||||
|
||||
func (s *Service) DeleteRequestLogsTimebased(vacuum vacuumFunc, requestThreshold, maxRetries int, retryDelay time.Duration) {
|
||||
if err := s.repository.DeleteRequestLogsTimebased(vacuum, requestThreshold, maxRetries, retryDelay); err != nil {
|
||||
log.Warning("Error while deleting old request logs: %v", err)
|
||||
}
|
||||
}
|
||||
77
backend/resolution/repository.go
Normal file
77
backend/resolution/repository.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package resolution
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"goaway/backend/database"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
CreateResolution(ip, domain string) error
|
||||
FindResolution(domain string) (string, error)
|
||||
FindResolutions() ([]database.Resolution, error)
|
||||
DeleteResolution(ip, domain string) (int, error)
|
||||
}
|
||||
|
||||
type repository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewRepository(db *gorm.DB) Repository {
|
||||
return &repository{db: db}
|
||||
}
|
||||
|
||||
func (r *repository) CreateResolution(ip, domain string) error {
|
||||
res := database.Resolution{
|
||||
Domain: domain,
|
||||
IP: ip,
|
||||
}
|
||||
|
||||
if err := r.db.Create(&res).Error; err != nil {
|
||||
if strings.Contains(err.Error(), "UNIQUE") {
|
||||
return errors.New("domain already exists, must be unique")
|
||||
}
|
||||
return fmt.Errorf("could not create new resolution: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repository) FindResolution(domain string) (string, error) {
|
||||
var res database.Resolution
|
||||
|
||||
r.db.Where("domain = ?", domain).Find(&res)
|
||||
if res.IP != "" {
|
||||
return res.IP, nil
|
||||
}
|
||||
|
||||
parts := strings.Split(domain, ".")
|
||||
for i := 1; i < len(parts); i++ {
|
||||
wildcardDomain := "*." + strings.Join(parts[i:], ".")
|
||||
if err := r.db.Where("domain = ?", wildcardDomain).Find(&res).Error; err == nil {
|
||||
return res.IP, nil
|
||||
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (r *repository) FindResolutions() ([]database.Resolution, error) {
|
||||
var resolutions []database.Resolution
|
||||
if err := r.db.Find(&resolutions).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resolutions, nil
|
||||
}
|
||||
|
||||
func (r *repository) DeleteResolution(ip, domain string) (int, error) {
|
||||
result := r.db.Where("domain = ? AND ip = ?", domain, ip).Delete(&database.Resolution{})
|
||||
if result.Error != nil {
|
||||
return 0, result.Error
|
||||
}
|
||||
return int(result.RowsAffected), nil
|
||||
}
|
||||
34
backend/resolution/service.go
Normal file
34
backend/resolution/service.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package resolution
|
||||
|
||||
import (
|
||||
"goaway/backend/database"
|
||||
"goaway/backend/logging"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
repository Repository
|
||||
}
|
||||
|
||||
var log = logging.GetLogger()
|
||||
|
||||
func NewService(repo Repository) *Service {
|
||||
return &Service{repository: repo}
|
||||
}
|
||||
|
||||
func (s *Service) CreateResolution(ip, domain string) error {
|
||||
log.Debug("Creating new resolution '%s' -> '%s'", domain, ip)
|
||||
return s.repository.CreateResolution(ip, domain)
|
||||
}
|
||||
|
||||
func (s *Service) GetResolution(domain string) (string, error) {
|
||||
log.Debug("Finding resolution for domain: %s", domain)
|
||||
return s.repository.FindResolution(domain)
|
||||
}
|
||||
|
||||
func (s *Service) GetResolutions() ([]database.Resolution, error) {
|
||||
return s.repository.FindResolutions()
|
||||
}
|
||||
|
||||
func (s *Service) DeleteResolution(ip, domain string) (int, error) {
|
||||
return s.repository.DeleteResolution(ip, domain)
|
||||
}
|
||||
51
backend/services/context.go
Normal file
51
backend/services/context.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"goaway/backend/database"
|
||||
"goaway/backend/dns/server"
|
||||
"goaway/backend/settings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type AppContext struct {
|
||||
Config *settings.Config
|
||||
DBConn *gorm.DB
|
||||
Certificate tls.Certificate
|
||||
DNSServer *server.DNSServer
|
||||
}
|
||||
|
||||
func NewAppContext(config *settings.Config) (*AppContext, error) {
|
||||
ctx := &AppContext{Config: config}
|
||||
if err := ctx.initialize(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func (ctx *AppContext) initialize() error {
|
||||
ctx.DBConn = database.Initialize()
|
||||
|
||||
cert, err := ctx.Config.GetCertificate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get certificate: %w", err)
|
||||
}
|
||||
ctx.Certificate = cert
|
||||
|
||||
dnsServer, err := server.NewDNSServer(
|
||||
ctx.Config,
|
||||
ctx.DBConn,
|
||||
cert,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize DNS server: %w", err)
|
||||
}
|
||||
ctx.DNSServer = dnsServer
|
||||
|
||||
go dnsServer.ProcessLogEntries()
|
||||
|
||||
return nil
|
||||
}
|
||||
235
backend/services/registry.go
Normal file
235
backend/services/registry.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"goaway/backend/api"
|
||||
"goaway/backend/api/key"
|
||||
"goaway/backend/blacklist"
|
||||
"goaway/backend/logging"
|
||||
"goaway/backend/notification"
|
||||
"goaway/backend/prefetch"
|
||||
"goaway/backend/request"
|
||||
"goaway/backend/resolution"
|
||||
"goaway/backend/user"
|
||||
"goaway/backend/whitelist"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
var log = logging.GetLogger()
|
||||
|
||||
// Manages all servers and services
|
||||
type ServiceRegistry struct {
|
||||
APIServer *api.API
|
||||
errorChan chan ServiceError
|
||||
dotServer *dns.Server
|
||||
readyChan chan struct{}
|
||||
content embed.FS
|
||||
udpServer *dns.Server
|
||||
dohServer *http.Server
|
||||
tcpServer *dns.Server
|
||||
Context *AppContext
|
||||
version string
|
||||
date string
|
||||
commit string
|
||||
wg sync.WaitGroup
|
||||
|
||||
ResolutionService *resolution.Service
|
||||
RequestService *request.Service
|
||||
PrefetchService *prefetch.Service
|
||||
UserService *user.Service
|
||||
KeyService *key.Service
|
||||
NotificationService *notification.Service
|
||||
BlacklistService *blacklist.Service
|
||||
WhitelistService *whitelist.Service
|
||||
}
|
||||
|
||||
type ServiceError struct {
|
||||
Err error
|
||||
Service string
|
||||
}
|
||||
|
||||
func NewServiceRegistry(ctx *AppContext, version, commit, date string, content embed.FS) *ServiceRegistry {
|
||||
return &ServiceRegistry{
|
||||
Context: ctx,
|
||||
version: version,
|
||||
commit: commit,
|
||||
date: date,
|
||||
content: content,
|
||||
readyChan: make(chan struct{}),
|
||||
errorChan: make(chan ServiceError, 10),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ServiceRegistry) Initialize() error {
|
||||
r.setupDNSServers()
|
||||
|
||||
if r.Context.Certificate.Certificate != nil {
|
||||
if err := r.setupSecureServers(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
r.setupAPIServer()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *ServiceRegistry) setupDNSServers() {
|
||||
config := r.Context.Config
|
||||
|
||||
notifyReady := func() {
|
||||
log.Info("Started DNS server on: %s:%d", config.DNS.Address, config.DNS.Ports.TCPUDP)
|
||||
close(r.readyChan)
|
||||
}
|
||||
|
||||
r.udpServer = &dns.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", config.DNS.Address, config.DNS.Ports.TCPUDP),
|
||||
Net: "udp",
|
||||
Handler: r.Context.DNSServer,
|
||||
ReusePort: true,
|
||||
UDPSize: config.DNS.UDPSize,
|
||||
}
|
||||
|
||||
r.tcpServer = &dns.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", config.DNS.Address, config.DNS.Ports.TCPUDP),
|
||||
Net: "tcp",
|
||||
Handler: r.Context.DNSServer,
|
||||
ReusePort: true,
|
||||
UDPSize: config.DNS.UDPSize,
|
||||
NotifyStartedFunc: notifyReady,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ServiceRegistry) setupSecureServers() error {
|
||||
dotServer, err := r.Context.DNSServer.InitDoT(r.Context.Certificate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize DoT server: %w", err)
|
||||
}
|
||||
r.dotServer = dotServer
|
||||
|
||||
dohServer, err := r.Context.DNSServer.InitDoH(r.Context.Certificate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize DoH server: %w", err)
|
||||
}
|
||||
r.dohServer = dohServer
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *ServiceRegistry) setupAPIServer() {
|
||||
r.APIServer = &api.API{
|
||||
DNS: r.Context.DNSServer,
|
||||
Authentication: r.Context.Config.API.Authentication,
|
||||
Config: r.Context.Config,
|
||||
DNSPort: r.Context.Config.DNS.Ports.TCPUDP,
|
||||
Version: r.version,
|
||||
Commit: r.commit,
|
||||
Date: r.date,
|
||||
DNSServer: r.Context.DNSServer,
|
||||
DBConn: r.Context.DBConn,
|
||||
WSQueries: r.Context.DNSServer.WSQueries,
|
||||
WSCommunication: r.Context.DNSServer.WSCommunication,
|
||||
|
||||
ResolutionService: r.ResolutionService,
|
||||
RequestService: r.RequestService,
|
||||
PrefetchService: r.PrefetchService,
|
||||
NotificationService: r.NotificationService,
|
||||
UserService: r.UserService,
|
||||
KeyService: r.KeyService,
|
||||
BlacklistService: r.BlacklistService,
|
||||
WhitelistService: r.WhitelistService,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ServiceRegistry) StartAll() {
|
||||
r.startDNSServers()
|
||||
|
||||
if r.Context.Certificate.Certificate != nil {
|
||||
r.startSecureServers()
|
||||
}
|
||||
|
||||
r.startAPIServer()
|
||||
}
|
||||
|
||||
func (r *ServiceRegistry) startDNSServers() {
|
||||
r.wg.Add(1)
|
||||
go func() {
|
||||
defer r.wg.Done()
|
||||
if err := r.udpServer.ListenAndServe(); err != nil {
|
||||
r.errorChan <- ServiceError{Service: "UDP", Err: err}
|
||||
}
|
||||
}()
|
||||
|
||||
r.wg.Add(1)
|
||||
go func() {
|
||||
defer r.wg.Done()
|
||||
if err := r.tcpServer.ListenAndServe(); err != nil {
|
||||
r.errorChan <- ServiceError{Service: "TCP", Err: err}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (r *ServiceRegistry) startSecureServers() {
|
||||
r.wg.Add(1)
|
||||
go func() {
|
||||
defer r.wg.Done()
|
||||
if err := r.dotServer.ListenAndServe(); err != nil {
|
||||
r.errorChan <- ServiceError{Service: "DoT", Err: err}
|
||||
}
|
||||
}()
|
||||
|
||||
r.wg.Add(1)
|
||||
go func() {
|
||||
defer r.wg.Done()
|
||||
|
||||
if serverIP, err := api.GetServerIP(); err == nil {
|
||||
log.Info("DoH (dns-over-https) server running at https://%s:%d/dns-query",
|
||||
serverIP, r.Context.Config.DNS.Ports.DoH)
|
||||
} else {
|
||||
log.Info("DoH (dns-over-https) server running on port :%d", r.Context.Config.DNS.Ports.DoH)
|
||||
}
|
||||
|
||||
if err := r.dohServer.ListenAndServeTLS(
|
||||
r.Context.Config.DNS.TLS.Cert,
|
||||
r.Context.Config.DNS.TLS.Key,
|
||||
); err != nil {
|
||||
r.errorChan <- ServiceError{Service: "DoH", Err: err}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (r *ServiceRegistry) startAPIServer() {
|
||||
r.wg.Add(1)
|
||||
go func() {
|
||||
defer r.wg.Done()
|
||||
<-r.readyChan
|
||||
|
||||
errorChan := make(chan struct{}, 1)
|
||||
go func() {
|
||||
<-errorChan
|
||||
r.errorChan <- ServiceError{Service: "API", Err: fmt.Errorf("API server stopped")}
|
||||
}()
|
||||
|
||||
r.APIServer.Start(r.content, errorChan)
|
||||
}()
|
||||
}
|
||||
|
||||
func (r *ServiceRegistry) WaitGroup() *sync.WaitGroup {
|
||||
return &r.wg
|
||||
}
|
||||
|
||||
func (r *ServiceRegistry) ReadyChannel() <-chan struct{} {
|
||||
return r.readyChan
|
||||
}
|
||||
|
||||
func (r *ServiceRegistry) ErrorChannel() <-chan ServiceError {
|
||||
return r.errorChan
|
||||
}
|
||||
|
||||
func (r *ServiceRegistry) GetPrefetcher() *prefetch.Service {
|
||||
return r.APIServer.PrefetchService
|
||||
}
|
||||
69
backend/settings/model.go
Normal file
69
backend/settings/model.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package settings
|
||||
|
||||
import "time"
|
||||
|
||||
type Status struct {
|
||||
PausedAt time.Time `json:"pausedAt"`
|
||||
PauseTime time.Time `json:"pauseTime"`
|
||||
Paused bool `json:"paused"`
|
||||
}
|
||||
|
||||
type TLSConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
Cert string `yaml:"cert" json:"cert"`
|
||||
Key string `yaml:"key" json:"key"`
|
||||
}
|
||||
|
||||
type UpstreamConfig struct {
|
||||
Preferred string `yaml:"preferred" json:"preferred"`
|
||||
Fallback []string `yaml:"fallback" json:"fallback"`
|
||||
}
|
||||
|
||||
type PortsConfig struct {
|
||||
TCPUDP int `yaml:"udptcp" json:"udptcp"`
|
||||
DoT int `yaml:"dot" json:"dot"`
|
||||
DoH int `yaml:"doh" json:"doh"`
|
||||
}
|
||||
|
||||
type DNSConfig struct {
|
||||
Status Status `yaml:"-" json:"status"`
|
||||
Address string `yaml:"address" json:"address"`
|
||||
Gateway string `yaml:"gateway" json:"gateway"`
|
||||
CacheTTL int `yaml:"cacheTTL" json:"cacheTTL"`
|
||||
UDPSize int `yaml:"udpSize" json:"udpSize"`
|
||||
TLS TLSConfig `yaml:"tls" json:"tls"`
|
||||
Upstream UpstreamConfig `yaml:"upstream" json:"upstream"`
|
||||
Ports PortsConfig `yaml:"ports" json:"ports"`
|
||||
}
|
||||
|
||||
type RateLimitConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
MaxTries int `yaml:"maxTries" json:"maxTries"`
|
||||
Window int `yaml:"window" json:"window"`
|
||||
}
|
||||
|
||||
type APIConfig struct {
|
||||
Port int `yaml:"port" json:"port"`
|
||||
Authentication bool `yaml:"authentication" json:"authentication"`
|
||||
RateLimit RateLimitConfig `yaml:"rateLimit" json:"rateLimit"`
|
||||
}
|
||||
|
||||
type LoggingConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
Level int `yaml:"level" json:"level"`
|
||||
}
|
||||
|
||||
type MiscConfig struct {
|
||||
InAppUpdate bool `yaml:"inAppUpdate" json:"inAppUpdate"`
|
||||
StatisticsRetention int `yaml:"statisticsRetention" json:"statisticsRetention"`
|
||||
Dashboard bool `yaml:"dashboard" json:"dashboard"`
|
||||
ScheduledBlacklistUpdates bool `yaml:"scheduledBlacklistUpdates" json:"scheduledBlacklistUpdates"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
BinaryPath string `yaml:"-" json:"-"`
|
||||
DNS DNSConfig `yaml:"dns" json:"dns"`
|
||||
API APIConfig `yaml:"api" json:"api"`
|
||||
Logging LoggingConfig `yaml:"logging" json:"logging"`
|
||||
Misc MiscConfig `yaml:"misc" json:"misc"`
|
||||
}
|
||||
@@ -3,62 +3,17 @@ package settings
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"goaway/backend/api/ratelimit"
|
||||
"goaway/backend/logging"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var log = logging.GetLogger()
|
||||
|
||||
type Status struct {
|
||||
Paused bool `json:"paused"`
|
||||
PausedAt time.Time `json:"pausedAt"`
|
||||
PauseTime int `json:"pauseTime"`
|
||||
}
|
||||
|
||||
type DNSConfig struct {
|
||||
Address string `yaml:"address" json:"address"`
|
||||
Port int `yaml:"port" json:"port"`
|
||||
DoTPort int `yaml:"dotPort" json:"dotPort"`
|
||||
DoHPort int `yaml:"dohPort" json:"dohPort"`
|
||||
CacheTTL int `yaml:"cacheTTL" json:"cacheTTL"`
|
||||
PreferredUpstream string `yaml:"preferredUpstream" json:"preferredUpstream"`
|
||||
Gateway string `yaml:"gateway" json:"gateway"`
|
||||
UpstreamDNS []string `yaml:"upstreamDNS" json:"upstreamDNS"`
|
||||
UDPSize int `yaml:"udpSize" json:"udpSize"`
|
||||
Status Status `yaml:"-" json:"status"`
|
||||
|
||||
TLSCertFile string `yaml:"tlsCertFile" json:"tlsCertFile"`
|
||||
TLSKeyFile string `yaml:"tlsKeyFile" json:"tlsKeyFile"`
|
||||
}
|
||||
|
||||
type APIConfig struct {
|
||||
Port int `yaml:"port" json:"port"`
|
||||
Authentication bool `yaml:"authentication" json:"authentication"`
|
||||
RateLimiterConfig ratelimit.RateLimiterConfig `yaml:"rateLimit" json:"-"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
DNS DNSConfig `yaml:"dns" json:"dns"`
|
||||
API APIConfig `yaml:"api" json:"api"`
|
||||
|
||||
Dashboard bool `yaml:"dashboard" json:"-"`
|
||||
ScheduledBlacklistUpdates bool `yaml:"scheduledBlacklistUpdates" json:"scheduledBlacklistUpdates"`
|
||||
StatisticsRetention int `yaml:"statisticsRetention" json:"statisticsRetention"`
|
||||
LoggingEnabled bool `yaml:"loggingEnabled" json:"loggingEnabled"`
|
||||
LogLevel logging.LogLevel `yaml:"logLevel" json:"logLevel"`
|
||||
InAppUpdate bool `yaml:"inAppUpdate" json:"inAppUpdate"`
|
||||
|
||||
// settings not visible in config file
|
||||
BinaryPath string `yaml:"-" json:"-"`
|
||||
}
|
||||
|
||||
func LoadSettings() (Config, error) {
|
||||
var config Config
|
||||
|
||||
@@ -74,6 +29,7 @@ func LoadSettings() (Config, error) {
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
@@ -81,43 +37,10 @@ func LoadSettings() (Config, error) {
|
||||
return Config{}, fmt.Errorf("could not read settings file: %w", err)
|
||||
}
|
||||
|
||||
type configWithPtr struct {
|
||||
DNS DNSConfig `yaml:"dns" json:"dns"`
|
||||
API APIConfig `yaml:"api" json:"api"`
|
||||
Dashboard *bool `yaml:"dashboard" json:"-"`
|
||||
ScheduledBlacklistUpdates *bool `yaml:"scheduledBlacklistUpdates" json:"scheduledBlacklistUpdates"`
|
||||
StatisticsRetention int `yaml:"statisticsRetention" json:"statisticsRetention"`
|
||||
LoggingEnabled bool `yaml:"loggingEnabled" json:"loggingEnabled"`
|
||||
LogLevel logging.LogLevel `yaml:"logLevel" json:"logLevel"`
|
||||
InAppUpdate bool `yaml:"inAppUpdate" json:"inAppUpdate"`
|
||||
}
|
||||
|
||||
var temp configWithPtr
|
||||
if err := yaml.Unmarshal(data, &temp); err != nil {
|
||||
if err := yaml.Unmarshal(data, &config); err != nil {
|
||||
return Config{}, fmt.Errorf("invalid settings format: %w", err)
|
||||
}
|
||||
|
||||
config.DNS = temp.DNS
|
||||
config.API = temp.API
|
||||
config.StatisticsRetention = temp.StatisticsRetention
|
||||
config.LoggingEnabled = temp.LoggingEnabled
|
||||
config.LogLevel = temp.LogLevel
|
||||
config.InAppUpdate = temp.InAppUpdate
|
||||
|
||||
if temp.Dashboard == nil {
|
||||
// true by default if the Dashboard field was not found in settings.yaml
|
||||
config.Dashboard = true
|
||||
} else {
|
||||
config.Dashboard = *temp.Dashboard
|
||||
}
|
||||
|
||||
if temp.ScheduledBlacklistUpdates == nil {
|
||||
// false by default if the ScheduledBlacklistUpdates field was not found in settings.yaml
|
||||
config.ScheduledBlacklistUpdates = false
|
||||
} else {
|
||||
config.ScheduledBlacklistUpdates = *temp.ScheduledBlacklistUpdates
|
||||
}
|
||||
|
||||
binaryPath, err := os.Executable()
|
||||
if err != nil {
|
||||
log.Warning("Unable to find installed binary path, err: %v", err)
|
||||
@@ -127,35 +50,84 @@ func LoadSettings() (Config, error) {
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (config *Config) Save() {
|
||||
data, err := yaml.Marshal(config)
|
||||
if err != nil {
|
||||
log.Error("Could not parse settings %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := os.WriteFile("./config/settings.yaml", data, 0644); err != nil {
|
||||
log.Error("Could not save settings %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (config *Config) Update(updatedSettings Config) {
|
||||
config.API.Port = updatedSettings.API.Port
|
||||
config.API.Authentication = updatedSettings.API.Authentication
|
||||
config.API.RateLimit = updatedSettings.API.RateLimit
|
||||
|
||||
config.DNS.Address = updatedSettings.DNS.Address
|
||||
config.DNS.Ports = updatedSettings.DNS.Ports
|
||||
config.DNS.UDPSize = updatedSettings.DNS.UDPSize
|
||||
config.DNS.CacheTTL = updatedSettings.DNS.CacheTTL
|
||||
config.DNS.TLS = updatedSettings.DNS.TLS
|
||||
config.DNS.Upstream = updatedSettings.DNS.Upstream
|
||||
|
||||
config.Logging = updatedSettings.Logging
|
||||
config.Misc = updatedSettings.Misc
|
||||
|
||||
log.ToggleLogging(config.Logging.Enabled)
|
||||
log.SetLevel(logging.LogLevel(config.Logging.Level))
|
||||
|
||||
config.Save()
|
||||
}
|
||||
|
||||
func createDefaultSettings(filePath string) (Config, error) {
|
||||
defaultConfig := Config{
|
||||
StatisticsRetention: 7,
|
||||
LoggingEnabled: true,
|
||||
LogLevel: logging.INFO,
|
||||
InAppUpdate: false,
|
||||
DNS: DNSConfig{
|
||||
Address: "0.0.0.0",
|
||||
Gateway: getDefaultGateway(),
|
||||
CacheTTL: 3600,
|
||||
UDPSize: 512,
|
||||
TLS: TLSConfig{
|
||||
Enabled: false,
|
||||
Cert: "",
|
||||
Key: "",
|
||||
},
|
||||
Upstream: UpstreamConfig{
|
||||
Preferred: "8.8.8.8:53",
|
||||
Fallback: []string{
|
||||
"1.1.1.1:53",
|
||||
},
|
||||
},
|
||||
Ports: PortsConfig{
|
||||
TCPUDP: getEnvAsIntWithDefault("DNS_PORT", 53),
|
||||
DoT: getEnvAsIntWithDefault("DOT_PORT", 853),
|
||||
DoH: getEnvAsIntWithDefault("DOH_PORT", 443),
|
||||
},
|
||||
},
|
||||
API: APIConfig{
|
||||
Port: getEnvAsIntWithDefault("WEBSITE_PORT", 8080),
|
||||
Authentication: true,
|
||||
RateLimit: RateLimitConfig{
|
||||
Enabled: true,
|
||||
MaxTries: 5,
|
||||
Window: 5,
|
||||
},
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Enabled: true,
|
||||
Level: int(logging.INFO),
|
||||
},
|
||||
Misc: MiscConfig{
|
||||
InAppUpdate: false,
|
||||
StatisticsRetention: 7,
|
||||
Dashboard: true,
|
||||
ScheduledBlacklistUpdates: true,
|
||||
},
|
||||
}
|
||||
|
||||
defaultConfig.DNS.Address = "0.0.0.0"
|
||||
defaultConfig.DNS.Port = GetEnvAsIntWithDefault("DNS_PORT", 53)
|
||||
defaultConfig.DNS.DoTPort = GetEnvAsIntWithDefault("DOT_PORT", 853)
|
||||
defaultConfig.DNS.DoHPort = GetEnvAsIntWithDefault("DOH_PORT", 443)
|
||||
defaultConfig.DNS.CacheTTL = 3600
|
||||
defaultConfig.DNS.PreferredUpstream = "8.8.8.8:53"
|
||||
defaultConfig.DNS.Gateway = GetDefaultGateway()
|
||||
defaultConfig.DNS.UpstreamDNS = []string{
|
||||
"1.1.1.1:53",
|
||||
"8.8.8.8:53",
|
||||
}
|
||||
defaultConfig.DNS.UDPSize = 512
|
||||
defaultConfig.DNS.TLSCertFile = ""
|
||||
defaultConfig.DNS.TLSKeyFile = ""
|
||||
|
||||
defaultConfig.Dashboard = true
|
||||
defaultConfig.ScheduledBlacklistUpdates = true
|
||||
defaultConfig.API.Port = GetEnvAsIntWithDefault("WEBSITE_PORT", 8080)
|
||||
defaultConfig.API.Authentication = true
|
||||
defaultConfig.API.RateLimiterConfig = ratelimit.RateLimiterConfig{Enabled: true, MaxTries: 5, Window: 5}
|
||||
|
||||
data, err := yaml.Marshal(&defaultConfig)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("failed to marshal default config: %w", err)
|
||||
@@ -174,45 +146,7 @@ func createDefaultSettings(filePath string) (Config, error) {
|
||||
return defaultConfig, nil
|
||||
}
|
||||
|
||||
func (config *Config) Save() {
|
||||
data, err := yaml.Marshal(config)
|
||||
if err != nil {
|
||||
log.Error("Could not parse settings %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := os.WriteFile("./config/settings.yaml", data, 0644); err != nil {
|
||||
log.Error("Could not save settings %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (config *Config) UpdateSettings(updatedSettings Config) {
|
||||
config.API.Port = updatedSettings.API.Port
|
||||
config.API.Authentication = updatedSettings.API.Authentication
|
||||
|
||||
config.DNS.Address = updatedSettings.DNS.Address
|
||||
config.DNS.Port = updatedSettings.DNS.Port
|
||||
config.DNS.DoTPort = updatedSettings.DNS.DoTPort
|
||||
config.DNS.DoHPort = updatedSettings.DNS.DoHPort
|
||||
config.DNS.UDPSize = updatedSettings.DNS.UDPSize
|
||||
config.DNS.CacheTTL = updatedSettings.DNS.CacheTTL
|
||||
config.DNS.TLSCertFile = updatedSettings.DNS.TLSCertFile
|
||||
config.DNS.TLSKeyFile = updatedSettings.DNS.TLSKeyFile
|
||||
|
||||
config.LogLevel = updatedSettings.LogLevel
|
||||
config.StatisticsRetention = updatedSettings.StatisticsRetention
|
||||
config.LoggingEnabled = updatedSettings.LoggingEnabled
|
||||
|
||||
config.ScheduledBlacklistUpdates = updatedSettings.ScheduledBlacklistUpdates
|
||||
config.InAppUpdate = updatedSettings.InAppUpdate
|
||||
|
||||
log.ToggleLogging(config.LoggingEnabled)
|
||||
log.SetLevel(config.LogLevel)
|
||||
|
||||
config.Save()
|
||||
}
|
||||
|
||||
func GetEnvAsIntWithDefault(envVariable string, defaultValue int) int {
|
||||
func getEnvAsIntWithDefault(envVariable string, defaultValue int) int {
|
||||
val, found := os.LookupEnv(envVariable)
|
||||
if !found {
|
||||
return defaultValue
|
||||
@@ -227,10 +161,10 @@ func GetEnvAsIntWithDefault(envVariable string, defaultValue int) int {
|
||||
}
|
||||
|
||||
func (config *Config) GetCertificate() (tls.Certificate, error) {
|
||||
if config.DNS.TLSCertFile != "" && config.DNS.TLSKeyFile != "" {
|
||||
cert, err := tls.LoadX509KeyPair(config.DNS.TLSCertFile, config.DNS.TLSKeyFile)
|
||||
if config.DNS.TLS.Enabled && config.DNS.TLS.Cert != "" && config.DNS.TLS.Key != "" {
|
||||
cert, err := tls.LoadX509KeyPair(config.DNS.TLS.Cert, config.DNS.TLS.Key)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("failed to load TLS certificate: %s", err)
|
||||
return tls.Certificate{}, fmt.Errorf("failed to load TLS certificate: %w", err)
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
@@ -239,7 +173,7 @@ func (config *Config) GetCertificate() (tls.Certificate, error) {
|
||||
return tls.Certificate{}, nil
|
||||
}
|
||||
|
||||
func GetDefaultGateway() string {
|
||||
func getDefaultGateway() string {
|
||||
conn, err := net.Dial("udp", "8.8.8.8:80")
|
||||
if err != nil {
|
||||
return "192.168.0.1:53"
|
||||
|
||||
@@ -2,18 +2,19 @@ package setup
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"goaway/backend/logging"
|
||||
"goaway/backend/settings"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"goaway/backend/logging"
|
||||
"goaway/backend/settings"
|
||||
|
||||
"github.com/Masterminds/semver"
|
||||
)
|
||||
|
||||
var log = logging.GetLogger()
|
||||
|
||||
type SetFlags struct {
|
||||
DnsPort *int
|
||||
DNSPort *int
|
||||
DoTPort *int
|
||||
DoHPort *int
|
||||
WebserverPort *int
|
||||
@@ -37,15 +38,15 @@ func UpdateConfig(config *settings.Config, flags *SetFlags) {
|
||||
fmt.Println("Flag --log-level can't be greater than 3 or below 0.")
|
||||
os.Exit(1)
|
||||
}
|
||||
if flags.DnsPort != nil || os.Getenv("DNS_PORT") != "" {
|
||||
if flags.DNSPort != nil || os.Getenv("DNS_PORT") != "" {
|
||||
if port, found := os.LookupEnv("DNS_PORT"); found {
|
||||
dnsPort, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
log.Fatal("Could not parse DNS_PORT environment variable")
|
||||
}
|
||||
config.DNS.Port = dnsPort
|
||||
config.DNS.Ports.TCPUDP = dnsPort
|
||||
} else {
|
||||
config.DNS.Port = *flags.DnsPort
|
||||
config.DNS.Ports.TCPUDP = *flags.DNSPort
|
||||
}
|
||||
}
|
||||
if flags.DoTPort != nil || os.Getenv("DOT_PORT") != "" {
|
||||
@@ -54,9 +55,9 @@ func UpdateConfig(config *settings.Config, flags *SetFlags) {
|
||||
if err != nil {
|
||||
log.Fatal("Could not parse DOT_PORT environment variable")
|
||||
}
|
||||
config.DNS.DoTPort = dotPort
|
||||
config.DNS.Ports.DoT = dotPort
|
||||
} else {
|
||||
config.DNS.DoTPort = *flags.DoTPort
|
||||
config.DNS.Ports.DoT = *flags.DoTPort
|
||||
}
|
||||
}
|
||||
if flags.DoHPort != nil || os.Getenv("DOH_PORT") != "" {
|
||||
@@ -65,9 +66,9 @@ func UpdateConfig(config *settings.Config, flags *SetFlags) {
|
||||
if err != nil {
|
||||
log.Fatal("Could not parse DOH_PORT environment variable")
|
||||
}
|
||||
config.DNS.DoHPort = dohPort
|
||||
config.DNS.Ports.DoH = dohPort
|
||||
} else {
|
||||
config.DNS.DoHPort = *flags.DoHPort
|
||||
config.DNS.Ports.DoH = *flags.DoHPort
|
||||
}
|
||||
}
|
||||
if flags.WebserverPort != nil || os.Getenv("WEBSITE_PORT") != "" {
|
||||
@@ -82,27 +83,27 @@ func UpdateConfig(config *settings.Config, flags *SetFlags) {
|
||||
}
|
||||
}
|
||||
if flags.StatisticsRetention != nil {
|
||||
config.StatisticsRetention = *flags.StatisticsRetention
|
||||
config.Misc.StatisticsRetention = *flags.StatisticsRetention
|
||||
}
|
||||
if flags.Authentication != nil {
|
||||
config.API.Authentication = *flags.Authentication
|
||||
}
|
||||
if flags.Dashboard != nil {
|
||||
config.Dashboard = *flags.Dashboard
|
||||
config.Misc.Dashboard = *flags.Dashboard
|
||||
}
|
||||
if flags.LoggingEnabled != nil {
|
||||
config.LoggingEnabled = *flags.LoggingEnabled
|
||||
config.Logging.Enabled = *flags.LoggingEnabled
|
||||
}
|
||||
if flags.LogLevel != nil {
|
||||
config.LogLevel = logging.LogLevel(*flags.LogLevel)
|
||||
config.Logging.Level = *flags.LogLevel
|
||||
}
|
||||
if flags.InAppUpdate != nil {
|
||||
config.InAppUpdate = *flags.InAppUpdate
|
||||
config.Misc.InAppUpdate = *flags.InAppUpdate
|
||||
}
|
||||
|
||||
if flags.JSON != nil {
|
||||
log.JSON = *flags.JSON
|
||||
log.SetJson(log.JSON)
|
||||
log.SetJSON(log.JSON)
|
||||
} else {
|
||||
log.Ansi = flags.Ansi == nil || *flags.Ansi
|
||||
log.SetAnsi(log.Ansi)
|
||||
|
||||
@@ -18,15 +18,15 @@ func SelfUpdate(sse sendSSE, binaryPath string) error {
|
||||
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stdout pipe: %v", err)
|
||||
return fmt.Errorf("failed to create stdout pipe: %w", err)
|
||||
}
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stderr pipe: %v", err)
|
||||
return fmt.Errorf("failed to create stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start command: %v", err)
|
||||
return fmt.Errorf("failed to start command: %w", err)
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
@@ -55,7 +55,7 @@ func SelfUpdate(sse sendSSE, binaryPath string) error {
|
||||
return nil
|
||||
case err := <-waitCmd(cmd):
|
||||
if err != nil {
|
||||
return fmt.Errorf("update failed: %v", err)
|
||||
return fmt.Errorf("update failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
6
backend/user/models.go
Normal file
6
backend/user/models.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package user
|
||||
|
||||
type User struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
66
backend/user/repository.go
Normal file
66
backend/user/repository.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"goaway/backend/database"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
Create(user *database.User) error
|
||||
FindByUsername(username string) (*User, error)
|
||||
UpdatePassword(username string, hashedPassword string) error
|
||||
}
|
||||
|
||||
type repository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewRepository(db *gorm.DB) Repository {
|
||||
return &repository{db: db}
|
||||
}
|
||||
|
||||
func (r *repository) Create(user *database.User) error {
|
||||
result := r.db.Create(user)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return errors.New("user creation failed: no rows affected")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repository) FindByUsername(username string) (*User, error) {
|
||||
var user database.User
|
||||
err := r.db.Where("username = ?", username).Find(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.Username == "" {
|
||||
return nil, errors.New("user not found")
|
||||
}
|
||||
|
||||
return &User{
|
||||
Username: user.Username,
|
||||
Password: user.Password,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *repository) UpdatePassword(username, hashedPassword string) error {
|
||||
affected, err := gorm.G[database.User](r.db).Where("username = ?", username).Update(context.Background(), "password", hashedPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if affected == 0 {
|
||||
return errors.New("password update failed: no rows affected")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
114
backend/user/service.go
Normal file
114
backend/user/service.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"goaway/backend/database"
|
||||
"goaway/backend/logging"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
repository Repository
|
||||
}
|
||||
|
||||
var log = logging.GetLogger()
|
||||
|
||||
func NewService(repo Repository) *Service {
|
||||
return &Service{repository: repo}
|
||||
}
|
||||
|
||||
func (s *Service) CreateUser(username, password string) error {
|
||||
log.Info("Creating a new user with name '%s'", username)
|
||||
|
||||
hashedPassword, err := hashPassword(password)
|
||||
if err != nil {
|
||||
log.Error("Failed to hash password: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
newUser := &database.User{Username: username, Password: hashedPassword}
|
||||
|
||||
if err := s.repository.Create(newUser); err != nil {
|
||||
log.Error("Failed to create user: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("User created successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) Exists(username string) bool {
|
||||
user, err := s.repository.FindByUsername(username)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if user != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Service) Authenticate(username, password string) bool {
|
||||
user, err := s.repository.FindByUsername(username)
|
||||
if err != nil {
|
||||
log.Error("Authentication failed for user '%s': %v", username, err)
|
||||
return false
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
|
||||
log.Debug("Invalid password for user '%s'", username)
|
||||
return false
|
||||
}
|
||||
|
||||
log.Debug("User '%s' authenticated successfully", username)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Service) UpdatePassword(username, newPassword string) error {
|
||||
hashedPassword, err := hashPassword(newPassword)
|
||||
if err != nil {
|
||||
log.Error("Failed to hash new password: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.repository.UpdatePassword(username, hashedPassword); err != nil {
|
||||
log.Error("Failed to update password: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("Password updated successfully for user '%s'", username)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) ValidateCredentials(user User) error {
|
||||
user.Username = strings.TrimSpace(user.Username)
|
||||
user.Password = strings.TrimSpace(user.Password)
|
||||
|
||||
if user.Username == "" || user.Password == "" {
|
||||
return errors.New("username and password cannot be empty")
|
||||
}
|
||||
|
||||
if len(user.Username) > 60 {
|
||||
return errors.New("username too long")
|
||||
}
|
||||
if len(user.Password) > 120 {
|
||||
return errors.New("password too long")
|
||||
}
|
||||
|
||||
for _, r := range user.Username {
|
||||
if r < 32 || r == 127 {
|
||||
return errors.New("username contains invalid characters")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func hashPassword(password string) (string, error) {
|
||||
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
return string(hashed), err
|
||||
}
|
||||
61
backend/whitelist/repository.go
Normal file
61
backend/whitelist/repository.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package whitelist
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"goaway/backend/database"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
AddDomain(domain string) error
|
||||
GetDomains() (map[string]bool, error)
|
||||
RemoveDomain(domain string) error
|
||||
}
|
||||
|
||||
type repository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewRepository(db *gorm.DB) Repository {
|
||||
return &repository{db: db}
|
||||
}
|
||||
|
||||
func (r *repository) AddDomain(domain string) error {
|
||||
result := r.db.Create(&database.Whitelist{Domain: domain})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("%s is already whitelisted", domain)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repository) GetDomains() (map[string]bool, error) {
|
||||
var records []database.Whitelist
|
||||
result := r.db.Find(&records)
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
domainMap := make(map[string]bool)
|
||||
for _, record := range records {
|
||||
domainMap[record.Domain] = true
|
||||
}
|
||||
|
||||
return domainMap, nil
|
||||
}
|
||||
|
||||
func (r *repository) RemoveDomain(domain string) error {
|
||||
result := r.db.Delete(&database.Whitelist{}, "domain = ?", domain)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("%s does not exist", domain)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
59
backend/whitelist/service.go
Normal file
59
backend/whitelist/service.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package whitelist
|
||||
|
||||
import "goaway/backend/logging"
|
||||
|
||||
type Service struct {
|
||||
repository Repository
|
||||
Cache map[string]bool
|
||||
}
|
||||
|
||||
var log = logging.GetLogger()
|
||||
|
||||
func NewService(repo Repository) *Service {
|
||||
service := &Service{
|
||||
repository: repo,
|
||||
Cache: map[string]bool{},
|
||||
}
|
||||
|
||||
_, err := service.GetDomains() // Preload cache
|
||||
if err != nil {
|
||||
log.Warning("Could not preload domains cache, %v", err)
|
||||
}
|
||||
|
||||
return service
|
||||
}
|
||||
|
||||
func (s *Service) AddDomain(domain string) error {
|
||||
err := s.repository.AddDomain(domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.Cache[domain] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) GetDomains() (map[string]bool, error) {
|
||||
domains, err := s.repository.GetDomains()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.Cache = domains
|
||||
return domains, nil
|
||||
}
|
||||
|
||||
func (s *Service) RemoveDomain(domain string) error {
|
||||
err := s.repository.RemoveDomain(domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(s.Cache, domain)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) IsWhitelisted(domain string) bool {
|
||||
_, exists := s.Cache[domain]
|
||||
return exists
|
||||
}
|
||||
@@ -17,5 +17,8 @@
|
||||
"lib": "@/lib",
|
||||
"hooks": "@/hooks"
|
||||
},
|
||||
"iconLibrary": "phosphor-icons/react"
|
||||
"iconLibrary": "phosphor-icons/react",
|
||||
"registries": {
|
||||
"@magicui": "https://magicui.design/r/{name}.json"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,22 +12,8 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"@phosphor-icons/react": "^2.1.10",
|
||||
"@radix-ui/react-checkbox": "^1.3.3",
|
||||
"@radix-ui/react-collapsible": "^1.1.12",
|
||||
"@radix-ui/react-dialog": "^1.1.15",
|
||||
"@radix-ui/react-dropdown-menu": "^2.1.16",
|
||||
"@radix-ui/react-label": "^2.1.7",
|
||||
"@radix-ui/react-popover": "^1.1.15",
|
||||
"@radix-ui/react-scroll-area": "^1.2.10",
|
||||
"@radix-ui/react-select": "^2.2.6",
|
||||
"@radix-ui/react-separator": "^1.1.7",
|
||||
"@radix-ui/react-slider": "^1.3.6",
|
||||
"@radix-ui/react-slot": "^1.2.3",
|
||||
"@radix-ui/react-switch": "^1.2.6",
|
||||
"@radix-ui/react-tabs": "^1.1.13",
|
||||
"@radix-ui/react-toggle": "^1.1.10",
|
||||
"@radix-ui/react-toggle-group": "^1.1.11",
|
||||
"@radix-ui/react-tooltip": "^1.2.8",
|
||||
"@tanstack/react-table": "^8.21.3",
|
||||
"@tsparticles/engine": "^3.9.1",
|
||||
"@tsparticles/preset-stars": "^3.2.0",
|
||||
@@ -37,35 +23,41 @@
|
||||
"clsx": "^2.1.1",
|
||||
"cmdk": "1.1.1",
|
||||
"compare-versions": "^6.1.1",
|
||||
"motion": "^12.23.12",
|
||||
"lucide-react": "^0.546.0",
|
||||
"motion": "^12.23.24",
|
||||
"next-themes": "^0.4.6",
|
||||
"react": "^19.1.1",
|
||||
"react-dom": "^19.1.1",
|
||||
"react-force-graph-2d": "^1.28.0",
|
||||
"react-router-dom": "^7.8.2",
|
||||
"radix-ui": "^1.4.3",
|
||||
"react": "^19.2.0",
|
||||
"react-dom": "^19.2.0",
|
||||
"react-force-graph-2d": "^1.29.0",
|
||||
"react-router-dom": "^7.9.4",
|
||||
"react-timeago": "^8.3.0",
|
||||
"recharts": "^3.1.2",
|
||||
"recharts": "^3.3.0",
|
||||
"sonner": "^2.0.7",
|
||||
"tailwind-merge": "^3.3.1",
|
||||
"tw-animate-css": "^1.3.8"
|
||||
"tw-animate-css": "^1.4.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@commitlint/cli": "^19.8.1",
|
||||
"@commitlint/config-conventional": "^19.8.1",
|
||||
"babel-plugin-react-compiler": "^1.0.0",
|
||||
"@commitlint/cli": "^20.1.0",
|
||||
"@commitlint/config-conventional": "^20.0.0",
|
||||
"@commitlint/types": "^20.0.0",
|
||||
"@eslint/js": "^9.34.0",
|
||||
"@tailwindcss/vite": "^4.1.12",
|
||||
"@types/react": "^19.1.12",
|
||||
"@types/react-dom": "^19.1.9",
|
||||
"@vitejs/plugin-react": "^5.0.2",
|
||||
"eslint": "^9.34.0",
|
||||
"eslint-plugin-react-hooks": "^5.2.0",
|
||||
"eslint-plugin-react-refresh": "^0.4.20",
|
||||
"globals": "^16.3.0",
|
||||
"tailwindcss": "^4.1.12",
|
||||
"typescript": "~5.9.2",
|
||||
"typescript-eslint": "^8.42.0",
|
||||
"vite": "^7.1.4"
|
||||
"@eslint/js": "^9.38.0",
|
||||
"@radix-ui/react-dialog": "^1.1.15",
|
||||
"@radix-ui/react-select": "^2.2.6",
|
||||
"@tailwindcss/vite": "^4.1.15",
|
||||
"@types/canvas-confetti": "^1.9.0",
|
||||
"@types/react": "^19.2.2",
|
||||
"@types/react-dom": "^19.2.2",
|
||||
"@vitejs/plugin-react": "^5.0.4",
|
||||
"eslint": "^9.38.0",
|
||||
"eslint-plugin-react-hooks": "^7.0.0",
|
||||
"eslint-plugin-react-refresh": "^0.4.24",
|
||||
"globals": "^16.4.0",
|
||||
"tailwindcss": "^4.1.15",
|
||||
"typescript": "~5.9.3",
|
||||
"typescript-eslint": "^8.46.2",
|
||||
"vite": "^7.1.11"
|
||||
},
|
||||
"pnpm": {
|
||||
"onlyBuiltDependencies": [
|
||||
|
||||
2975
client/pnpm-lock.yaml
generated
2975
client/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
Binary file not shown.
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 7.5 KiB |
BIN
client/public/gray-icon.png
Normal file
BIN
client/public/gray-icon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 3.3 MiB |
Binary file not shown.
|
Before Width: | Height: | Size: 27 KiB After Width: | Height: | Size: 3.5 MiB |
@@ -1,18 +1,18 @@
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Dialog, DialogContent, DialogTitle } from "@/components/ui/dialog";
|
||||
import { Tabs, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
import { ClientEntry } from "@/pages/clients";
|
||||
import { GetRequest } from "@/util";
|
||||
import {
|
||||
BirdIcon,
|
||||
CaretDownIcon,
|
||||
CaretRightIcon,
|
||||
EyeIcon,
|
||||
ClockCounterClockwiseIcon,
|
||||
EyeglassesIcon,
|
||||
IdentificationBadgeIcon,
|
||||
LightningIcon,
|
||||
LineSegmentsIcon,
|
||||
PlusMinusIcon,
|
||||
RowsIcon,
|
||||
ShieldIcon,
|
||||
SparkleIcon
|
||||
SparkleIcon,
|
||||
TargetIcon
|
||||
} from "@phosphor-icons/react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
@@ -79,75 +79,76 @@ export function CardDetails({
|
||||
|
||||
return (
|
||||
<Dialog open onOpenChange={onClose}>
|
||||
<DialogContent className="border-none bg-accent rounded-lg w-full max-w-5xl mx-auto p-0 overflow-hidden max-h-[90vh] flex flex-col">
|
||||
<div className="p-4 sm:p-6 border-b">
|
||||
<DialogTitle>
|
||||
<div className="flex items-center justify-between">
|
||||
<div>
|
||||
<h2 className="text-xl sm:text-2xl font-bold mb-1">
|
||||
{clientEntry.name || "Unnamed Device"}
|
||||
</h2>
|
||||
<div className="flex items-center text-sm">
|
||||
<span className="bg-blue-400 px-2 py-0.5 rounded-full font-medium">
|
||||
{clientEntry.ip}
|
||||
<DialogContent className="border-none bg-accent rounded-lg w-full max-w-6xl overflow-hidden">
|
||||
<DialogTitle>
|
||||
<div className="flex items-center justify-between">
|
||||
<div>
|
||||
<h2 className="text-xl sm:text-2xl font-bold mb-1">
|
||||
{clientEntry.name || "unknown"}
|
||||
</h2>
|
||||
<div className="flex items-center text-sm gap-2">
|
||||
<span className="bg-muted-foreground/20 px-2 py-0.5 rounded-md font-mono text-xs">
|
||||
ip: {clientEntry.ip}
|
||||
</span>
|
||||
{clientEntry.mac && (
|
||||
<span className="bg-muted-foreground/20 px-2 py-0.5 rounded-md font-mono text-xs">
|
||||
mac: {clientEntry.mac}
|
||||
</span>
|
||||
{clientEntry.mac && (
|
||||
<span className="ml-2 flex items-center">
|
||||
<IdentificationBadgeIcon size={14} className="mr-1" />
|
||||
{clientEntry.mac}
|
||||
</span>
|
||||
)}
|
||||
{clientEntry.vendor && (
|
||||
<span className="ml-2 opacity-75">
|
||||
{clientEntry.vendor}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div className="text-right hidden sm:block">
|
||||
<span className="text-xs">Last Activity</span>
|
||||
<div className="text-muted-foreground">
|
||||
{formatTimeAgo(clientEntry.lastSeen)}
|
||||
</div>
|
||||
)}
|
||||
{clientEntry.vendor && (
|
||||
<span className="bg-muted-foreground/20 px-2 py-0.5 rounded-md font-mono text-xs">
|
||||
vendor: {clientEntry.vendor}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</DialogTitle>
|
||||
</div>
|
||||
<div className="text-right hidden sm:block">
|
||||
<span className="text-xs">Last Activity</span>
|
||||
<div className="text-muted-foreground">
|
||||
{formatTimeAgo(clientEntry.lastSeen)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</DialogTitle>
|
||||
|
||||
<div className="flex bg-background border-b">
|
||||
<button
|
||||
className={`px-4 py-2 text-sm font-medium flex items-center ${
|
||||
activeTab === "overview"
|
||||
? "text-blue-400 border-b-2 border-blue-500"
|
||||
: "text-muted-foreground hover:font-bold hover:border-b-2 border-stone-500 cursor-pointer"
|
||||
}`}
|
||||
onClick={() => setActiveTab("overview")}
|
||||
>
|
||||
<BirdIcon size={16} className="mr-2" />
|
||||
Overview
|
||||
</button>
|
||||
<button
|
||||
className={`px-4 py-2 text-sm font-medium flex items-center ${
|
||||
activeTab === "domains"
|
||||
? "text-blue-400 border-b-2 border-blue-500"
|
||||
: "text-muted-foreground hover:font-bold hover:border-b-2 border-stone-500 cursor-pointer"
|
||||
}`}
|
||||
onClick={() => setActiveTab("domains")}
|
||||
>
|
||||
<LineSegmentsIcon size={16} className="mr-2" />
|
||||
Domains
|
||||
</button>
|
||||
</div>
|
||||
<Tabs defaultValue="overview">
|
||||
<TabsList className="bg-transparent space-x-2">
|
||||
<TabsTrigger
|
||||
value="overview"
|
||||
className="border-l-0 !bg-transparent border-t-0 border-r-0 cursor-pointer data-[state=active]:border-b-2 data-[state=active]:!border-b-primary rounded-none"
|
||||
onClick={() => setActiveTab("overview")}
|
||||
>
|
||||
<TargetIcon />
|
||||
Overview
|
||||
</TabsTrigger>
|
||||
<TabsTrigger
|
||||
value="domains"
|
||||
className="border-l-0 !bg-transparent border-t-0 border-r-0 cursor-pointer data-[state=active]:border-b-2 data-[state=active]:!border-b-primary rounded-none"
|
||||
onClick={() => setActiveTab("domains")}
|
||||
>
|
||||
<RowsIcon />
|
||||
Domains
|
||||
</TabsTrigger>
|
||||
<TabsTrigger
|
||||
value="history"
|
||||
className="border-l-0 !bg-transparent border-t-0 border-r-0 cursor-pointer data-[state=active]:border-b-2 data-[state=active]:!border-b-primary rounded-none"
|
||||
onClick={() => setActiveTab("history")}
|
||||
>
|
||||
<ClockCounterClockwiseIcon />
|
||||
History
|
||||
</TabsTrigger>
|
||||
</TabsList>
|
||||
</Tabs>
|
||||
|
||||
{isLoading ? (
|
||||
<div className="flex justify-center items-center p-16">
|
||||
<div className="animate-spin rounded-full h-12 w-12 border-t-2 border-b-2 border-blue-500" />
|
||||
</div>
|
||||
) : clientDetails ? (
|
||||
<div className="overflow-y-auto p-4 sm:p-6 flex-grow">
|
||||
<div className="overflow-y-auto sm:p-2 grow">
|
||||
{activeTab === "overview" && (
|
||||
<>
|
||||
<div className="grid grid-cols-2 sm:grid-cols-3 lg:grid-cols-6 gap-3 mb-6">
|
||||
<div className="grid grid-cols-2 sm:grid-cols-3 lg:grid-cols-6 gap-2 mb-4">
|
||||
<StatCard
|
||||
icon={<EyeglassesIcon size={18} />}
|
||||
label="Total Requests"
|
||||
@@ -181,17 +182,14 @@ export function CardDetails({
|
||||
<StatCard
|
||||
icon={<CaretDownIcon size={18} />}
|
||||
label="Most Queried"
|
||||
value={clientDetails.mostQueriedDomain.split(".")[0]}
|
||||
value={clientDetails.mostQueriedDomain}
|
||||
color="bg-indigo-600"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="mb-6">
|
||||
<h3 className="text-lg font-semibold mb-3 flex items-center">
|
||||
<EyeIcon size={18} className="mr-2 text-blue-400" />
|
||||
Top Queried Domains
|
||||
</h3>
|
||||
<div className="grid gap-3">
|
||||
<p className="mb-2">Top Queried Domains</p>
|
||||
<div className="grid gap-2">
|
||||
{Object.entries(clientDetails.allDomains)
|
||||
.sort((a, b) => b[1] - a[1])
|
||||
.slice(0, 5)
|
||||
@@ -208,10 +206,10 @@ export function CardDetails({
|
||||
<div className="w-12 text-center font-mono py-1 rounded text-xs font-medium">
|
||||
{count}
|
||||
</div>
|
||||
<div className="ml-3 flex-grow font-medium truncate">
|
||||
<div className="ml-3 grow font-medium truncate">
|
||||
{domain}
|
||||
</div>
|
||||
<div className="w-24 flex-shrink-0">
|
||||
<div className="w-24 shrink-0">
|
||||
<div className="h-2 bg-accent rounded-full w-full">
|
||||
<div
|
||||
className={`h-2 rounded-full ${getProgressColor(
|
||||
@@ -227,49 +225,37 @@ export function CardDetails({
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
<div className="mt-2 text-center">
|
||||
<button
|
||||
className="text-blue-400 hover:text-blue-300 text-sm flex items-center mx-auto"
|
||||
onClick={() => setActiveTab("domains")}
|
||||
>
|
||||
View all domains <CaretRightIcon size={16} />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-1 sm:grid-cols-3 gap-3 mt-4">
|
||||
<ActionButton
|
||||
label="[WIP] View Details"
|
||||
bgClass="bg-blue-500 text-white"
|
||||
/>
|
||||
<ActionButton
|
||||
label="[WIP] Block Device"
|
||||
bgClass="bg-red-500 text-white"
|
||||
/>
|
||||
<ActionButton
|
||||
label="[WIP] Device Settings"
|
||||
bgClass="bg-stone-500 text-white"
|
||||
/>
|
||||
<Button variant="outline" disabled={true}>
|
||||
[WIP] View Details
|
||||
</Button>
|
||||
<Button variant="outline" disabled={true}>
|
||||
[WIP] Block Device
|
||||
</Button>
|
||||
<Button variant="outline" disabled={true}>
|
||||
[WIP] Device Settings
|
||||
</Button>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
{activeTab === "domains" && (
|
||||
<div>
|
||||
<h3 className="text-lg font-semibold mb-3 flex items-center">
|
||||
<EyeIcon size={18} className="mr-2 text-blue-400" />
|
||||
<p className="mb-2">
|
||||
All Queried Domains
|
||||
<span className="ml-2 text-xs bg-accent px-2 py-0.5 rounded-full text-muted-foreground">
|
||||
{Object.keys(clientDetails.allDomains).length} domains
|
||||
</span>
|
||||
</h3>
|
||||
</p>
|
||||
|
||||
<div className="shadow-md border rounded-md overflow-hidden">
|
||||
<div className="flex justify-between items-center py-2 px-3">
|
||||
<div className="w-16 text-xs text-muted-foreground font-medium">
|
||||
Count
|
||||
</div>
|
||||
<div className="flex-grow text-xs text-muted-foreground font-medium">
|
||||
<div className="grow text-xs text-muted-foreground font-medium">
|
||||
Domain
|
||||
</div>
|
||||
<div className="w-24 text-xs text-muted-foreground font-medium">
|
||||
@@ -294,7 +280,7 @@ export function CardDetails({
|
||||
<div className="w-16 font-mono bg-accent py-1 rounded text-center text-xs font-medium">
|
||||
{count}
|
||||
</div>
|
||||
<div className="ml-3 flex-grow font-medium truncate">
|
||||
<div className="ml-3 grow font-medium truncate">
|
||||
{domain}
|
||||
</div>
|
||||
<div className="w-24 text-right text-muted-foreground text-sm">
|
||||
@@ -309,7 +295,7 @@ export function CardDetails({
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
<div className="text-center py-16 flex-grow flex flex-col items-center justify-center">
|
||||
<div className="text-center py-16 grow flex flex-col items-center justify-center">
|
||||
<ShieldIcon size={48} className="mb-4" />
|
||||
<div className="text-lg">No data available for this client</div>
|
||||
<div className="text-sm mt-2 text-muted-foreground">
|
||||
@@ -326,23 +312,12 @@ function StatCard({ icon, label, value, color }) {
|
||||
return (
|
||||
<div className="rounded-sm shadow-md bg-background">
|
||||
<div className={`${color} h-1`}></div>
|
||||
<div className="p-3">
|
||||
<div className="flex items-center text-xs text-muted-foreground mb-1">
|
||||
<span className="mr-1">{icon}</span>
|
||||
{label}
|
||||
<div className="p-2">
|
||||
<div className="flex items-center text-xs text-muted-foreground mb-1 gap-1">
|
||||
{icon} {label}
|
||||
</div>
|
||||
<div className="font-bold text-lg truncate">{value}</div>
|
||||
<div className="font-bold text-sm truncate">{value}</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ActionButton({ label, bgClass }) {
|
||||
return (
|
||||
<button
|
||||
className={`${bgClass} px-4 py-3 rounded-md text-sm font-medium transition-all shadow-md hover:shadow-lg`}
|
||||
>
|
||||
{label}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import { XIcon } from "@phosphor-icons/react";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import ForceGraph2D, { ForceGraphMethods } from "react-force-graph-2d";
|
||||
import { CardDetails } from "./details";
|
||||
import { toast } from "sonner";
|
||||
|
||||
interface ClientEntry {
|
||||
ip: string;
|
||||
@@ -129,8 +130,8 @@ export default function DNSServerVisualizer() {
|
||||
const [viewSettings, setViewSettings] = useState<ViewSettings>({
|
||||
clusterBySubnet: false,
|
||||
hideInactiveClients: false,
|
||||
minNodeSize: 3,
|
||||
maxNodeSize: 12,
|
||||
minNodeSize: 2,
|
||||
maxNodeSize: 10,
|
||||
showLabels: true,
|
||||
activityThresholdMinutes: 60
|
||||
});
|
||||
@@ -235,7 +236,7 @@ export default function DNSServerVisualizer() {
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to fetch clients"
|
||||
);
|
||||
console.error("Error fetching clients:", err);
|
||||
toast.warning("Error fetching clients", { description: `${err}` });
|
||||
}
|
||||
};
|
||||
|
||||
@@ -289,7 +290,9 @@ export default function DNSServerVisualizer() {
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error handling WebSocket message:", error);
|
||||
toast.warning("Error handling WebSocket message", {
|
||||
description: `${error}`
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -319,14 +322,14 @@ export default function DNSServerVisualizer() {
|
||||
id: "dns-server",
|
||||
name: "DNS Server",
|
||||
type: "server",
|
||||
color: "#3b82f6",
|
||||
color: "cornflowerblue",
|
||||
size: viewSettings.maxNodeSize
|
||||
},
|
||||
{
|
||||
id: "upstream",
|
||||
name: "Upstream",
|
||||
type: "server",
|
||||
color: "#008000",
|
||||
color: "teal",
|
||||
size: viewSettings.maxNodeSize
|
||||
}
|
||||
];
|
||||
@@ -557,7 +560,7 @@ export default function DNSServerVisualizer() {
|
||||
|
||||
<div
|
||||
ref={containerRef}
|
||||
className="rounded-xl shadow-md p-4 w-full border-1 dark:bg-accent"
|
||||
className="rounded-xl shadow-md p-4 w-full border dark:bg-accent"
|
||||
>
|
||||
<div className="grid grid-cols-4 gap-2 text-sm mb-4">
|
||||
{[
|
||||
@@ -583,7 +586,7 @@ export default function DNSServerVisualizer() {
|
||||
).length
|
||||
}
|
||||
].map(({ label, plural, value }) => (
|
||||
<div key={label} className="rounded-lg py-0.5 text-center border-1">
|
||||
<div key={label} className="rounded-lg py-0.5 text-center border">
|
||||
<p className="text-sm font-medium">{value}</p>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
{value === 1 ? label : plural}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"use client";
|
||||
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { NoContent } from "@/shared";
|
||||
import { GetRequest } from "@/util";
|
||||
import { ArticleIcon, WarningIcon } from "@phosphor-icons/react";
|
||||
import { ArticleIcon } from "@phosphor-icons/react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
|
||||
@@ -87,16 +88,9 @@ export default function Audit() {
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
<EmptyState />
|
||||
<NoContent text={"No audits created"} />
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
const EmptyState = () => (
|
||||
<div className="grid place-items-center py-8">
|
||||
<WarningIcon size={32} className="text-destructive mb-2" />
|
||||
<p className="text-sm text-muted-foreground">No audit entries found</p>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { GetRequest } from "@/util";
|
||||
import { NetworkSlashIcon, WarningIcon } from "@phosphor-icons/react";
|
||||
import { NetworkSlashIcon } from "@phosphor-icons/react";
|
||||
import { useEffect, useState, useRef } from "react";
|
||||
import {
|
||||
Bar,
|
||||
@@ -20,6 +20,7 @@ import {
|
||||
NameType,
|
||||
ValueType
|
||||
} from "recharts/types/component/DefaultTooltipContent";
|
||||
import { NoContent } from "@/shared";
|
||||
|
||||
type TopBlockedDomains = {
|
||||
frequency: number;
|
||||
@@ -58,17 +59,6 @@ const CustomTooltip = ({
|
||||
return null;
|
||||
};
|
||||
|
||||
const EmptyState = () => (
|
||||
<div className="flex flex-col items-center justify-center h-full w-full py-10">
|
||||
<div className="mb-4">
|
||||
<WarningIcon size={36} className="text-destructive" />
|
||||
</div>
|
||||
<p className="text-muted-foreground text-sm text-center">
|
||||
Blocked domains will appear here when detected
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
|
||||
const isNewData = (a: TopBlockedDomains[], b: TopBlockedDomains[]): boolean => {
|
||||
if (a.length !== b.length) return false;
|
||||
|
||||
@@ -211,7 +201,7 @@ export default function FrequencyChartBlockedDomains() {
|
||||
</BarChart>
|
||||
</ResponsiveContainer>
|
||||
) : (
|
||||
<EmptyState />
|
||||
<NoContent text={"No domain has been blocked"} />
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { Tabs, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
import { NoContent } from "@/shared";
|
||||
import { GetRequest } from "@/util";
|
||||
import { UsersIcon, WarningIcon } from "@phosphor-icons/react";
|
||||
import { UsersIcon } from "@phosphor-icons/react";
|
||||
import { useEffect, useState, useRef } from "react";
|
||||
import {
|
||||
Bar,
|
||||
@@ -58,17 +59,6 @@ const CustomTooltip = ({
|
||||
return null;
|
||||
};
|
||||
|
||||
const EmptyState = () => (
|
||||
<div className="flex flex-col items-center justify-center h-full w-full py-10">
|
||||
<div className="mb-4">
|
||||
<WarningIcon size={36} className="text-destructive" />
|
||||
</div>
|
||||
<p className="text-muted-foreground text-sm text-center">
|
||||
Client data will appear here when requests are detected
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
|
||||
const isNewData = (a: TopBlockedClients[], b: TopBlockedClients[]): boolean => {
|
||||
if (a.length !== b.length) return false;
|
||||
|
||||
@@ -214,7 +204,7 @@ export default function FrequencyChartTopBlockedClients() {
|
||||
</BarChart>
|
||||
</ResponsiveContainer>
|
||||
) : (
|
||||
<EmptyState />
|
||||
<NoContent text={"No client data to show"} />
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
@@ -6,6 +6,7 @@ import {
|
||||
ChartLegendContent,
|
||||
ChartTooltip
|
||||
} from "@/components/ui/chart";
|
||||
import { NoContent } from "@/shared";
|
||||
import { GetRequest } from "@/util";
|
||||
import {
|
||||
ArrowsClockwiseIcon,
|
||||
@@ -93,9 +94,15 @@ export default function ResponseSizeTimeline() {
|
||||
}, [timelineInterval]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchData();
|
||||
const timer = setTimeout(() => {
|
||||
fetchData();
|
||||
}, 0);
|
||||
|
||||
const interval = setInterval(fetchData, 10000);
|
||||
return () => clearInterval(interval);
|
||||
return () => {
|
||||
clearTimeout(timer);
|
||||
clearInterval(interval);
|
||||
};
|
||||
}, [fetchData]);
|
||||
|
||||
const getFilteredData = () => {
|
||||
@@ -207,7 +214,7 @@ export default function ResponseSizeTimeline() {
|
||||
<div className="flex gap-2">
|
||||
{isZoomed && (
|
||||
<Button
|
||||
className="bg-transparent border-1 text-white hover:bg-stone-800"
|
||||
className="bg-transparent border text-white hover:bg-stone-800"
|
||||
onClick={handleZoomOut}
|
||||
>
|
||||
<MagnifyingGlassMinusIcon weight="bold" className="mr-1" />
|
||||
@@ -444,13 +451,8 @@ export default function ResponseSizeTimeline() {
|
||||
</CardContent>
|
||||
</>
|
||||
) : (
|
||||
<CardContent className="flex h-[300px] items-center justify-center">
|
||||
<div className="text-center">
|
||||
<p className="text-lg font-medium">No data available</p>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
No response size data recorded yet
|
||||
</p>
|
||||
</div>
|
||||
<CardContent className="flex h-[220px] items-center justify-center">
|
||||
<NoContent text={"No requests recorded"} />
|
||||
</CardContent>
|
||||
)}
|
||||
</Card>
|
||||
|
||||
@@ -27,11 +27,11 @@ import {
|
||||
ArrowsClockwiseIcon,
|
||||
ChartLineIcon,
|
||||
MagnifyingGlassMinusIcon,
|
||||
MagnifyingGlassPlusIcon,
|
||||
WarningIcon
|
||||
MagnifyingGlassPlusIcon
|
||||
} from "@phosphor-icons/react";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import { Button } from "../../components/ui/button";
|
||||
import { NoContent } from "@/shared";
|
||||
|
||||
const chartConfig = {
|
||||
blocked: {
|
||||
@@ -90,9 +90,14 @@ export default function RequestTimeline() {
|
||||
}, [timelineInterval]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchData();
|
||||
const interval = setInterval(fetchData, 10000);
|
||||
return () => clearInterval(interval);
|
||||
const timeout = window.setTimeout(() => {
|
||||
fetchData();
|
||||
}, 0);
|
||||
const interval = window.setInterval(fetchData, 10000);
|
||||
return () => {
|
||||
window.clearTimeout(timeout);
|
||||
window.clearInterval(interval);
|
||||
};
|
||||
}, [fetchData]);
|
||||
|
||||
const getFilteredData = () => {
|
||||
@@ -196,7 +201,7 @@ export default function RequestTimeline() {
|
||||
<div className="flex gap-2">
|
||||
{isZoomed && (
|
||||
<Button
|
||||
className="bg-transparent border-1 text-white hover:bg-stone-800"
|
||||
className="bg-transparent border text-white hover:bg-stone-800"
|
||||
onClick={handleZoomOut}
|
||||
>
|
||||
<MagnifyingGlassMinusIcon weight="bold" className="mr-1" />
|
||||
@@ -411,15 +416,8 @@ export default function RequestTimeline() {
|
||||
</CardContent>
|
||||
</>
|
||||
) : (
|
||||
<CardContent className="flex h-[300px] items-center justify-center">
|
||||
<div className="flex flex-col items-center justify-center">
|
||||
<div className="mb-4">
|
||||
<WarningIcon size={36} className="text-destructive" />
|
||||
</div>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
No requests recorded yet
|
||||
</p>
|
||||
</div>
|
||||
<CardContent className="flex h-[220px] items-center justify-center">
|
||||
<NoContent text={"No requests recorded"} />
|
||||
</CardContent>
|
||||
)}
|
||||
</Card>
|
||||
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
SelectTrigger,
|
||||
SelectValue
|
||||
} from "@/components/ui/select";
|
||||
import { NoContent } from "@/shared";
|
||||
import { GetRequest } from "@/util";
|
||||
import { SetStateAction, useEffect, useState } from "react";
|
||||
import {
|
||||
@@ -144,12 +145,7 @@ export default function RequestTypeChart() {
|
||||
</CardContent>
|
||||
) : (
|
||||
<CardContent className="flex h-[200px] items-center justify-center">
|
||||
<div className="text-center">
|
||||
<p className="text-sm font-medium">No data available</p>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
No query types have yet been identified
|
||||
</p>
|
||||
</div>
|
||||
<NoContent text={"No query types has been identified"} />
|
||||
</CardContent>
|
||||
)}
|
||||
</Card>
|
||||
|
||||
@@ -20,8 +20,8 @@ async function removeDomain(domain: string) {
|
||||
}
|
||||
}
|
||||
|
||||
export default function BlockedDomainsList({ listName }) {
|
||||
const [domains, setDomains] = useState([]);
|
||||
export default function BlockedDomainsList({ listName }: { listName: string }) {
|
||||
const [domains, setDomains] = useState<string[]>([]);
|
||||
const [searchTerm, setSearchTerm] = useState("");
|
||||
const [loading, setLoading] = useState(true);
|
||||
|
||||
@@ -32,7 +32,7 @@ export function ListCard(
|
||||
...listEntry
|
||||
} = props;
|
||||
|
||||
const formattedDate = new Date(listEntry.lastUpdated * 1000).toLocaleString(
|
||||
const formattedDate = new Date(listEntry.lastUpdated).toLocaleString(
|
||||
"en-US",
|
||||
{
|
||||
month: "short",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user