mirror of
https://github.com/MizuchiLabs/mantrae.git
synced 2025-12-16 20:05:17 -06:00
91 lines
2.3 KiB
Go
91 lines
2.3 KiB
Go
package middlewares
|
|
|
|
import (
|
|
"context"
|
|
"log/slog"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"connectrpc.com/connect"
|
|
)
|
|
|
|
// statusRecorder is a wrapper around http.ResponseWriter to capture the status code
|
|
type statusRecorder struct {
|
|
http.ResponseWriter
|
|
statusCode int
|
|
}
|
|
|
|
// WriteHeader captures the status code and writes the header
|
|
func (rec *statusRecorder) WriteHeader(code int) {
|
|
rec.statusCode = code
|
|
rec.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
// Implement the http.Flusher interface to forward Flush calls to the underlying ResponseWriter
|
|
func (rec *statusRecorder) Flush() {
|
|
// Check if the underlying ResponseWriter supports flushing
|
|
if flusher, ok := rec.ResponseWriter.(http.Flusher); ok {
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
|
|
// Logger middleware to log HTTP requests
|
|
func (h *MiddlewareHandler) Logger(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if strings.HasPrefix(r.URL.Path, "/_app/") || r.URL.Path == "/favicon.ico" {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
start := time.Now()
|
|
rec := &statusRecorder{ResponseWriter: w, statusCode: http.StatusOK}
|
|
next.ServeHTTP(rec, r)
|
|
duration := time.Since(start)
|
|
|
|
level := slog.LevelDebug
|
|
switch {
|
|
case rec.statusCode >= 500:
|
|
level = slog.LevelError
|
|
case rec.statusCode >= 400:
|
|
level = slog.LevelWarn
|
|
}
|
|
|
|
slog.Log(r.Context(), level, "http_request",
|
|
slog.String("method", r.Method),
|
|
slog.String("url", r.URL.Path),
|
|
slog.Int("status", rec.statusCode),
|
|
slog.String("protocol", r.Proto),
|
|
slog.Int64("duration_ms", duration.Milliseconds()),
|
|
)
|
|
})
|
|
}
|
|
|
|
func Logging() connect.UnaryInterceptorFunc {
|
|
return func(next connect.UnaryFunc) connect.UnaryFunc {
|
|
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
|
|
if req.Spec().Procedure == "HealthCheck" {
|
|
return next(ctx, req)
|
|
}
|
|
start := time.Now()
|
|
resp, err := next(ctx, req)
|
|
duration := time.Since(start)
|
|
|
|
logger := slog.With(
|
|
slog.String("method", req.Spec().Procedure),
|
|
slog.String("peer", req.Peer().Addr),
|
|
slog.String("protocol", req.Peer().Protocol),
|
|
slog.Int64("duration_ms", duration.Milliseconds()),
|
|
)
|
|
|
|
if err != nil {
|
|
logger.With(slog.String("error", err.Error())).Error("rpc_call")
|
|
} else {
|
|
logger.Debug("rpc_call")
|
|
}
|
|
|
|
return resp, err
|
|
}
|
|
}
|
|
}
|