fix: multiple fixes related to graceful shutdown, routine leaks, error channel overflow and more

This commit is contained in:
pommee
2026-05-01 11:13:55 +02:00
parent 7088d2f627
commit e84bbcd73c
8 changed files with 153 additions and 61 deletions
+7 -2
View File
@@ -9,6 +9,7 @@ import (
"goaway/backend/user"
"io"
"net/http"
"time"
"github.com/gin-gonic/gin"
)
@@ -98,7 +99,9 @@ func (api *API) updatePassword(c *gin.Context) {
Message: logMsg,
})
go func() {
_ = api.DNSServer.AlertService.SendToAll(context.Background(), alert.Message{
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = api.DNSServer.AlertService.SendToAll(ctx, alert.Message{
Title: "System",
Content: logMsg,
Severity: SeverityWarning,
@@ -135,7 +138,9 @@ func (api *API) createAPIKey(c *gin.Context) {
}
go func() {
_ = api.DNSServer.AlertService.SendToAll(context.Background(), alert.Message{
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = api.DNSServer.AlertService.SendToAll(ctx, alert.Message{
Title: "System",
Content: fmt.Sprintf("New API key created with the name '%s'", request.Name),
Severity: SeverityWarning,
+38 -32
View File
@@ -747,46 +747,52 @@ func (s *Service) Vacuum(ctx context.Context) {
}
}
func (s *Service) ScheduleAutomaticListUpdates() {
func (s *Service) ScheduleAutomaticListUpdates(ctx context.Context) {
ticker := time.NewTicker(s.config.UpdateInterval)
defer ticker.Stop()
for range ticker.C {
ctx := context.Background()
log.Info("Starting automatic list updates...")
for {
select {
case <-ctx.Done():
log.Debug("Stopping automatic list updates")
return
case <-ticker.C:
bgCtx := context.Background()
log.Info("Starting automatic list updates...")
for _, source := range s.blocklistURL {
if source.Name == "Custom" {
continue
}
log.Info("Checking for updates for blocklist %s from %s", source.Name, source.URL)
for _, source := range s.blocklistURL {
if source.Name == "Custom" {
continue
}
log.Info("Checking for updates for blocklist %s from %s", source.Name, source.URL)
availableUpdate, err := s.CheckIfUpdateAvailable(ctx, source.URL, source.Name)
if err != nil {
log.Warning("Failed to check for updates for %s: %v", source.Name, err)
continue
availableUpdate, err := s.CheckIfUpdateAvailable(bgCtx, source.URL, source.Name)
if err != nil {
log.Warning("Failed to check for updates for %s: %v", source.Name, err)
continue
}
if !availableUpdate.UpdateAvailable {
log.Info("No updates available for %s", source.Name)
continue
}
if err := s.RemoveSourceAndDomains(bgCtx, source.Name, source.URL); err != nil {
log.Warning("Failed to remove old domains for %s: %v", source.Name, err)
continue
}
if err := s.FetchAndLoadHosts(bgCtx, source.URL, source.Name); err != nil {
log.Warning("Failed to fetch and load hosts for %s: %v", source.Name, err)
continue
}
log.Info("Successfully updated %s with %d new domains", source.Name, len(availableUpdate.DiffAdded))
}
if !availableUpdate.UpdateAvailable {
log.Info("No updates available for %s", source.Name)
continue
if err := s.PopulateCache(bgCtx); err != nil {
log.Warning("Failed to populate blocklist cache after auto-update: %v", err)
}
if err := s.RemoveSourceAndDomains(ctx, source.Name, source.URL); err != nil {
log.Warning("Failed to remove old domains for %s: %v", source.Name, err)
continue
}
if err := s.FetchAndLoadHosts(ctx, source.URL, source.Name); err != nil {
log.Warning("Failed to fetch and load hosts for %s: %v", source.Name, err)
continue
}
log.Info("Successfully updated %s with %d new domains", source.Name, len(availableUpdate.DiffAdded))
}
if err := s.PopulateCache(ctx); err != nil {
log.Warning("Failed to populate blocklist cache after auto-update: %v", err)
}
}
}
+19 -7
View File
@@ -1,6 +1,7 @@
package server
import (
"context"
"encoding/json"
model "goaway/backend/dns/server/models"
"time"
@@ -10,13 +11,18 @@ import (
const batchSize = 1000
func (s *DNSServer) ProcessLogEntries() {
func (s *DNSServer) ProcessLogEntries(ctx context.Context) {
var batch []model.RequestLogEntry
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
if len(batch) > 0 {
s.saveBatch(batch)
}
return
case entry := <-s.logEntryChannel:
log.Debug("%s", entry.String())
if s.WSQueries != nil {
@@ -46,18 +52,24 @@ func (s *DNSServer) saveBatch(entries []model.RequestLogEntry) {
}
// Removes old log entries based on the configured retention period.
func (s *DNSServer) ClearOldEntries() {
func (s *DNSServer) ClearOldEntries(ctx context.Context) {
const (
maxRetries = 10
retryDelay = 150 * time.Millisecond
cleanupInterval = 5 * time.Minute
)
for {
requestThreshold := ((60 * 60) * 24) * s.Config.Misc.StatisticsRetention
log.Debug("Next cleanup running at %s", time.Now().Add(cleanupInterval).Format(time.DateTime))
time.Sleep(cleanupInterval)
ticker := time.NewTicker(cleanupInterval)
defer ticker.Stop()
s.RequestService.DeleteRequestLogsTimebased(s.BlacklistService.Vacuum, requestThreshold, maxRetries, retryDelay)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
requestThreshold := ((60 * 60) * 24) * s.Config.Misc.StatisticsRetention
log.Debug("Next cleanup running at %s", time.Now().Add(cleanupInterval).Format(time.DateTime))
s.RequestService.DeleteRequestLogsTimebased(s.BlacklistService.Vacuum, requestThreshold, maxRetries, retryDelay)
}
}
}
+6 -3
View File
@@ -1,6 +1,7 @@
package jobs
import (
"context"
arp "goaway/backend/dns"
"goaway/backend/logging"
"goaway/backend/services"
@@ -10,11 +11,13 @@ var log = logging.GetLogger()
type BackgroundJobs struct {
registry *services.ServiceRegistry
ctx context.Context
}
func NewBackgroundJobs(registry *services.ServiceRegistry) *BackgroundJobs {
return &BackgroundJobs{
registry: registry,
ctx: registry.ShutdownContext(),
}
}
@@ -54,7 +57,7 @@ func (b *BackgroundJobs) startScheduledUpdates(readyChan <-chan struct{}) {
<-readyChan
if b.registry.Context.Config.Misc.ScheduledBlacklistUpdates {
log.Debug("Starting scheduler for automatic list updates...")
b.registry.BlacklistService.ScheduleAutomaticListUpdates()
b.registry.BlacklistService.ScheduleAutomaticListUpdates(b.ctx)
}
}()
}
@@ -63,7 +66,7 @@ func (b *BackgroundJobs) startCacheCleanup(readyChan <-chan struct{}) {
go func() {
<-readyChan
log.Debug("Starting cache cleanup routine...")
b.registry.Context.DNSServer.ClearOldEntries()
b.registry.Context.DNSServer.ClearOldEntries(b.ctx)
}()
}
@@ -71,6 +74,6 @@ func (b *BackgroundJobs) startPrefetcher(readyChan <-chan struct{}) {
go func() {
<-readyChan
log.Debug("Starting prefetcher...")
b.registry.PrefetchService.Run()
b.registry.PrefetchService.Run(b.ctx)
}()
}
+46 -1
View File
@@ -1,6 +1,7 @@
package lifecycle
import (
"context"
"goaway/backend/api"
"goaway/backend/jobs"
"goaway/backend/logging"
@@ -8,6 +9,7 @@ import (
"os"
"os/signal"
"syscall"
"time"
)
var log = logging.GetLogger()
@@ -63,6 +65,49 @@ func (m *Manager) waitForTermination() error {
}
func (m *Manager) shutdown() {
// TODO: Add graceful shutdown logic
log.Info("Starting graceful shutdown...")
m.services.Shutdown()
m.services.APIServer.IsShuttingDown = true
if err := m.services.APIServer.Stop(); err != nil {
log.Error("Error stopping API server: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
m.services.UDPServer.Shutdown(ctx)
log.Info("Stopped UDP server")
m.services.TCPServer.Shutdown(ctx)
log.Info("Stopped TCP server")
if m.services.DoTServer != nil {
m.services.DoTServer.Shutdown(ctx)
log.Info("Stopped DNS-over-TLS server")
}
if m.services.DoHServer != nil {
if err := m.services.DoHServer.Shutdown(ctx); err != nil && err != context.DeadlineExceeded {
log.Error("Error stopping DoH server: %v", err)
}
log.Info("Stopped DNS-over-HTTPS server")
}
// Wait for all goroutines to finish with timeout
done := make(chan struct{})
go func() {
m.services.WaitGroup().Wait()
close(done)
}()
select {
case <-done:
log.Info("All services stopped gracefully")
case <-time.After(15 * time.Second):
log.Warning("Shutdown timeout exceeded, forcing exit")
}
os.Exit(0)
}
+11 -4
View File
@@ -1,6 +1,7 @@
package prefetch
import (
"context"
"fmt"
"goaway/backend/database"
"goaway/backend/dns/server"
@@ -30,13 +31,19 @@ func NewService(repo Repository, dnsServer *server.DNSServer) *Service {
return service
}
func (s *Service) Run() {
func (s *Service) Run(ctx context.Context) {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for range ticker.C {
s.checkNewDomains()
s.processExpiredEntries()
for {
select {
case <-ctx.Done():
log.Debug("Stopping prefetch service")
return
case <-ticker.C:
s.checkNewDomains()
s.processExpiredEntries()
}
}
}
+2 -1
View File
@@ -1,6 +1,7 @@
package services
import (
"context"
"crypto/tls"
"fmt"
"goaway/backend/database"
@@ -45,7 +46,7 @@ func (ctx *AppContext) initialize() error {
}
ctx.DNSServer = dnsServer
go dnsServer.ProcessLogEntries()
go dnsServer.ProcessLogEntries(context.Background())
return nil
}
+24 -11
View File
@@ -36,10 +36,12 @@ type ServiceRegistry struct {
Context *AppContext
version string
date string
commit string
wg sync.WaitGroup
version string
date string
commit string
wg sync.WaitGroup
shutdownCtx context.Context
cancel context.CancelFunc
ResolutionService *resolution.Service
RequestService *request.Service
@@ -57,14 +59,17 @@ type ServiceError struct {
}
func NewServiceRegistry(ctx *AppContext, version, commit, date string, content embed.FS) *ServiceRegistry {
shutdownCtx, cancel := context.WithCancel(context.Background())
return &ServiceRegistry{
Context: ctx,
version: version,
commit: commit,
date: date,
content: content,
readyChan: make(chan struct{}),
errorChan: make(chan ServiceError, 10),
Context: ctx,
version: version,
commit: commit,
date: date,
content: content,
readyChan: make(chan struct{}),
errorChan: make(chan ServiceError, 50),
shutdownCtx: shutdownCtx,
cancel: cancel,
}
}
@@ -230,6 +235,14 @@ func (r *ServiceRegistry) ReadyChannel() <-chan struct{} {
return r.readyChan
}
func (r *ServiceRegistry) ShutdownContext() context.Context {
return r.shutdownCtx
}
func (r *ServiceRegistry) Shutdown() {
r.cancel()
}
func (r *ServiceRegistry) ErrorChannel() <-chan ServiceError {
return r.errorChan
}