mirror of
https://github.com/btouchard/ackify.git
synced 2026-02-09 15:28:28 -06:00
refactor(checksum): propagate context for HTTP request cancellation
Add context.Context parameter to checksum computation functions to enable request cancellation, timeout propagation, and better observability for remote document downloads.
This commit is contained in:
@@ -89,7 +89,7 @@ func (s *DocumentService) CreateDocument(ctx context.Context, req CreateDocument
|
||||
|
||||
// Automatically compute checksum for remote URLs if enabled
|
||||
if url != "" && s.checksumConfig != nil {
|
||||
checksumResult := s.computeChecksumForURL(url)
|
||||
checksumResult := s.computeChecksumForURL(ctx, url)
|
||||
if checksumResult != nil {
|
||||
input.Checksum = checksumResult.ChecksumHex
|
||||
input.ChecksumAlgorithm = checksumResult.Algorithm
|
||||
@@ -313,7 +313,7 @@ func isBase64Like(s string) bool {
|
||||
|
||||
// computeChecksumForURL attempts to compute the checksum for a remote URL
|
||||
// Returns nil if the checksum cannot be computed (error, too large, etc.)
|
||||
func (s *DocumentService) computeChecksumForURL(url string) *checksum.Result {
|
||||
func (s *DocumentService) computeChecksumForURL(ctx context.Context, url string) *checksum.Result {
|
||||
if s.checksumConfig == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -327,7 +327,7 @@ func (s *DocumentService) computeChecksumForURL(url string) *checksum.Result {
|
||||
InsecureSkipVerify: s.checksumConfig.InsecureSkipVerify,
|
||||
}
|
||||
|
||||
result, err := checksum.ComputeRemoteChecksum(url, opts)
|
||||
result, err := checksum.ComputeRemoteChecksum(ctx, url, opts)
|
||||
if err != nil {
|
||||
logger.Logger.Warn("Failed to compute checksum for URL",
|
||||
"url", url,
|
||||
@@ -426,7 +426,7 @@ func (s *DocumentService) FindOrCreateDocument(ctx context.Context, ref string)
|
||||
// For URL references, compute checksum before creating
|
||||
if refType == ReferenceTypeURL && s.checksumConfig != nil {
|
||||
logger.Logger.Debug("Computing checksum for URL reference", "url", ref)
|
||||
checksumResult := s.computeChecksumForURL(ref)
|
||||
checksumResult := s.computeChecksumForURL(ctx, ref)
|
||||
if checksumResult != nil {
|
||||
logger.Logger.Info("Automatically computed checksum for URL reference",
|
||||
"url", ref,
|
||||
|
||||
@@ -111,7 +111,7 @@ func (s *SignatureService) CreateSignature(ctx context.Context, request *models.
|
||||
// Continue without checksum - document metadata is optional
|
||||
} else if doc != nil && doc.Checksum != "" {
|
||||
// Verify document hasn't been modified before signing
|
||||
if err := s.verifyDocumentIntegrity(doc); err != nil {
|
||||
if err := s.verifyDocumentIntegrity(ctx, doc); err != nil {
|
||||
logger.Logger.Warn("Document integrity check failed",
|
||||
"doc_id", request.DocID,
|
||||
"error", err.Error())
|
||||
@@ -375,7 +375,7 @@ func (s *SignatureService) RebuildChain(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// verifyDocumentIntegrity checks if the document at the URL hasn't been modified since the checksum was stored
|
||||
func (s *SignatureService) verifyDocumentIntegrity(doc *models.Document) error {
|
||||
func (s *SignatureService) verifyDocumentIntegrity(ctx context.Context, doc *models.Document) error {
|
||||
// Only verify if document has URL and checksum, and checksum config is available
|
||||
if doc.URL == "" || doc.Checksum == "" || s.checksumConfig == nil {
|
||||
logger.Logger.Debug("Skipping document integrity check",
|
||||
@@ -406,7 +406,7 @@ func (s *SignatureService) verifyDocumentIntegrity(doc *models.Document) error {
|
||||
}
|
||||
|
||||
// Compute current checksum
|
||||
result, err := checksum.ComputeRemoteChecksum(doc.URL, opts)
|
||||
result, err := checksum.ComputeRemoteChecksum(ctx, doc.URL, opts)
|
||||
if err != nil {
|
||||
logger.Logger.Error("Failed to compute checksum for integrity check",
|
||||
"doc_id", doc.DocID,
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package checksum
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"encoding/hex"
|
||||
@@ -58,7 +59,12 @@ func DefaultOptions() ComputeOptions {
|
||||
|
||||
// ComputeRemoteChecksum downloads a remote binary file and computes its SHA-256 checksum
|
||||
// Returns nil if the file cannot be processed (too large, wrong type, network error, SSRF blocked)
|
||||
func ComputeRemoteChecksum(urlStr string, opts ComputeOptions) (*Result, error) {
|
||||
// The context is used for request cancellation and timeout propagation.
|
||||
func ComputeRemoteChecksum(ctx context.Context, urlStr string, opts ComputeOptions) (*Result, error) {
|
||||
// Check if context is already cancelled
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, fmt.Errorf("context cancelled before checksum computation: %w", err)
|
||||
}
|
||||
// Validate URL scheme (only HTTPS allowed)
|
||||
if !isValidURL(urlStr) {
|
||||
logger.Logger.Info("Checksum: URL rejected - not HTTPS", "url", urlStr)
|
||||
@@ -101,7 +107,7 @@ func ComputeRemoteChecksum(urlStr string, opts ComputeOptions) (*Result, error)
|
||||
}
|
||||
|
||||
// Step 1: HEAD request to check Content-Type and Content-Length
|
||||
headReq, err := http.NewRequest("HEAD", urlStr, nil)
|
||||
headReq, err := http.NewRequestWithContext(ctx, "HEAD", urlStr, nil)
|
||||
if err != nil {
|
||||
logger.Logger.Warn("Checksum: Failed to create HEAD request", "url", urlStr, "error", err.Error())
|
||||
return nil, nil
|
||||
@@ -112,7 +118,7 @@ func ComputeRemoteChecksum(urlStr string, opts ComputeOptions) (*Result, error)
|
||||
if err != nil {
|
||||
logger.Logger.Info("Checksum: HEAD request failed", "url", urlStr, "error", err.Error())
|
||||
// Fallback: try GET with streaming if HEAD not supported
|
||||
return computeWithStreamedGET(client, urlStr, opts)
|
||||
return computeWithStreamedGET(ctx, client, urlStr, opts)
|
||||
}
|
||||
defer headResp.Body.Close()
|
||||
|
||||
@@ -133,11 +139,11 @@ func ComputeRemoteChecksum(urlStr string, opts ComputeOptions) (*Result, error)
|
||||
// If Content-Length is unknown (0 or -1), fallback to streamed GET
|
||||
if contentLength <= 0 {
|
||||
logger.Logger.Debug("Checksum: Content-Length unknown, using streamed GET", "url", urlStr)
|
||||
return computeWithStreamedGET(client, urlStr, opts)
|
||||
return computeWithStreamedGET(ctx, client, urlStr, opts)
|
||||
}
|
||||
|
||||
// Step 2: GET request to download and compute checksum
|
||||
getReq, err := http.NewRequest("GET", urlStr, nil)
|
||||
getReq, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
|
||||
if err != nil {
|
||||
logger.Logger.Warn("Checksum: Failed to create GET request", "url", urlStr, "error", err.Error())
|
||||
return nil, nil
|
||||
@@ -161,8 +167,8 @@ func ComputeRemoteChecksum(urlStr string, opts ComputeOptions) (*Result, error)
|
||||
}
|
||||
|
||||
// computeWithStreamedGET performs a GET request and computes checksum with hard size limit
|
||||
func computeWithStreamedGET(client *http.Client, urlStr string, opts ComputeOptions) (*Result, error) {
|
||||
getReq, err := http.NewRequest("GET", urlStr, nil)
|
||||
func computeWithStreamedGET(ctx context.Context, client *http.Client, urlStr string, opts ComputeOptions) (*Result, error) {
|
||||
getReq, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
|
||||
if err != nil {
|
||||
logger.Logger.Warn("Checksum: Failed to create GET request (fallback)", "url", urlStr, "error", err.Error())
|
||||
return nil, nil
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package checksum
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -27,7 +28,7 @@ func TestComputeRemoteChecksum_Success(t *testing.T) {
|
||||
opts := DefaultOptions()
|
||||
opts.SkipSSRFCheck = true // For testing with httptest
|
||||
opts.InsecureSkipVerify = true // Accept self-signed certs
|
||||
result, err := ComputeRemoteChecksum(server.URL, opts)
|
||||
result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
@@ -56,7 +57,7 @@ func TestComputeRemoteChecksum_TooLarge(t *testing.T) {
|
||||
opts := DefaultOptions()
|
||||
opts.SkipSSRFCheck = true
|
||||
opts.InsecureSkipVerify = true
|
||||
result, err := ComputeRemoteChecksum(server.URL, opts)
|
||||
result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
@@ -80,7 +81,7 @@ func TestComputeRemoteChecksum_WrongContentType(t *testing.T) {
|
||||
opts := DefaultOptions()
|
||||
opts.SkipSSRFCheck = true
|
||||
opts.InsecureSkipVerify = true
|
||||
result, err := ComputeRemoteChecksum(server.URL, opts)
|
||||
result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
@@ -94,7 +95,7 @@ func TestComputeRemoteChecksum_WrongContentType(t *testing.T) {
|
||||
func TestComputeRemoteChecksum_HTTPNotHTTPS(t *testing.T) {
|
||||
// Test HTTP (not HTTPS) - should be rejected
|
||||
opts := DefaultOptions()
|
||||
result, err := ComputeRemoteChecksum("http://example.com/file.pdf", opts)
|
||||
result, err := ComputeRemoteChecksum(context.Background(), "http://example.com/file.pdf", opts)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
@@ -123,7 +124,7 @@ func TestComputeRemoteChecksum_StreamedGETFallback(t *testing.T) {
|
||||
opts := DefaultOptions()
|
||||
opts.SkipSSRFCheck = true
|
||||
opts.InsecureSkipVerify = true
|
||||
result, err := ComputeRemoteChecksum(server.URL, opts)
|
||||
result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
@@ -155,7 +156,7 @@ func TestComputeRemoteChecksum_ExceedsSizeDuringStreaming(t *testing.T) {
|
||||
opts := DefaultOptions()
|
||||
opts.SkipSSRFCheck = true
|
||||
opts.InsecureSkipVerify = true
|
||||
result, err := ComputeRemoteChecksum(server.URL, opts)
|
||||
result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
@@ -175,7 +176,7 @@ func TestComputeRemoteChecksum_HTTPError(t *testing.T) {
|
||||
opts := DefaultOptions()
|
||||
opts.SkipSSRFCheck = true
|
||||
opts.InsecureSkipVerify = true
|
||||
result, err := ComputeRemoteChecksum(server.URL, opts)
|
||||
result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
@@ -197,7 +198,7 @@ func TestComputeRemoteChecksum_TooManyRedirects(t *testing.T) {
|
||||
opts := DefaultOptions()
|
||||
opts.SkipSSRFCheck = true
|
||||
opts.InsecureSkipVerify = true
|
||||
result, err := ComputeRemoteChecksum(server.URL, opts)
|
||||
result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
@@ -298,7 +299,7 @@ func TestComputeRemoteChecksum_ImageContentType(t *testing.T) {
|
||||
opts := DefaultOptions()
|
||||
opts.SkipSSRFCheck = true
|
||||
opts.InsecureSkipVerify = true
|
||||
result, err := ComputeRemoteChecksum(server.URL, opts)
|
||||
result, err := ComputeRemoteChecksum(context.Background(), server.URL, opts)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
|
||||
Reference in New Issue
Block a user