mirror of
https://github.com/pommee/goaway.git
synced 2026-05-06 08:40:41 -05:00
fix: multiple fixes related to graceful shutdown, routine leaks, error channel overflow and more
This commit is contained in:
+7
-2
@@ -9,6 +9,7 @@ import (
|
||||
"goaway/backend/user"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -98,7 +99,9 @@ func (api *API) updatePassword(c *gin.Context) {
|
||||
Message: logMsg,
|
||||
})
|
||||
go func() {
|
||||
_ = api.DNSServer.AlertService.SendToAll(context.Background(), alert.Message{
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = api.DNSServer.AlertService.SendToAll(ctx, alert.Message{
|
||||
Title: "System",
|
||||
Content: logMsg,
|
||||
Severity: SeverityWarning,
|
||||
@@ -135,7 +138,9 @@ func (api *API) createAPIKey(c *gin.Context) {
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = api.DNSServer.AlertService.SendToAll(context.Background(), alert.Message{
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = api.DNSServer.AlertService.SendToAll(ctx, alert.Message{
|
||||
Title: "System",
|
||||
Content: fmt.Sprintf("New API key created with the name '%s'", request.Name),
|
||||
Severity: SeverityWarning,
|
||||
|
||||
@@ -747,46 +747,52 @@ func (s *Service) Vacuum(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) ScheduleAutomaticListUpdates() {
|
||||
func (s *Service) ScheduleAutomaticListUpdates(ctx context.Context) {
|
||||
ticker := time.NewTicker(s.config.UpdateInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
ctx := context.Background()
|
||||
log.Info("Starting automatic list updates...")
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Debug("Stopping automatic list updates")
|
||||
return
|
||||
case <-ticker.C:
|
||||
bgCtx := context.Background()
|
||||
log.Info("Starting automatic list updates...")
|
||||
|
||||
for _, source := range s.blocklistURL {
|
||||
if source.Name == "Custom" {
|
||||
continue
|
||||
}
|
||||
log.Info("Checking for updates for blocklist %s from %s", source.Name, source.URL)
|
||||
for _, source := range s.blocklistURL {
|
||||
if source.Name == "Custom" {
|
||||
continue
|
||||
}
|
||||
log.Info("Checking for updates for blocklist %s from %s", source.Name, source.URL)
|
||||
|
||||
availableUpdate, err := s.CheckIfUpdateAvailable(ctx, source.URL, source.Name)
|
||||
if err != nil {
|
||||
log.Warning("Failed to check for updates for %s: %v", source.Name, err)
|
||||
continue
|
||||
availableUpdate, err := s.CheckIfUpdateAvailable(bgCtx, source.URL, source.Name)
|
||||
if err != nil {
|
||||
log.Warning("Failed to check for updates for %s: %v", source.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !availableUpdate.UpdateAvailable {
|
||||
log.Info("No updates available for %s", source.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.RemoveSourceAndDomains(bgCtx, source.Name, source.URL); err != nil {
|
||||
log.Warning("Failed to remove old domains for %s: %v", source.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.FetchAndLoadHosts(bgCtx, source.URL, source.Name); err != nil {
|
||||
log.Warning("Failed to fetch and load hosts for %s: %v", source.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Info("Successfully updated %s with %d new domains", source.Name, len(availableUpdate.DiffAdded))
|
||||
}
|
||||
|
||||
if !availableUpdate.UpdateAvailable {
|
||||
log.Info("No updates available for %s", source.Name)
|
||||
continue
|
||||
if err := s.PopulateCache(bgCtx); err != nil {
|
||||
log.Warning("Failed to populate blocklist cache after auto-update: %v", err)
|
||||
}
|
||||
|
||||
if err := s.RemoveSourceAndDomains(ctx, source.Name, source.URL); err != nil {
|
||||
log.Warning("Failed to remove old domains for %s: %v", source.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.FetchAndLoadHosts(ctx, source.URL, source.Name); err != nil {
|
||||
log.Warning("Failed to fetch and load hosts for %s: %v", source.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Info("Successfully updated %s with %d new domains", source.Name, len(availableUpdate.DiffAdded))
|
||||
}
|
||||
|
||||
if err := s.PopulateCache(ctx); err != nil {
|
||||
log.Warning("Failed to populate blocklist cache after auto-update: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
model "goaway/backend/dns/server/models"
|
||||
"time"
|
||||
@@ -10,13 +11,18 @@ import (
|
||||
|
||||
const batchSize = 1000
|
||||
|
||||
func (s *DNSServer) ProcessLogEntries() {
|
||||
func (s *DNSServer) ProcessLogEntries(ctx context.Context) {
|
||||
var batch []model.RequestLogEntry
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if len(batch) > 0 {
|
||||
s.saveBatch(batch)
|
||||
}
|
||||
return
|
||||
case entry := <-s.logEntryChannel:
|
||||
log.Debug("%s", entry.String())
|
||||
if s.WSQueries != nil {
|
||||
@@ -46,18 +52,24 @@ func (s *DNSServer) saveBatch(entries []model.RequestLogEntry) {
|
||||
}
|
||||
|
||||
// Removes old log entries based on the configured retention period.
|
||||
func (s *DNSServer) ClearOldEntries() {
|
||||
func (s *DNSServer) ClearOldEntries(ctx context.Context) {
|
||||
const (
|
||||
maxRetries = 10
|
||||
retryDelay = 150 * time.Millisecond
|
||||
cleanupInterval = 5 * time.Minute
|
||||
)
|
||||
|
||||
for {
|
||||
requestThreshold := ((60 * 60) * 24) * s.Config.Misc.StatisticsRetention
|
||||
log.Debug("Next cleanup running at %s", time.Now().Add(cleanupInterval).Format(time.DateTime))
|
||||
time.Sleep(cleanupInterval)
|
||||
ticker := time.NewTicker(cleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
s.RequestService.DeleteRequestLogsTimebased(s.BlacklistService.Vacuum, requestThreshold, maxRetries, retryDelay)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
requestThreshold := ((60 * 60) * 24) * s.Config.Misc.StatisticsRetention
|
||||
log.Debug("Next cleanup running at %s", time.Now().Add(cleanupInterval).Format(time.DateTime))
|
||||
s.RequestService.DeleteRequestLogsTimebased(s.BlacklistService.Vacuum, requestThreshold, maxRetries, retryDelay)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package jobs
|
||||
|
||||
import (
|
||||
"context"
|
||||
arp "goaway/backend/dns"
|
||||
"goaway/backend/logging"
|
||||
"goaway/backend/services"
|
||||
@@ -10,11 +11,13 @@ var log = logging.GetLogger()
|
||||
|
||||
type BackgroundJobs struct {
|
||||
registry *services.ServiceRegistry
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewBackgroundJobs(registry *services.ServiceRegistry) *BackgroundJobs {
|
||||
return &BackgroundJobs{
|
||||
registry: registry,
|
||||
ctx: registry.ShutdownContext(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,7 +57,7 @@ func (b *BackgroundJobs) startScheduledUpdates(readyChan <-chan struct{}) {
|
||||
<-readyChan
|
||||
if b.registry.Context.Config.Misc.ScheduledBlacklistUpdates {
|
||||
log.Debug("Starting scheduler for automatic list updates...")
|
||||
b.registry.BlacklistService.ScheduleAutomaticListUpdates()
|
||||
b.registry.BlacklistService.ScheduleAutomaticListUpdates(b.ctx)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -63,7 +66,7 @@ func (b *BackgroundJobs) startCacheCleanup(readyChan <-chan struct{}) {
|
||||
go func() {
|
||||
<-readyChan
|
||||
log.Debug("Starting cache cleanup routine...")
|
||||
b.registry.Context.DNSServer.ClearOldEntries()
|
||||
b.registry.Context.DNSServer.ClearOldEntries(b.ctx)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -71,6 +74,6 @@ func (b *BackgroundJobs) startPrefetcher(readyChan <-chan struct{}) {
|
||||
go func() {
|
||||
<-readyChan
|
||||
log.Debug("Starting prefetcher...")
|
||||
b.registry.PrefetchService.Run()
|
||||
b.registry.PrefetchService.Run(b.ctx)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"goaway/backend/api"
|
||||
"goaway/backend/jobs"
|
||||
"goaway/backend/logging"
|
||||
@@ -8,6 +9,7 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
var log = logging.GetLogger()
|
||||
@@ -63,6 +65,49 @@ func (m *Manager) waitForTermination() error {
|
||||
}
|
||||
|
||||
func (m *Manager) shutdown() {
|
||||
// TODO: Add graceful shutdown logic
|
||||
log.Info("Starting graceful shutdown...")
|
||||
|
||||
m.services.Shutdown()
|
||||
m.services.APIServer.IsShuttingDown = true
|
||||
|
||||
if err := m.services.APIServer.Stop(); err != nil {
|
||||
log.Error("Error stopping API server: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
m.services.UDPServer.Shutdown(ctx)
|
||||
log.Info("Stopped UDP server")
|
||||
|
||||
m.services.TCPServer.Shutdown(ctx)
|
||||
log.Info("Stopped TCP server")
|
||||
|
||||
if m.services.DoTServer != nil {
|
||||
m.services.DoTServer.Shutdown(ctx)
|
||||
log.Info("Stopped DNS-over-TLS server")
|
||||
}
|
||||
|
||||
if m.services.DoHServer != nil {
|
||||
if err := m.services.DoHServer.Shutdown(ctx); err != nil && err != context.DeadlineExceeded {
|
||||
log.Error("Error stopping DoH server: %v", err)
|
||||
}
|
||||
log.Info("Stopped DNS-over-HTTPS server")
|
||||
}
|
||||
|
||||
// Wait for all goroutines to finish with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
m.services.WaitGroup().Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
log.Info("All services stopped gracefully")
|
||||
case <-time.After(15 * time.Second):
|
||||
log.Warning("Shutdown timeout exceeded, forcing exit")
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package prefetch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"goaway/backend/database"
|
||||
"goaway/backend/dns/server"
|
||||
@@ -30,13 +31,19 @@ func NewService(repo Repository, dnsServer *server.DNSServer) *Service {
|
||||
return service
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
func (s *Service) Run(ctx context.Context) {
|
||||
ticker := time.NewTicker(500 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
s.checkNewDomains()
|
||||
s.processExpiredEntries()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Debug("Stopping prefetch service")
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.checkNewDomains()
|
||||
s.processExpiredEntries()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"goaway/backend/database"
|
||||
@@ -45,7 +46,7 @@ func (ctx *AppContext) initialize() error {
|
||||
}
|
||||
ctx.DNSServer = dnsServer
|
||||
|
||||
go dnsServer.ProcessLogEntries()
|
||||
go dnsServer.ProcessLogEntries(context.Background())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -36,10 +36,12 @@ type ServiceRegistry struct {
|
||||
|
||||
Context *AppContext
|
||||
|
||||
version string
|
||||
date string
|
||||
commit string
|
||||
wg sync.WaitGroup
|
||||
version string
|
||||
date string
|
||||
commit string
|
||||
wg sync.WaitGroup
|
||||
shutdownCtx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
ResolutionService *resolution.Service
|
||||
RequestService *request.Service
|
||||
@@ -57,14 +59,17 @@ type ServiceError struct {
|
||||
}
|
||||
|
||||
func NewServiceRegistry(ctx *AppContext, version, commit, date string, content embed.FS) *ServiceRegistry {
|
||||
shutdownCtx, cancel := context.WithCancel(context.Background())
|
||||
return &ServiceRegistry{
|
||||
Context: ctx,
|
||||
version: version,
|
||||
commit: commit,
|
||||
date: date,
|
||||
content: content,
|
||||
readyChan: make(chan struct{}),
|
||||
errorChan: make(chan ServiceError, 10),
|
||||
Context: ctx,
|
||||
version: version,
|
||||
commit: commit,
|
||||
date: date,
|
||||
content: content,
|
||||
readyChan: make(chan struct{}),
|
||||
errorChan: make(chan ServiceError, 50),
|
||||
shutdownCtx: shutdownCtx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -230,6 +235,14 @@ func (r *ServiceRegistry) ReadyChannel() <-chan struct{} {
|
||||
return r.readyChan
|
||||
}
|
||||
|
||||
func (r *ServiceRegistry) ShutdownContext() context.Context {
|
||||
return r.shutdownCtx
|
||||
}
|
||||
|
||||
func (r *ServiceRegistry) Shutdown() {
|
||||
r.cancel()
|
||||
}
|
||||
|
||||
func (r *ServiceRegistry) ErrorChannel() <-chan ServiceError {
|
||||
return r.errorChan
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user