diff --git a/backend/api/auth.go b/backend/api/auth.go index fbf1fa8..4604b12 100644 --- a/backend/api/auth.go +++ b/backend/api/auth.go @@ -9,6 +9,7 @@ import ( "goaway/backend/user" "io" "net/http" + "time" "github.com/gin-gonic/gin" ) @@ -98,7 +99,9 @@ func (api *API) updatePassword(c *gin.Context) { Message: logMsg, }) go func() { - _ = api.DNSServer.AlertService.SendToAll(context.Background(), alert.Message{ + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = api.DNSServer.AlertService.SendToAll(ctx, alert.Message{ Title: "System", Content: logMsg, Severity: SeverityWarning, @@ -135,7 +138,9 @@ func (api *API) createAPIKey(c *gin.Context) { } go func() { - _ = api.DNSServer.AlertService.SendToAll(context.Background(), alert.Message{ + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = api.DNSServer.AlertService.SendToAll(ctx, alert.Message{ Title: "System", Content: fmt.Sprintf("New API key created with the name '%s'", request.Name), Severity: SeverityWarning, diff --git a/backend/blacklist/service.go b/backend/blacklist/service.go index eeb24a1..030f0e6 100644 --- a/backend/blacklist/service.go +++ b/backend/blacklist/service.go @@ -747,46 +747,52 @@ func (s *Service) Vacuum(ctx context.Context) { } } -func (s *Service) ScheduleAutomaticListUpdates() { +func (s *Service) ScheduleAutomaticListUpdates(ctx context.Context) { ticker := time.NewTicker(s.config.UpdateInterval) defer ticker.Stop() - for range ticker.C { - ctx := context.Background() - log.Info("Starting automatic list updates...") + for { + select { + case <-ctx.Done(): + log.Debug("Stopping automatic list updates") + return + case <-ticker.C: + bgCtx := context.Background() + log.Info("Starting automatic list updates...") - for _, source := range s.blocklistURL { - if source.Name == "Custom" { - continue - } - log.Info("Checking for updates for blocklist %s from %s", source.Name, source.URL) + for _, source := range s.blocklistURL { + if source.Name == "Custom" { + continue + } + log.Info("Checking for updates for blocklist %s from %s", source.Name, source.URL) - availableUpdate, err := s.CheckIfUpdateAvailable(ctx, source.URL, source.Name) - if err != nil { - log.Warning("Failed to check for updates for %s: %v", source.Name, err) - continue + availableUpdate, err := s.CheckIfUpdateAvailable(bgCtx, source.URL, source.Name) + if err != nil { + log.Warning("Failed to check for updates for %s: %v", source.Name, err) + continue + } + + if !availableUpdate.UpdateAvailable { + log.Info("No updates available for %s", source.Name) + continue + } + + if err := s.RemoveSourceAndDomains(bgCtx, source.Name, source.URL); err != nil { + log.Warning("Failed to remove old domains for %s: %v", source.Name, err) + continue + } + + if err := s.FetchAndLoadHosts(bgCtx, source.URL, source.Name); err != nil { + log.Warning("Failed to fetch and load hosts for %s: %v", source.Name, err) + continue + } + + log.Info("Successfully updated %s with %d new domains", source.Name, len(availableUpdate.DiffAdded)) } - if !availableUpdate.UpdateAvailable { - log.Info("No updates available for %s", source.Name) - continue + if err := s.PopulateCache(bgCtx); err != nil { + log.Warning("Failed to populate blocklist cache after auto-update: %v", err) } - - if err := s.RemoveSourceAndDomains(ctx, source.Name, source.URL); err != nil { - log.Warning("Failed to remove old domains for %s: %v", source.Name, err) - continue - } - - if err := s.FetchAndLoadHosts(ctx, source.URL, source.Name); err != nil { - log.Warning("Failed to fetch and load hosts for %s: %v", source.Name, err) - continue - } - - log.Info("Successfully updated %s with %d new domains", source.Name, len(availableUpdate.DiffAdded)) - } - - if err := s.PopulateCache(ctx); err != nil { - log.Warning("Failed to populate blocklist cache after auto-update: %v", err) } } } diff --git a/backend/dns/server/logs.go b/backend/dns/server/logs.go index a41574c..03f06ed 100644 --- a/backend/dns/server/logs.go +++ b/backend/dns/server/logs.go @@ -1,6 +1,7 @@ package server import ( + "context" "encoding/json" model "goaway/backend/dns/server/models" "time" @@ -10,13 +11,18 @@ import ( const batchSize = 1000 -func (s *DNSServer) ProcessLogEntries() { +func (s *DNSServer) ProcessLogEntries(ctx context.Context) { var batch []model.RequestLogEntry ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() for { select { + case <-ctx.Done(): + if len(batch) > 0 { + s.saveBatch(batch) + } + return case entry := <-s.logEntryChannel: log.Debug("%s", entry.String()) if s.WSQueries != nil { @@ -46,18 +52,24 @@ func (s *DNSServer) saveBatch(entries []model.RequestLogEntry) { } // Removes old log entries based on the configured retention period. -func (s *DNSServer) ClearOldEntries() { +func (s *DNSServer) ClearOldEntries(ctx context.Context) { const ( maxRetries = 10 retryDelay = 150 * time.Millisecond cleanupInterval = 5 * time.Minute ) - for { - requestThreshold := ((60 * 60) * 24) * s.Config.Misc.StatisticsRetention - log.Debug("Next cleanup running at %s", time.Now().Add(cleanupInterval).Format(time.DateTime)) - time.Sleep(cleanupInterval) + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() - s.RequestService.DeleteRequestLogsTimebased(s.BlacklistService.Vacuum, requestThreshold, maxRetries, retryDelay) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + requestThreshold := ((60 * 60) * 24) * s.Config.Misc.StatisticsRetention + log.Debug("Next cleanup running at %s", time.Now().Add(cleanupInterval).Format(time.DateTime)) + s.RequestService.DeleteRequestLogsTimebased(s.BlacklistService.Vacuum, requestThreshold, maxRetries, retryDelay) + } } } diff --git a/backend/jobs/background.go b/backend/jobs/background.go index dffb111..3df47e0 100644 --- a/backend/jobs/background.go +++ b/backend/jobs/background.go @@ -1,6 +1,7 @@ package jobs import ( + "context" arp "goaway/backend/dns" "goaway/backend/logging" "goaway/backend/services" @@ -10,11 +11,13 @@ var log = logging.GetLogger() type BackgroundJobs struct { registry *services.ServiceRegistry + ctx context.Context } func NewBackgroundJobs(registry *services.ServiceRegistry) *BackgroundJobs { return &BackgroundJobs{ registry: registry, + ctx: registry.ShutdownContext(), } } @@ -54,7 +57,7 @@ func (b *BackgroundJobs) startScheduledUpdates(readyChan <-chan struct{}) { <-readyChan if b.registry.Context.Config.Misc.ScheduledBlacklistUpdates { log.Debug("Starting scheduler for automatic list updates...") - b.registry.BlacklistService.ScheduleAutomaticListUpdates() + b.registry.BlacklistService.ScheduleAutomaticListUpdates(b.ctx) } }() } @@ -63,7 +66,7 @@ func (b *BackgroundJobs) startCacheCleanup(readyChan <-chan struct{}) { go func() { <-readyChan log.Debug("Starting cache cleanup routine...") - b.registry.Context.DNSServer.ClearOldEntries() + b.registry.Context.DNSServer.ClearOldEntries(b.ctx) }() } @@ -71,6 +74,6 @@ func (b *BackgroundJobs) startPrefetcher(readyChan <-chan struct{}) { go func() { <-readyChan log.Debug("Starting prefetcher...") - b.registry.PrefetchService.Run() + b.registry.PrefetchService.Run(b.ctx) }() } diff --git a/backend/lifecycle/manager.go b/backend/lifecycle/manager.go index 3355c93..09c67ba 100644 --- a/backend/lifecycle/manager.go +++ b/backend/lifecycle/manager.go @@ -1,6 +1,7 @@ package lifecycle import ( + "context" "goaway/backend/api" "goaway/backend/jobs" "goaway/backend/logging" @@ -8,6 +9,7 @@ import ( "os" "os/signal" "syscall" + "time" ) var log = logging.GetLogger() @@ -63,6 +65,49 @@ func (m *Manager) waitForTermination() error { } func (m *Manager) shutdown() { - // TODO: Add graceful shutdown logic + log.Info("Starting graceful shutdown...") + + m.services.Shutdown() + m.services.APIServer.IsShuttingDown = true + + if err := m.services.APIServer.Stop(); err != nil { + log.Error("Error stopping API server: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + m.services.UDPServer.Shutdown(ctx) + log.Info("Stopped UDP server") + + m.services.TCPServer.Shutdown(ctx) + log.Info("Stopped TCP server") + + if m.services.DoTServer != nil { + m.services.DoTServer.Shutdown(ctx) + log.Info("Stopped DNS-over-TLS server") + } + + if m.services.DoHServer != nil { + if err := m.services.DoHServer.Shutdown(ctx); err != nil && err != context.DeadlineExceeded { + log.Error("Error stopping DoH server: %v", err) + } + log.Info("Stopped DNS-over-HTTPS server") + } + + // Wait for all goroutines to finish with timeout + done := make(chan struct{}) + go func() { + m.services.WaitGroup().Wait() + close(done) + }() + + select { + case <-done: + log.Info("All services stopped gracefully") + case <-time.After(15 * time.Second): + log.Warning("Shutdown timeout exceeded, forcing exit") + } + os.Exit(0) } diff --git a/backend/prefetch/service.go b/backend/prefetch/service.go index 207ef4a..10251e6 100644 --- a/backend/prefetch/service.go +++ b/backend/prefetch/service.go @@ -1,6 +1,7 @@ package prefetch import ( + "context" "fmt" "goaway/backend/database" "goaway/backend/dns/server" @@ -30,13 +31,19 @@ func NewService(repo Repository, dnsServer *server.DNSServer) *Service { return service } -func (s *Service) Run() { +func (s *Service) Run(ctx context.Context) { ticker := time.NewTicker(500 * time.Millisecond) defer ticker.Stop() - for range ticker.C { - s.checkNewDomains() - s.processExpiredEntries() + for { + select { + case <-ctx.Done(): + log.Debug("Stopping prefetch service") + return + case <-ticker.C: + s.checkNewDomains() + s.processExpiredEntries() + } } } diff --git a/backend/services/context.go b/backend/services/context.go index 9729e91..1ac1bee 100644 --- a/backend/services/context.go +++ b/backend/services/context.go @@ -1,6 +1,7 @@ package services import ( + "context" "crypto/tls" "fmt" "goaway/backend/database" @@ -45,7 +46,7 @@ func (ctx *AppContext) initialize() error { } ctx.DNSServer = dnsServer - go dnsServer.ProcessLogEntries() + go dnsServer.ProcessLogEntries(context.Background()) return nil } diff --git a/backend/services/registry.go b/backend/services/registry.go index ca98979..acf12cf 100644 --- a/backend/services/registry.go +++ b/backend/services/registry.go @@ -36,10 +36,12 @@ type ServiceRegistry struct { Context *AppContext - version string - date string - commit string - wg sync.WaitGroup + version string + date string + commit string + wg sync.WaitGroup + shutdownCtx context.Context + cancel context.CancelFunc ResolutionService *resolution.Service RequestService *request.Service @@ -57,14 +59,17 @@ type ServiceError struct { } func NewServiceRegistry(ctx *AppContext, version, commit, date string, content embed.FS) *ServiceRegistry { + shutdownCtx, cancel := context.WithCancel(context.Background()) return &ServiceRegistry{ - Context: ctx, - version: version, - commit: commit, - date: date, - content: content, - readyChan: make(chan struct{}), - errorChan: make(chan ServiceError, 10), + Context: ctx, + version: version, + commit: commit, + date: date, + content: content, + readyChan: make(chan struct{}), + errorChan: make(chan ServiceError, 50), + shutdownCtx: shutdownCtx, + cancel: cancel, } } @@ -230,6 +235,14 @@ func (r *ServiceRegistry) ReadyChannel() <-chan struct{} { return r.readyChan } +func (r *ServiceRegistry) ShutdownContext() context.Context { + return r.shutdownCtx +} + +func (r *ServiceRegistry) Shutdown() { + r.cancel() +} + func (r *ServiceRegistry) ErrorChannel() <-chan ServiceError { return r.errorChan }