refactor(checksum): propagate context for HTTP request cancellation

Add context.Context parameter to checksum computation functions
to enable request cancellation, timeout propagation, and better
observability for remote document downloads.
This commit is contained in:
Benjamin
2025-11-23 01:00:05 +01:00
parent ddb44df7d0
commit eecb2565bc
4 changed files with 30 additions and 23 deletions

View File

@@ -89,7 +89,7 @@ func (s *DocumentService) CreateDocument(ctx context.Context, req CreateDocument
// Automatically compute checksum for remote URLs if enabled
if url != "" && s.checksumConfig != nil {
checksumResult := s.computeChecksumForURL(url)
checksumResult := s.computeChecksumForURL(ctx, url)
if checksumResult != nil {
input.Checksum = checksumResult.ChecksumHex
input.ChecksumAlgorithm = checksumResult.Algorithm
@@ -313,7 +313,7 @@ func isBase64Like(s string) bool {
// computeChecksumForURL attempts to compute the checksum for a remote URL
// Returns nil if the checksum cannot be computed (error, too large, etc.)
func (s *DocumentService) computeChecksumForURL(url string) *checksum.Result {
func (s *DocumentService) computeChecksumForURL(ctx context.Context, url string) *checksum.Result {
if s.checksumConfig == nil {
return nil
}
@@ -327,7 +327,7 @@ func (s *DocumentService) computeChecksumForURL(url string) *checksum.Result {
InsecureSkipVerify: s.checksumConfig.InsecureSkipVerify,
}
result, err := checksum.ComputeRemoteChecksum(url, opts)
result, err := checksum.ComputeRemoteChecksum(ctx, url, opts)
if err != nil {
logger.Logger.Warn("Failed to compute checksum for URL",
"url", url,
@@ -426,7 +426,7 @@ func (s *DocumentService) FindOrCreateDocument(ctx context.Context, ref string)
// For URL references, compute checksum before creating
if refType == ReferenceTypeURL && s.checksumConfig != nil {
logger.Logger.Debug("Computing checksum for URL reference", "url", ref)
checksumResult := s.computeChecksumForURL(ref)
checksumResult := s.computeChecksumForURL(ctx, ref)
if checksumResult != nil {
logger.Logger.Info("Automatically computed checksum for URL reference",
"url", ref,

View File

@@ -111,7 +111,7 @@ func (s *SignatureService) CreateSignature(ctx context.Context, request *models.
// Continue without checksum - document metadata is optional
} else if doc != nil && doc.Checksum != "" {
// Verify document hasn't been modified before signing
if err := s.verifyDocumentIntegrity(doc); err != nil {
if err := s.verifyDocumentIntegrity(ctx, doc); err != nil {
logger.Logger.Warn("Document integrity check failed",
"doc_id", request.DocID,
"error", err.Error())
@@ -375,7 +375,7 @@ func (s *SignatureService) RebuildChain(ctx context.Context) error {
}
// verifyDocumentIntegrity checks if the document at the URL hasn't been modified since the checksum was stored
func (s *SignatureService) verifyDocumentIntegrity(doc *models.Document) error {
func (s *SignatureService) verifyDocumentIntegrity(ctx context.Context, doc *models.Document) error {
// Only verify if document has URL and checksum, and checksum config is available
if doc.URL == "" || doc.Checksum == "" || s.checksumConfig == nil {
logger.Logger.Debug("Skipping document integrity check",
@@ -406,7 +406,7 @@ func (s *SignatureService) verifyDocumentIntegrity(doc *models.Document) error {
}
// Compute current checksum
result, err := checksum.ComputeRemoteChecksum(doc.URL, opts)
result, err := checksum.ComputeRemoteChecksum(ctx, doc.URL, opts)
if err != nil {
logger.Logger.Error("Failed to compute checksum for integrity check",
"doc_id", doc.DocID,

View File

@@ -2,6 +2,7 @@
package checksum
import (
"context"
"crypto/sha256"
"crypto/tls"
"encoding/hex"
@@ -58,7 +59,12 @@ func DefaultOptions() ComputeOptions {
// ComputeRemoteChecksum downloads a remote binary file and computes its SHA-256 checksum
// Returns nil if the file cannot be processed (too large, wrong type, network error, SSRF blocked)
func ComputeRemoteChecksum(urlStr string, opts ComputeOptions) (*Result, error) {
// The context is used for request cancellation and timeout propagation.
func ComputeRemoteChecksum(ctx context.Context, urlStr string, opts ComputeOptions) (*Result, error) {
// Check if context is already cancelled
if err := ctx.Err(); err != nil {
return nil, fmt.Errorf("context cancelled before checksum computation: %w", err)
}
// Validate URL scheme (only HTTPS allowed)
if !isValidURL(urlStr) {
logger.Logger.Info("Checksum: URL rejected - not HTTPS", "url", urlStr)
@@ -101,7 +107,7 @@ func ComputeRemoteChecksum(urlStr string, opts ComputeOptions) (*Result, error)
}
// Step 1: HEAD request to check Content-Type and Content-Length
headReq, err := http.NewRequest("HEAD", urlStr, nil)
headReq, err := http.NewRequestWithContext(ctx, "HEAD", urlStr, nil)
if err != nil {
logger.Logger.Warn("Checksum: Failed to create HEAD request", "url", urlStr, "error", err.Error())
return nil, nil
@@ -112,7 +118,7 @@ func ComputeRemoteChecksum(urlStr string, opts ComputeOptions) (*Result, error)
if err != nil {
logger.Logger.Info("Checksum: HEAD request failed", "url", urlStr, "error", err.Error())
// Fallback: try GET with streaming if HEAD not supported
return computeWithStreamedGET(client, urlStr, opts)
return computeWithStreamedGET(ctx, client, urlStr, opts)
}
defer headResp.Body.Close()
@@ -133,11 +139,11 @@ func ComputeRemoteChecksum(urlStr string, opts ComputeOptions) (*Result, error)
// If Content-Length is unknown (0 or -1), fallback to streamed GET
if contentLength <= 0 {
logger.Logger.Debug("Checksum: Content-Length unknown, using streamed GET", "url", urlStr)
return computeWithStreamedGET(client, urlStr, opts)
return computeWithStreamedGET(ctx, client, urlStr, opts)
}
// Step 2: GET request to download and compute checksum
getReq, err := http.NewRequest("GET", urlStr, nil)
getReq, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
if err != nil {
logger.Logger.Warn("Checksum: Failed to create GET request", "url", urlStr, "error", err.Error())
return nil, nil
@@ -161,8 +167,8 @@ func ComputeRemoteChecksum(urlStr string, opts ComputeOptions) (*Result, error)
}
// computeWithStreamedGET performs a GET request and computes checksum with hard size limit
func computeWithStreamedGET(client *http.Client, urlStr string, opts ComputeOptions) (*Result, error) {
getReq, err := http.NewRequest("GET", urlStr, nil)
func computeWithStreamedGET(ctx context.Context, client *http.Client, urlStr string, opts ComputeOptions) (*Result, error) {
getReq, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
if err != nil {
logger.Logger.Warn("Checksum: Failed to create GET request (fallback)", "url", urlStr, "error", err.Error())
return nil, nil

View File

@@ -2,6 +2,7 @@
package checksum
import (
"context"
"fmt"
"net"
"net/http"
@@ -27,7 +28,7 @@ func TestComputeRemoteChecksum_Success(t *testing.T) {
opts := DefaultOptions()
opts.SkipSSRFCheck = true // For testing with httptest
opts.InsecureSkipVerify = true // Accept self-signed certs
result, err := ComputeRemoteChecksum(server.URL, opts)
result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
@@ -56,7 +57,7 @@ func TestComputeRemoteChecksum_TooLarge(t *testing.T) {
opts := DefaultOptions()
opts.SkipSSRFCheck = true
opts.InsecureSkipVerify = true
result, err := ComputeRemoteChecksum(server.URL, opts)
result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
@@ -80,7 +81,7 @@ func TestComputeRemoteChecksum_WrongContentType(t *testing.T) {
opts := DefaultOptions()
opts.SkipSSRFCheck = true
opts.InsecureSkipVerify = true
result, err := ComputeRemoteChecksum(server.URL, opts)
result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
@@ -94,7 +95,7 @@ func TestComputeRemoteChecksum_WrongContentType(t *testing.T) {
func TestComputeRemoteChecksum_HTTPNotHTTPS(t *testing.T) {
// Test HTTP (not HTTPS) - should be rejected
opts := DefaultOptions()
result, err := ComputeRemoteChecksum("http://example.com/file.pdf", opts)
result, err := ComputeRemoteChecksum(context.Background(), "http://example.com/file.pdf", opts)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
@@ -123,7 +124,7 @@ func TestComputeRemoteChecksum_StreamedGETFallback(t *testing.T) {
opts := DefaultOptions()
opts.SkipSSRFCheck = true
opts.InsecureSkipVerify = true
result, err := ComputeRemoteChecksum(server.URL, opts)
result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
@@ -155,7 +156,7 @@ func TestComputeRemoteChecksum_ExceedsSizeDuringStreaming(t *testing.T) {
opts := DefaultOptions()
opts.SkipSSRFCheck = true
opts.InsecureSkipVerify = true
result, err := ComputeRemoteChecksum(server.URL, opts)
result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
@@ -175,7 +176,7 @@ func TestComputeRemoteChecksum_HTTPError(t *testing.T) {
opts := DefaultOptions()
opts.SkipSSRFCheck = true
opts.InsecureSkipVerify = true
result, err := ComputeRemoteChecksum(server.URL, opts)
result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
@@ -197,7 +198,7 @@ func TestComputeRemoteChecksum_TooManyRedirects(t *testing.T) {
opts := DefaultOptions()
opts.SkipSSRFCheck = true
opts.InsecureSkipVerify = true
result, err := ComputeRemoteChecksum(server.URL, opts)
result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
@@ -298,7 +299,7 @@ func TestComputeRemoteChecksum_ImageContentType(t *testing.T) {
opts := DefaultOptions()
opts.SkipSSRFCheck = true
opts.InsecureSkipVerify = true
result, err := ComputeRemoteChecksum(server.URL, opts)
result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts)
if err != nil {
t.Fatalf("Expected no error, got %v", err)