From eecb2565bc6d8ac05b2295ebed711c6787a907ae Mon Sep 17 00:00:00 2001 From: Benjamin Date: Sun, 23 Nov 2025 01:00:05 +0100 Subject: [PATCH] refactor(checksum): propagate context for HTTP request cancellation Add context.Context parameter to checksum computation functions to enable request cancellation, timeout propagation, and better observability for remote document downloads. --- .../application/services/document_service.go | 8 ++++---- .../application/services/signature.go | 6 +++--- backend/pkg/checksum/remote_checksum.go | 20 ++++++++++++------- backend/pkg/checksum/remote_checksum_test.go | 19 +++++++++--------- 4 files changed, 30 insertions(+), 23 deletions(-) diff --git a/backend/internal/application/services/document_service.go b/backend/internal/application/services/document_service.go index 411083c..05791bb 100644 --- a/backend/internal/application/services/document_service.go +++ b/backend/internal/application/services/document_service.go @@ -89,7 +89,7 @@ func (s *DocumentService) CreateDocument(ctx context.Context, req CreateDocument // Automatically compute checksum for remote URLs if enabled if url != "" && s.checksumConfig != nil { - checksumResult := s.computeChecksumForURL(url) + checksumResult := s.computeChecksumForURL(ctx, url) if checksumResult != nil { input.Checksum = checksumResult.ChecksumHex input.ChecksumAlgorithm = checksumResult.Algorithm @@ -313,7 +313,7 @@ func isBase64Like(s string) bool { // computeChecksumForURL attempts to compute the checksum for a remote URL // Returns nil if the checksum cannot be computed (error, too large, etc.) -func (s *DocumentService) computeChecksumForURL(url string) *checksum.Result { +func (s *DocumentService) computeChecksumForURL(ctx context.Context, url string) *checksum.Result { if s.checksumConfig == nil { return nil } @@ -327,7 +327,7 @@ func (s *DocumentService) computeChecksumForURL(url string) *checksum.Result { InsecureSkipVerify: s.checksumConfig.InsecureSkipVerify, } - result, err := checksum.ComputeRemoteChecksum(url, opts) + result, err := checksum.ComputeRemoteChecksum(ctx, url, opts) if err != nil { logger.Logger.Warn("Failed to compute checksum for URL", "url", url, @@ -426,7 +426,7 @@ func (s *DocumentService) FindOrCreateDocument(ctx context.Context, ref string) // For URL references, compute checksum before creating if refType == ReferenceTypeURL && s.checksumConfig != nil { logger.Logger.Debug("Computing checksum for URL reference", "url", ref) - checksumResult := s.computeChecksumForURL(ref) + checksumResult := s.computeChecksumForURL(ctx, ref) if checksumResult != nil { logger.Logger.Info("Automatically computed checksum for URL reference", "url", ref, diff --git a/backend/internal/application/services/signature.go b/backend/internal/application/services/signature.go index c9855ad..05ff463 100644 --- a/backend/internal/application/services/signature.go +++ b/backend/internal/application/services/signature.go @@ -111,7 +111,7 @@ func (s *SignatureService) CreateSignature(ctx context.Context, request *models. // Continue without checksum - document metadata is optional } else if doc != nil && doc.Checksum != "" { // Verify document hasn't been modified before signing - if err := s.verifyDocumentIntegrity(doc); err != nil { + if err := s.verifyDocumentIntegrity(ctx, doc); err != nil { logger.Logger.Warn("Document integrity check failed", "doc_id", request.DocID, "error", err.Error()) @@ -375,7 +375,7 @@ func (s *SignatureService) RebuildChain(ctx context.Context) error { } // verifyDocumentIntegrity checks if the document at the URL hasn't been modified since the checksum was stored -func (s *SignatureService) verifyDocumentIntegrity(doc *models.Document) error { +func (s *SignatureService) verifyDocumentIntegrity(ctx context.Context, doc *models.Document) error { // Only verify if document has URL and checksum, and checksum config is available if doc.URL == "" || doc.Checksum == "" || s.checksumConfig == nil { logger.Logger.Debug("Skipping document integrity check", @@ -406,7 +406,7 @@ func (s *SignatureService) verifyDocumentIntegrity(doc *models.Document) error { } // Compute current checksum - result, err := checksum.ComputeRemoteChecksum(doc.URL, opts) + result, err := checksum.ComputeRemoteChecksum(ctx, doc.URL, opts) if err != nil { logger.Logger.Error("Failed to compute checksum for integrity check", "doc_id", doc.DocID, diff --git a/backend/pkg/checksum/remote_checksum.go b/backend/pkg/checksum/remote_checksum.go index 3635f39..5417590 100644 --- a/backend/pkg/checksum/remote_checksum.go +++ b/backend/pkg/checksum/remote_checksum.go @@ -2,6 +2,7 @@ package checksum import ( + "context" "crypto/sha256" "crypto/tls" "encoding/hex" @@ -58,7 +59,12 @@ func DefaultOptions() ComputeOptions { // ComputeRemoteChecksum downloads a remote binary file and computes its SHA-256 checksum // Returns nil if the file cannot be processed (too large, wrong type, network error, SSRF blocked) -func ComputeRemoteChecksum(urlStr string, opts ComputeOptions) (*Result, error) { +// The context is used for request cancellation and timeout propagation. +func ComputeRemoteChecksum(ctx context.Context, urlStr string, opts ComputeOptions) (*Result, error) { + // Check if context is already cancelled + if err := ctx.Err(); err != nil { + return nil, fmt.Errorf("context cancelled before checksum computation: %w", err) + } // Validate URL scheme (only HTTPS allowed) if !isValidURL(urlStr) { logger.Logger.Info("Checksum: URL rejected - not HTTPS", "url", urlStr) @@ -101,7 +107,7 @@ func ComputeRemoteChecksum(urlStr string, opts ComputeOptions) (*Result, error) } // Step 1: HEAD request to check Content-Type and Content-Length - headReq, err := http.NewRequest("HEAD", urlStr, nil) + headReq, err := http.NewRequestWithContext(ctx, "HEAD", urlStr, nil) if err != nil { logger.Logger.Warn("Checksum: Failed to create HEAD request", "url", urlStr, "error", err.Error()) return nil, nil @@ -112,7 +118,7 @@ func ComputeRemoteChecksum(urlStr string, opts ComputeOptions) (*Result, error) if err != nil { logger.Logger.Info("Checksum: HEAD request failed", "url", urlStr, "error", err.Error()) // Fallback: try GET with streaming if HEAD not supported - return computeWithStreamedGET(client, urlStr, opts) + return computeWithStreamedGET(ctx, client, urlStr, opts) } defer headResp.Body.Close() @@ -133,11 +139,11 @@ func ComputeRemoteChecksum(urlStr string, opts ComputeOptions) (*Result, error) // If Content-Length is unknown (0 or -1), fallback to streamed GET if contentLength <= 0 { logger.Logger.Debug("Checksum: Content-Length unknown, using streamed GET", "url", urlStr) - return computeWithStreamedGET(client, urlStr, opts) + return computeWithStreamedGET(ctx, client, urlStr, opts) } // Step 2: GET request to download and compute checksum - getReq, err := http.NewRequest("GET", urlStr, nil) + getReq, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil) if err != nil { logger.Logger.Warn("Checksum: Failed to create GET request", "url", urlStr, "error", err.Error()) return nil, nil @@ -161,8 +167,8 @@ func ComputeRemoteChecksum(urlStr string, opts ComputeOptions) (*Result, error) } // computeWithStreamedGET performs a GET request and computes checksum with hard size limit -func computeWithStreamedGET(client *http.Client, urlStr string, opts ComputeOptions) (*Result, error) { - getReq, err := http.NewRequest("GET", urlStr, nil) +func computeWithStreamedGET(ctx context.Context, client *http.Client, urlStr string, opts ComputeOptions) (*Result, error) { + getReq, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil) if err != nil { logger.Logger.Warn("Checksum: Failed to create GET request (fallback)", "url", urlStr, "error", err.Error()) return nil, nil diff --git a/backend/pkg/checksum/remote_checksum_test.go b/backend/pkg/checksum/remote_checksum_test.go index 05edea2..6ead43a 100644 --- a/backend/pkg/checksum/remote_checksum_test.go +++ b/backend/pkg/checksum/remote_checksum_test.go @@ -2,6 +2,7 @@ package checksum import ( + "context" "fmt" "net" "net/http" @@ -27,7 +28,7 @@ func TestComputeRemoteChecksum_Success(t *testing.T) { opts := DefaultOptions() opts.SkipSSRFCheck = true // For testing with httptest opts.InsecureSkipVerify = true // Accept self-signed certs - result, err := ComputeRemoteChecksum(server.URL, opts) + result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts) if err != nil { t.Fatalf("Expected no error, got %v", err) @@ -56,7 +57,7 @@ func TestComputeRemoteChecksum_TooLarge(t *testing.T) { opts := DefaultOptions() opts.SkipSSRFCheck = true opts.InsecureSkipVerify = true - result, err := ComputeRemoteChecksum(server.URL, opts) + result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts) if err != nil { t.Fatalf("Expected no error, got %v", err) @@ -80,7 +81,7 @@ func TestComputeRemoteChecksum_WrongContentType(t *testing.T) { opts := DefaultOptions() opts.SkipSSRFCheck = true opts.InsecureSkipVerify = true - result, err := ComputeRemoteChecksum(server.URL, opts) + result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts) if err != nil { t.Fatalf("Expected no error, got %v", err) @@ -94,7 +95,7 @@ func TestComputeRemoteChecksum_WrongContentType(t *testing.T) { func TestComputeRemoteChecksum_HTTPNotHTTPS(t *testing.T) { // Test HTTP (not HTTPS) - should be rejected opts := DefaultOptions() - result, err := ComputeRemoteChecksum("http://example.com/file.pdf", opts) + result, err := ComputeRemoteChecksum(context.Background(), "http://example.com/file.pdf", opts) if err != nil { t.Fatalf("Expected no error, got %v", err) @@ -123,7 +124,7 @@ func TestComputeRemoteChecksum_StreamedGETFallback(t *testing.T) { opts := DefaultOptions() opts.SkipSSRFCheck = true opts.InsecureSkipVerify = true - result, err := ComputeRemoteChecksum(server.URL, opts) + result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts) if err != nil { t.Fatalf("Expected no error, got %v", err) @@ -155,7 +156,7 @@ func TestComputeRemoteChecksum_ExceedsSizeDuringStreaming(t *testing.T) { opts := DefaultOptions() opts.SkipSSRFCheck = true opts.InsecureSkipVerify = true - result, err := ComputeRemoteChecksum(server.URL, opts) + result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts) if err != nil { t.Fatalf("Expected no error, got %v", err) @@ -175,7 +176,7 @@ func TestComputeRemoteChecksum_HTTPError(t *testing.T) { opts := DefaultOptions() opts.SkipSSRFCheck = true opts.InsecureSkipVerify = true - result, err := ComputeRemoteChecksum(server.URL, opts) + result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts) if err != nil { t.Fatalf("Expected no error, got %v", err) @@ -197,7 +198,7 @@ func TestComputeRemoteChecksum_TooManyRedirects(t *testing.T) { opts := DefaultOptions() opts.SkipSSRFCheck = true opts.InsecureSkipVerify = true - result, err := ComputeRemoteChecksum(server.URL, opts) + result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts) if err != nil { t.Fatalf("Expected no error, got %v", err) @@ -298,7 +299,7 @@ func TestComputeRemoteChecksum_ImageContentType(t *testing.T) { opts := DefaultOptions() opts.SkipSSRFCheck = true opts.InsecureSkipVerify = true - result, err := ComputeRemoteChecksum(server.URL, opts) + result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts) if err != nil { t.Fatalf("Expected no error, got %v", err)