From 306cb7a20ef42f8eb5920174fa481c791f39d14b Mon Sep 17 00:00:00 2001 From: yusing Date: Sat, 1 Nov 2025 12:07:22 +0800 Subject: [PATCH] fix(access_logger): fix stdout and path not working at the same time --- internal/acl/config.go | 2 +- internal/entrypoint/entrypoint.go | 2 +- internal/logging/accesslog/access_logger.go | 56 ++-- .../logging/accesslog/access_logger_test.go | 2 +- .../logging/accesslog/back_scanner_test.go | 2 +- internal/logging/accesslog/config.go | 14 +- .../logging/accesslog/file_logger_test.go | 4 +- .../logging/accesslog/multi_access_logger.go | 63 +++++ .../accesslog/multi_access_logger_test.go | 261 ++++++++++++++++++ internal/logging/accesslog/rotate_test.go | 14 +- internal/route/fileserver.go | 2 +- 11 files changed, 383 insertions(+), 39 deletions(-) create mode 100644 internal/logging/accesslog/multi_access_logger.go create mode 100644 internal/logging/accesslog/multi_access_logger_test.go diff --git a/internal/acl/config.go b/internal/acl/config.go index 85ccc738..16ede527 100644 --- a/internal/acl/config.go +++ b/internal/acl/config.go @@ -55,7 +55,7 @@ type config struct { logAllowed bool // will be nil if Log is nil - logger *accesslog.AccessLogger + logger accesslog.AccessLogger // will never tick if Notify.To is empty notifyTicker *time.Ticker diff --git a/internal/entrypoint/entrypoint.go b/internal/entrypoint/entrypoint.go index c4fafa2d..ecaadd73 100644 --- a/internal/entrypoint/entrypoint.go +++ b/internal/entrypoint/entrypoint.go @@ -19,7 +19,7 @@ import ( type Entrypoint struct { middleware *middleware.Middleware notFoundHandler http.Handler - accessLogger *accesslog.AccessLogger + accessLogger accesslog.AccessLogger findRouteFunc func(host string) types.HTTPRoute } diff --git a/internal/logging/accesslog/access_logger.go b/internal/logging/accesslog/access_logger.go index 41227f4b..911e7a04 100644 --- a/internal/logging/accesslog/access_logger.go +++ b/internal/logging/accesslog/access_logger.go @@ -20,7 +20,18 @@ import ( ) type ( - AccessLogger struct { + AccessLogger interface { + Log(req *http.Request, res *http.Response) + LogError(req *http.Request, err error) + LogACL(info *maxmind.IPInfo, blocked bool) + + Config() *Config + + Flush() + Close() error + } + + accessLogger struct { task *task.Task cfg *Config @@ -52,6 +63,10 @@ type ( Name() string } + AccessLogRotater interface { + Rotate(result *RotateResult) (rotated bool, err error) + } + RequestFormatter interface { // AppendRequestLog appends a log line to line with or without a trailing newline AppendRequestLog(line []byte, req *http.Request, res *http.Response) []byte @@ -80,25 +95,26 @@ const ( var bytesPool = synk.GetUnsizedBytesPool() -func NewAccessLogger(parent task.Parent, cfg AnyConfig) (*AccessLogger, error) { - io, err := cfg.IO() +func NewAccessLogger(parent task.Parent, cfg AnyConfig) (AccessLogger, error) { + writers, err := cfg.Writers() if err != nil { return nil, err } - return NewAccessLoggerWithIO(parent, io, cfg), nil + + return NewMultiAccessLogger(parent, cfg, writers), nil } -func NewMockAccessLogger(parent task.Parent, cfg *RequestLoggerConfig) *AccessLogger { +func NewMockAccessLogger(parent task.Parent, cfg *RequestLoggerConfig) AccessLogger { return NewAccessLoggerWithIO(parent, NewMockFile(true), cfg) } -func NewAccessLoggerWithIO(parent task.Parent, writer Writer, anyCfg AnyConfig) *AccessLogger { +func NewAccessLoggerWithIO(parent task.Parent, writer Writer, anyCfg AnyConfig) AccessLogger { cfg := anyCfg.ToConfig() if cfg.RotateInterval == 0 { cfg.RotateInterval = defaultRotateInterval } - l := &AccessLogger{ + l := &accessLogger{ task: parent.Subtask("accesslog."+writer.Name(), true), cfg: cfg, bufSize: InitialBufferSize, @@ -138,11 +154,11 @@ func NewAccessLoggerWithIO(parent task.Parent, writer Writer, anyCfg AnyConfig) return l } -func (l *AccessLogger) Config() *Config { +func (l *accessLogger) Config() *Config { return l.cfg } -func (l *AccessLogger) shouldLog(req *http.Request, res *http.Response) bool { +func (l *accessLogger) shouldLog(req *http.Request, res *http.Response) bool { if !l.cfg.req.Filters.StatusCodes.CheckKeep(req, res) || !l.cfg.req.Filters.Method.CheckKeep(req, res) || !l.cfg.req.Filters.Headers.CheckKeep(req, res) || @@ -152,7 +168,7 @@ func (l *AccessLogger) shouldLog(req *http.Request, res *http.Response) bool { return true } -func (l *AccessLogger) Log(req *http.Request, res *http.Response) { +func (l *accessLogger) Log(req *http.Request, res *http.Response) { if !l.shouldLog(req, res) { return } @@ -166,11 +182,11 @@ func (l *AccessLogger) Log(req *http.Request, res *http.Response) { bytesPool.Put(line) } -func (l *AccessLogger) LogError(req *http.Request, err error) { +func (l *accessLogger) LogError(req *http.Request, err error) { l.Log(req, &http.Response{StatusCode: http.StatusInternalServerError, Status: err.Error()}) } -func (l *AccessLogger) LogACL(info *maxmind.IPInfo, blocked bool) { +func (l *accessLogger) LogACL(info *maxmind.IPInfo, blocked bool) { line := bytesPool.Get() line = l.AppendACLLog(line, info, blocked) if line[len(line)-1] != '\n' { @@ -180,11 +196,11 @@ func (l *AccessLogger) LogACL(info *maxmind.IPInfo, blocked bool) { bytesPool.Put(line) } -func (l *AccessLogger) ShouldRotate() bool { +func (l *accessLogger) ShouldRotate() bool { return l.supportRotate != nil && l.cfg.Retention.IsValid() } -func (l *AccessLogger) Rotate(result *RotateResult) (rotated bool, err error) { +func (l *accessLogger) Rotate(result *RotateResult) (rotated bool, err error) { if !l.ShouldRotate() { return false, nil } @@ -197,7 +213,7 @@ func (l *AccessLogger) Rotate(result *RotateResult) (rotated bool, err error) { return } -func (l *AccessLogger) handleErr(err error) { +func (l *accessLogger) handleErr(err error) { if l.errRateLimiter.Allow() { gperr.LogError("failed to write access log", err, &l.logger) } else { @@ -206,7 +222,7 @@ func (l *AccessLogger) handleErr(err error) { } } -func (l *AccessLogger) start() { +func (l *accessLogger) start() { defer func() { l.Flush() l.Close() @@ -242,7 +258,7 @@ func (l *AccessLogger) start() { } } -func (l *AccessLogger) Close() error { +func (l *accessLogger) Close() error { l.writeLock.Lock() defer l.writeLock.Unlock() if l.closed { @@ -253,7 +269,7 @@ func (l *AccessLogger) Close() error { return l.writer.Close() } -func (l *AccessLogger) Flush() { +func (l *accessLogger) Flush() { l.writeLock.Lock() defer l.writeLock.Unlock() if l.closed { @@ -262,7 +278,7 @@ func (l *AccessLogger) Flush() { l.writer.Flush() } -func (l *AccessLogger) write(data []byte) { +func (l *accessLogger) write(data []byte) { l.writeLock.Lock() defer l.writeLock.Unlock() if l.closed { @@ -277,7 +293,7 @@ func (l *AccessLogger) write(data []byte) { atomic.AddInt64(&l.writeCount, int64(n)) } -func (l *AccessLogger) adjustBuffer() { +func (l *accessLogger) adjustBuffer() { wps := int(atomic.SwapInt64(&l.writeCount, 0)) / int(bufferAdjustInterval.Seconds()) origBufSize := l.bufSize newBufSize := origBufSize diff --git a/internal/logging/accesslog/access_logger_test.go b/internal/logging/accesslog/access_logger_test.go index eca91336..a4c94a8c 100644 --- a/internal/logging/accesslog/access_logger_test.go +++ b/internal/logging/accesslog/access_logger_test.go @@ -58,7 +58,7 @@ func fmtLog(cfg *RequestLoggerConfig) (ts string, line string) { t := time.Now() logger := NewMockAccessLogger(testTask, cfg) utils.MockTimeNow(t) - buf = logger.AppendRequestLog(buf, req, resp) + buf = logger.(RequestFormatter).AppendRequestLog(buf, req, resp) return t.Format(LogTimeFormat), string(buf) } diff --git a/internal/logging/accesslog/back_scanner_test.go b/internal/logging/accesslog/back_scanner_test.go index bcb3e72f..fb106838 100644 --- a/internal/logging/accesslog/back_scanner_test.go +++ b/internal/logging/accesslog/back_scanner_test.go @@ -149,7 +149,7 @@ func logEntry() []byte { res := httptest.NewRecorder() // server the request srv.Config.Handler.ServeHTTP(res, req) - b := accesslog.AppendRequestLog(nil, req, res.Result()) + b := accesslog.(RequestFormatter).AppendRequestLog(nil, req, res.Result()) if b[len(b)-1] != '\n' { b = append(b, '\n') } diff --git a/internal/logging/accesslog/config.go b/internal/logging/accesslog/config.go index a73c3ae5..efcf0a60 100644 --- a/internal/logging/accesslog/config.go +++ b/internal/logging/accesslog/config.go @@ -32,7 +32,7 @@ type ( } AnyConfig interface { ToConfig() *Config - IO() (Writer, error) + Writers() ([]Writer, error) } Format string @@ -65,16 +65,20 @@ func (cfg *ConfigBase) Validate() gperr.Error { return nil } -// IO returns a writer for the config. -func (cfg *ConfigBase) IO() (Writer, error) { +// Writers returns a list of writers for the config. +func (cfg *ConfigBase) Writers() ([]Writer, error) { + writers := make([]Writer, 0, 2) if cfg.Path != "" { io, err := NewFileIO(cfg.Path) if err != nil { return nil, err } - return io, nil + writers = append(writers, io) } - return NewStdout(), nil + if cfg.Stdout { + writers = append(writers, NewStdout()) + } + return writers, nil } func (cfg *ACLLoggerConfig) ToConfig() *Config { diff --git a/internal/logging/accesslog/file_logger_test.go b/internal/logging/accesslog/file_logger_test.go index 5d00d21a..0b59f600 100644 --- a/internal/logging/accesslog/file_logger_test.go +++ b/internal/logging/accesslog/file_logger_test.go @@ -55,7 +55,7 @@ func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) { loggerCount := runtime.GOMAXPROCS(0) logCountPerLogger := 10 - loggers := make([]*AccessLogger, loggerCount) + loggers := make([]AccessLogger, loggerCount) for i := range loggerCount { loggers[i] = NewAccessLoggerWithIO(parent, file, cfg) @@ -83,7 +83,7 @@ func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) { } } -func concurrentLog(logger *AccessLogger, req *http.Request, resp *http.Response, n int) { +func concurrentLog(logger AccessLogger, req *http.Request, resp *http.Response, n int) { var wg sync.WaitGroup for range n { wg.Go(func() { diff --git a/internal/logging/accesslog/multi_access_logger.go b/internal/logging/accesslog/multi_access_logger.go new file mode 100644 index 00000000..2df6ed9e --- /dev/null +++ b/internal/logging/accesslog/multi_access_logger.go @@ -0,0 +1,63 @@ +package accesslog + +import ( + "net/http" + + maxmind "github.com/yusing/godoxy/internal/maxmind/types" + "github.com/yusing/goutils/task" +) + +type MultiAccessLogger struct { + accessLoggers []AccessLogger +} + +// NewMultiAccessLogger creates a new AccessLogger that writes to multiple writers. +// +// If there is only one writer, it will return a single AccessLogger. +// Otherwise, it will return a MultiAccessLogger that writes to all the writers. +func NewMultiAccessLogger(parent task.Parent, cfg AnyConfig, writers []Writer) AccessLogger { + if len(writers) == 1 { + return NewAccessLoggerWithIO(parent, writers[0], cfg) + } + + accessLoggers := make([]AccessLogger, len(writers)) + for i, writer := range writers { + accessLoggers[i] = NewAccessLoggerWithIO(parent, writer, cfg) + } + return &MultiAccessLogger{accessLoggers} +} + +func (m *MultiAccessLogger) Config() *Config { + return m.accessLoggers[0].Config() +} + +func (m *MultiAccessLogger) Log(req *http.Request, res *http.Response) { + for _, accessLogger := range m.accessLoggers { + accessLogger.Log(req, res) + } +} + +func (m *MultiAccessLogger) LogError(req *http.Request, err error) { + for _, accessLogger := range m.accessLoggers { + accessLogger.LogError(req, err) + } +} + +func (m *MultiAccessLogger) LogACL(info *maxmind.IPInfo, blocked bool) { + for _, accessLogger := range m.accessLoggers { + accessLogger.LogACL(info, blocked) + } +} + +func (m *MultiAccessLogger) Flush() { + for _, accessLogger := range m.accessLoggers { + accessLogger.Flush() + } +} + +func (m *MultiAccessLogger) Close() error { + for _, accessLogger := range m.accessLoggers { + accessLogger.Close() + } + return nil +} diff --git a/internal/logging/accesslog/multi_access_logger_test.go b/internal/logging/accesslog/multi_access_logger_test.go new file mode 100644 index 00000000..9a3c2308 --- /dev/null +++ b/internal/logging/accesslog/multi_access_logger_test.go @@ -0,0 +1,261 @@ +package accesslog + +import ( + "errors" + "net" + "net/http" + "net/url" + "testing" + + maxmind "github.com/yusing/godoxy/internal/maxmind/types" + "github.com/yusing/goutils/task" + expect "github.com/yusing/goutils/testing" +) + +func TestNewMultiAccessLogger(t *testing.T) { + testTask := task.RootTask("test", false) + cfg := DefaultRequestLoggerConfig() + + writers := []Writer{ + NewMockFile(true), + NewMockFile(true), + } + + logger := NewMultiAccessLogger(testTask, cfg, writers) + expect.NotNil(t, logger) +} + +func TestMultiAccessLoggerConfig(t *testing.T) { + testTask := task.RootTask("test", false) + cfg := DefaultRequestLoggerConfig() + cfg.Format = FormatCommon + + writers := []Writer{ + NewMockFile(true), + NewMockFile(true), + } + + logger := NewMultiAccessLogger(testTask, cfg, writers) + retrievedCfg := logger.Config() + + expect.Equal(t, retrievedCfg.req.Format, FormatCommon) +} + +func TestMultiAccessLoggerLog(t *testing.T) { + testTask := task.RootTask("test", false) + cfg := DefaultRequestLoggerConfig() + cfg.Format = FormatCommon + + writer1 := NewMockFile(true) + writer2 := NewMockFile(true) + writers := []Writer{writer1, writer2} + + logger := NewMultiAccessLogger(testTask, cfg, writers) + + testURL, _ := url.Parse("http://example.com/test") + req := &http.Request{ + RemoteAddr: "192.168.1.1", + Method: http.MethodGet, + Proto: "HTTP/1.1", + Host: "example.com", + URL: testURL, + Header: http.Header{ + "User-Agent": []string{"test-agent"}, + }, + } + resp := &http.Response{ + StatusCode: http.StatusOK, + ContentLength: 100, + } + + logger.Log(req, resp) + logger.Flush() + + expect.Equal(t, writer1.NumLines(), 1) + expect.Equal(t, writer2.NumLines(), 1) +} + +func TestMultiAccessLoggerLogError(t *testing.T) { + testTask := task.RootTask("test", false) + cfg := DefaultRequestLoggerConfig() + + writer1 := NewMockFile(true) + writer2 := NewMockFile(true) + writers := []Writer{writer1, writer2} + + logger := NewMultiAccessLogger(testTask, cfg, writers) + + testURL, _ := url.Parse("http://example.com/test") + req := &http.Request{ + RemoteAddr: "192.168.1.1", + Method: http.MethodGet, + URL: testURL, + } + testErr := errors.New("test error") + + logger.LogError(req, testErr) + logger.Flush() + + expect.Equal(t, writer1.NumLines(), 1) + expect.Equal(t, writer2.NumLines(), 1) +} + +func TestMultiAccessLoggerLogACL(t *testing.T) { + testTask := task.RootTask("test", false) + cfg := DefaultACLLoggerConfig() + cfg.LogAllowed = true + + writer1 := NewMockFile(true) + writer2 := NewMockFile(true) + writers := []Writer{writer1, writer2} + + logger := NewMultiAccessLogger(testTask, cfg, writers) + + info := &maxmind.IPInfo{ + IP: net.ParseIP("192.168.1.1"), + Str: "192.168.1.1", + } + + logger.LogACL(info, false) + logger.Flush() + + expect.Equal(t, writer1.NumLines(), 1) + expect.Equal(t, writer2.NumLines(), 1) +} + +func TestMultiAccessLoggerFlush(t *testing.T) { + testTask := task.RootTask("test", false) + cfg := DefaultRequestLoggerConfig() + + writer1 := NewMockFile(true) + writer2 := NewMockFile(true) + writers := []Writer{writer1, writer2} + + logger := NewMultiAccessLogger(testTask, cfg, writers) + + testURL, _ := url.Parse("http://example.com/test") + req := &http.Request{ + RemoteAddr: "192.168.1.1", + Method: http.MethodGet, + URL: testURL, + } + resp := &http.Response{ + StatusCode: http.StatusOK, + } + + logger.Log(req, resp) + logger.Flush() + + expect.Equal(t, writer1.NumLines(), 1) + expect.Equal(t, writer2.NumLines(), 1) +} + +func TestMultiAccessLoggerClose(t *testing.T) { + testTask := task.RootTask("test", false) + cfg := DefaultRequestLoggerConfig() + + writer1 := NewMockFile(true) + writer2 := NewMockFile(true) + writers := []Writer{writer1, writer2} + + logger := NewMultiAccessLogger(testTask, cfg, writers) + + err := logger.Close() + expect.Nil(t, err) +} + +func TestMultiAccessLoggerMultipleLogs(t *testing.T) { + testTask := task.RootTask("test", false) + cfg := DefaultRequestLoggerConfig() + + writer1 := NewMockFile(true) + writer2 := NewMockFile(true) + writers := []Writer{writer1, writer2} + + logger := NewMultiAccessLogger(testTask, cfg, writers) + + testURL, _ := url.Parse("http://example.com/test") + + for range 3 { + req := &http.Request{ + RemoteAddr: "192.168.1.1", + Method: http.MethodGet, + URL: testURL, + } + resp := &http.Response{ + StatusCode: http.StatusOK, + } + logger.Log(req, resp) + } + + logger.Flush() + + expect.Equal(t, writer1.NumLines(), 3) + expect.Equal(t, writer2.NumLines(), 3) +} + +func TestMultiAccessLoggerSingleWriter(t *testing.T) { + testTask := task.RootTask("test", false) + cfg := DefaultRequestLoggerConfig() + + writer := NewMockFile(true) + writers := []Writer{writer} + + logger := NewMultiAccessLogger(testTask, cfg, writers) + expect.NotNil(t, logger) + + testURL, _ := url.Parse("http://example.com/test") + req := &http.Request{ + RemoteAddr: "192.168.1.1", + Method: http.MethodGet, + URL: testURL, + } + resp := &http.Response{ + StatusCode: http.StatusOK, + } + + logger.Log(req, resp) + logger.Flush() + + expect.Equal(t, writer.NumLines(), 1) +} + +func TestMultiAccessLoggerMixedOperations(t *testing.T) { + testTask := task.RootTask("test", false) + cfg := DefaultRequestLoggerConfig() + + writer1 := NewMockFile(true) + writer2 := NewMockFile(true) + writers := []Writer{writer1, writer2} + + logger := NewMultiAccessLogger(testTask, cfg, writers) + + testURL, _ := url.Parse("http://example.com/test") + + req := &http.Request{ + RemoteAddr: "192.168.1.1", + Method: http.MethodGet, + URL: testURL, + } + resp := &http.Response{ + StatusCode: http.StatusOK, + } + + logger.Log(req, resp) + logger.Flush() + + info := &maxmind.IPInfo{ + IP: net.ParseIP("192.168.1.1"), + Str: "192.168.1.1", + } + + cfg2 := DefaultACLLoggerConfig() + cfg2.LogAllowed = true + aclLogger := NewMultiAccessLogger(testTask, cfg2, writers) + aclLogger.LogACL(info, false) + + logger.Flush() + + expect.Equal(t, writer1.NumLines(), 1) + expect.Equal(t, writer2.NumLines(), 1) +} diff --git a/internal/logging/accesslog/rotate_test.go b/internal/logging/accesslog/rotate_test.go index c370b8e3..3a56eec1 100644 --- a/internal/logging/accesslog/rotate_test.go +++ b/internal/logging/accesslog/rotate_test.go @@ -77,7 +77,7 @@ func TestRotateKeepLast(t *testing.T) { logger.Config().Retention = retention var result RotateResult - rotated, err := logger.Rotate(&result) + rotated, err := logger.(AccessLogRotater).Rotate(&result) expect.NoError(t, err) expect.Equal(t, rotated, true) expect.Equal(t, file.NumLines(), int(retention.Last)) @@ -107,7 +107,7 @@ func TestRotateKeepLast(t *testing.T) { utils.MockTimeNow(testTime) var result RotateResult - rotated, err := logger.Rotate(&result) + rotated, err := logger.(AccessLogRotater).Rotate(&result) expect.NoError(t, err) expect.Equal(t, rotated, true) expect.Equal(t, file.NumLines(), int(retention.Days)) @@ -153,7 +153,7 @@ func TestRotateKeepFileSize(t *testing.T) { utils.MockTimeNow(testTime) var result RotateResult - rotated, err := logger.Rotate(&result) + rotated, err := logger.(AccessLogRotater).Rotate(&result) expect.NoError(t, err) // file should be untouched as 100KB > 10 lines * bytes per line @@ -185,7 +185,7 @@ func TestRotateKeepFileSize(t *testing.T) { utils.MockTimeNow(testTime) var result RotateResult - rotated, err := logger.Rotate(&result) + rotated, err := logger.(AccessLogRotater).Rotate(&result) expect.NoError(t, err) expect.Equal(t, rotated, true) expect.Equal(t, result.NumBytesKeep, int64(retention.KeepSize)) @@ -221,7 +221,7 @@ func TestRotateSkipInvalidTime(t *testing.T) { logger.Config().Retention = retention var result RotateResult - rotated, err := logger.Rotate(&result) + rotated, err := logger.(AccessLogRotater).Rotate(&result) expect.NoError(t, err) expect.Equal(t, rotated, true) // should read one invalid line after every valid line @@ -260,7 +260,7 @@ func BenchmarkRotate(b *testing.B) { _, _ = file.Write(content) b.StartTimer() var result RotateResult - _, _ = logger.Rotate(&result) + _, _ = logger.(AccessLogRotater).Rotate(&result) } }) } @@ -297,7 +297,7 @@ func BenchmarkRotateWithInvalidTime(b *testing.B) { _, _ = file.Write(content) b.StartTimer() var result RotateResult - _, _ = logger.Rotate(&result) + _, _ = logger.(AccessLogRotater).Rotate(&result) } }) } diff --git a/internal/route/fileserver.go b/internal/route/fileserver.go index 51da6cec..3b020474 100644 --- a/internal/route/fileserver.go +++ b/internal/route/fileserver.go @@ -20,7 +20,7 @@ type ( middleware *middleware.Middleware handler http.Handler - accessLogger *accesslog.AccessLogger + accessLogger accesslog.AccessLogger } )