From e2e734ec9a6541f8161315a646eb4103a3db1cef Mon Sep 17 00:00:00 2001 From: Abhishek Shroff Date: Sat, 16 Mar 2024 23:38:56 +0530 Subject: [PATCH] Include and tweak http.ServeContent to get rid of Seeker --- internal/handler_webdav/adapter.go | 2 +- internal/handler_webdav/resource_info.go | 4 +- internal/library/library.go | 25 +- internal/storage/local_storage.go | 15 +- internal/storage/storage.go | 5 +- internal/webdav/file.go | 4 +- internal/webdav/serve_resource.go | 398 +++++++++++++++++++++++ internal/webdav/webdav.go | 9 +- 8 files changed, 433 insertions(+), 29 deletions(-) create mode 100644 internal/webdav/serve_resource.go diff --git a/internal/handler_webdav/adapter.go b/internal/handler_webdav/adapter.go index c4f4845a..660745bd 100644 --- a/internal/handler_webdav/adapter.go +++ b/internal/handler_webdav/adapter.go @@ -59,7 +59,7 @@ func (a adapter) OpenWrite(ctx context.Context, name string) (io.WriteCloser, er } else if resource.Dir { return nil, errors.New("cannot open collection for write") } - return a.lib.Open(ctx, resourceId, true) + return a.lib.OpenWrite(ctx, resourceId) } func (a adapter) RemoveAll(ctx context.Context, name string) error { diff --git a/internal/handler_webdav/resource_info.go b/internal/handler_webdav/resource_info.go index 63d5a7ab..ad985994 100644 --- a/internal/handler_webdav/resource_info.go +++ b/internal/handler_webdav/resource_info.go @@ -43,8 +43,8 @@ func (ri resourceInfo) ContentType() string { } return "application/octet-stream" } -func (ri resourceInfo) OpenRead(ctx context.Context) (io.ReadSeekCloser, error) { - return ri.lib.Open(ctx, ri.resourceID, false) +func (ri resourceInfo) OpenRead(ctx context.Context, start, length int64) (io.ReadCloser, error) { + return ri.lib.OpenRead(ctx, ri.resourceID, start, length) } func (ri resourceInfo) Readdir(ctx context.Context) ([]webdav.ResourceInfo, error) { if !ri.collection { diff --git a/internal/library/library.go b/internal/library/library.go index 620f6ea9..9b8d3699 100644 --- a/internal/library/library.go +++ b/internal/library/library.go @@ -2,6 +2,7 @@ package library import ( "context" + "io" "io/fs" "strings" @@ -17,18 +18,18 @@ type Library struct { cs storage.Storage } -func (l Library) Open(ctx context.Context, id uuid.UUID, write bool) (storage.ReadWriteSeekCloser, error) { - var callback func(int, string) error - if write { - callback = func(len int, etag string) error { - return l.db.Queries().UpdateResourceContents(ctx, sql.UpdateResourceContentsParams{ - ID: id, - Size: pgtype.Int4{Int32: int32(len), Valid: true}, - Etag: pgtype.Text{String: etag, Valid: true}, - }) - } - } - return l.cs.Open(id, write, callback) +func (l Library) OpenRead(ctx context.Context, id uuid.UUID, start, length int64) (io.ReadCloser, error) { + return l.cs.OpenRead(id, start, length) +} + +func (l Library) OpenWrite(ctx context.Context, id uuid.UUID) (io.WriteCloser, error) { + return l.cs.OpenWrite(id, func(len int, etag string) error { + return l.db.Queries().UpdateResourceContents(ctx, sql.UpdateResourceContentsParams{ + ID: id, + Size: pgtype.Int4{Int32: int32(len), Valid: true}, + Etag: pgtype.Text{String: etag, Valid: true}, + }) + }) } func (l Library) ReadDir(ctx context.Context, id uuid.UUID, includeRoot bool, recursive bool) ([]sql.ReadDirRow, error) { diff --git a/internal/storage/local_storage.go b/internal/storage/local_storage.go index 1a827459..b72ba971 100644 --- a/internal/storage/local_storage.go +++ b/internal/storage/local_storage.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "errors" "hash" + "io" "os" "path/filepath" @@ -24,11 +25,19 @@ func newLocalStorage(root string) (Storage, error) { return localStorage(root), nil } -func (l localStorage) Open(id uuid.UUID, write bool, callback func(int, string) error) (ReadWriteSeekCloser, error) { - if !write { - return os.OpenFile(l.path(id), os.O_RDONLY, 0640) +func (l localStorage) OpenRead(id uuid.UUID, start, length int64) (io.ReadCloser, error) { + file, err := os.OpenFile(l.path(id), os.O_RDONLY, 0640) + if err != nil { + return nil, err } + _, err = file.Seek(start, io.SeekStart) + if err != nil { + return nil, err + } + return file, nil +} +func (l localStorage) OpenWrite(id uuid.UUID, callback func(int, string) error) (io.WriteCloser, error) { file, err := os.OpenFile(l.path(id), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0640) if err != nil || callback == nil { return file, err diff --git a/internal/storage/storage.go b/internal/storage/storage.go index 877e1eed..61b17d10 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -1,11 +1,14 @@ package storage import ( + "io" + "github.com/google/uuid" ) type Storage interface { - Open(id uuid.UUID, write bool, callback func(int, string) error) (ReadWriteSeekCloser, error) + OpenRead(id uuid.UUID, start, length int64) (io.ReadCloser, error) + OpenWrite(id uuid.UUID, callback func(int, string) error) (io.WriteCloser, error) Delete(id uuid.UUID) error String() string } diff --git a/internal/webdav/file.go b/internal/webdav/file.go index 8a232268..09c9441d 100644 --- a/internal/webdav/file.go +++ b/internal/webdav/file.go @@ -32,7 +32,7 @@ type ResourceInfo interface { IsDir() bool // abbreviation for Mode().IsDir() ETag() string // entity tag for efficient caching ContentType() string // content type - OpenRead(ctx context.Context) (io.ReadSeekCloser, error) + OpenRead(ctx context.Context, start, length int64) (io.ReadCloser, error) Readdir(ctx context.Context) ([]ResourceInfo, error) } @@ -124,7 +124,7 @@ func copyFiles(ctx context.Context, fs FileSystem, src, dst string, overwrite bo } } else { - srcFile, err := srcStat.OpenRead(ctx) + srcFile, err := srcStat.OpenRead(ctx, 0, -1) if err != nil { if os.IsNotExist(err) { return http.StatusNotFound, err diff --git a/internal/webdav/serve_resource.go b/internal/webdav/serve_resource.go new file mode 100644 index 00000000..387c8d2b --- /dev/null +++ b/internal/webdav/serve_resource.go @@ -0,0 +1,398 @@ +package webdav + +import ( + "errors" + "fmt" + "io" + "net/http" + "net/textproto" + "strconv" + "strings" + "time" +) + +func serveResource(w http.ResponseWriter, r *http.Request, ri ResourceInfo) { + w.Header().Set("Etag", ri.ETag()) + w.Header().Set("Last-Modified", ri.ModTime().Format(http.TimeFormat)) + w.Header().Set("Content-Type", ri.ContentType()) + + done, rangeReq := checkPreconditions(w, r, ri) + if done { + return + } + + code := http.StatusOK + sendSize := ri.Size() + ranges, err := parseRange(rangeReq, ri.Size()) + if err != nil { + w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", ri.Size())) + http.Error(w, err.Error(), http.StatusRequestedRangeNotSatisfiable) + return + } + + var reader io.ReadCloser + if len(ranges) == 1 { + // RFC 7233, Section 4.1: + // "If a single part is being transferred, the server + // generating the 206 response MUST generate a + // Content-Range header field, describing what range + // of the selected representation is enclosed, and a + // payload consisting of the range. + // ... + // A server MUST NOT generate a multipart response to + // a request for a single range, since a client that + // does not request multiple parts might not support + // multipart responses." + ra := ranges[0] + sendSize = ra.length + code = http.StatusPartialContent + w.Header().Set("Content-Range", ra.contentRange(ri.Size())) + reader, err = ri.OpenRead(r.Context(), ra.start, ra.length) + } else { + reader, err = ri.OpenRead(r.Context(), 0, -1) + } + + w.Header().Set("Accept-Ranges", "bytes") + w.Header().Set("Content-Length", strconv.FormatInt(sendSize, 10)) + w.WriteHeader(code) + + if r.Method != "HEAD" { + io.CopyN(w, reader, sendSize) + } +} + +// httpRange specifies the byte range to be sent to the client. +type httpRange struct { + start, length int64 +} + +func (r httpRange) contentRange(size int64) string { + return fmt.Sprintf("bytes %d-%d/%d", r.start, r.start+r.length-1, size) +} + +func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHeader { + return textproto.MIMEHeader{ + "Content-Range": {r.contentRange(size)}, + "Content-Type": {contentType}, + } +} + +// parseRange parses a Range header string as per RFC 7233. +// errNoOverlap is returned if none of the ranges overlap. +func parseRange(s string, size int64) ([]httpRange, error) { + if s == "" { + return nil, nil // header not present + } + const b = "bytes=" + if !strings.HasPrefix(s, b) { + return nil, errors.New("invalid range") + } + var ranges []httpRange + noOverlap := false + for _, ra := range strings.Split(s[len(b):], ",") { + ra = textproto.TrimString(ra) + if ra == "" { + continue + } + start, end, ok := strings.Cut(ra, "-") + if !ok { + return nil, errors.New("invalid range") + } + start, end = textproto.TrimString(start), textproto.TrimString(end) + var r httpRange + if start == "" { + // If no start is specified, end specifies the + // range start relative to the end of the file, + // and we are dealing with + // which has to be a non-negative integer as per + // RFC 7233 Section 2.1 "Byte-Ranges". + if end == "" || end[0] == '-' { + return nil, errors.New("invalid range") + } + i, err := strconv.ParseInt(end, 10, 64) + if i < 0 || err != nil { + return nil, errors.New("invalid range") + } + if i > size { + i = size + } + r.start = size - i + r.length = size - r.start + } else { + i, err := strconv.ParseInt(start, 10, 64) + if err != nil || i < 0 { + return nil, errors.New("invalid range") + } + if i >= size { + // If the range begins after the size of the content, + // then it does not overlap. + noOverlap = true + continue + } + r.start = i + if end == "" { + // If no end is specified, range extends to end of the file. + r.length = size - r.start + } else { + i, err := strconv.ParseInt(end, 10, 64) + if err != nil || r.start > i { + return nil, errors.New("invalid range") + } + if i >= size { + i = size - 1 + } + r.length = i - r.start + 1 + } + } + ranges = append(ranges, r) + } + if noOverlap && len(ranges) == 0 { + if size == 0 { + // Some clients add a Range header to all requests to + // limit the size of the response. If the file is empty, + // ignore the range header + return nil, nil + } + return nil, errors.New("invalid range: failed to overlap") + } + return ranges, nil +} + +// checkPreconditions evaluates request preconditions and reports whether a precondition +// resulted in sending StatusNotModified or StatusPreconditionFailed. +func checkPreconditions(w http.ResponseWriter, r *http.Request, ri ResourceInfo) (done bool, rangeHeader string) { + // This function carefully follows RFC 7232 section 6. + ch := checkIfMatch(r, ri) + if ch == condNone { + ch = checkIfUnmodifiedSince(r, ri) + } + if ch == condFalse { + w.WriteHeader(http.StatusPreconditionFailed) + return true, "" + } + switch checkIfNoneMatch(r, ri) { + case condFalse: + if r.Method == "GET" || r.Method == "HEAD" { + writeNotModified(w) + return true, "" + } else { + w.WriteHeader(http.StatusPreconditionFailed) + return true, "" + } + case condNone: + if checkIfModifiedSince(r, ri) == condFalse { + writeNotModified(w) + return true, "" + } + } + + rangeHeader = r.Header.Get("Range") + if rangeHeader != "" && checkIfRange(r, ri) == condFalse { + rangeHeader = "" + } + return false, rangeHeader +} + +// scanETag determines if a syntactically valid ETag is present at s. If so, +// the ETag and remaining text after consuming ETag is returned. Otherwise, +// it returns "", "". +func scanETag(s string) (etag string, remain string) { + s = textproto.TrimString(s) + start := 0 + if strings.HasPrefix(s, "W/") { + start = 2 + } + if len(s[start:]) < 2 || s[start] != '"' { + return "", "" + } + // ETag is either W/"text" or "text". + // See RFC 7232 2.3. + for i := start + 1; i < len(s); i++ { + c := s[i] + switch { + // Character values allowed in ETags. + case c == 0x21 || c >= 0x23 && c <= 0x7E || c >= 0x80: + case c == '"': + return s[:i+1], s[i+1:] + default: + return "", "" + } + } + return "", "" +} + +// etagStrongMatch reports whether a and b match using strong ETag comparison. +// Assumes a and b are valid ETags. +func etagStrongMatch(a, b string) bool { + return a == b && a != "" && a[0] == '"' +} + +// etagWeakMatch reports whether a and b match using weak ETag comparison. +// Assumes a and b are valid ETags. +func etagWeakMatch(a, b string) bool { + return strings.TrimPrefix(a, "W/") == strings.TrimPrefix(b, "W/") +} + +// condResult is the result of an HTTP request precondition check. +// See https://tools.ietf.org/html/rfc7232 section 3. +type condResult int + +const ( + condNone condResult = iota + condTrue + condFalse +) + +func checkIfMatch(r *http.Request, ri ResourceInfo) condResult { + im := r.Header.Get("If-Match") + if im == "" { + return condNone + } + for { + im = textproto.TrimString(im) + if len(im) == 0 { + break + } + if im[0] == ',' { + im = im[1:] + continue + } + if im[0] == '*' { + return condTrue + } + etag, remain := scanETag(im) + if etag == "" { + break + } + if etagStrongMatch(etag, ri.ETag()) { + return condTrue + } + im = remain + } + + return condFalse +} + +func checkIfUnmodifiedSince(r *http.Request, ri ResourceInfo) condResult { + ius := r.Header.Get("If-Unmodified-Since") + if ius == "" || isZeroTime(ri.ModTime()) { + return condNone + } + t, err := http.ParseTime(ius) + if err != nil { + return condNone + } + + // The Last-Modified header truncates sub-second precision so + // the modtime needs to be truncated too. + modtime := ri.ModTime().Truncate(time.Second) + if ret := modtime.Compare(t); ret <= 0 { + return condTrue + } + return condFalse +} + +func checkIfNoneMatch(r *http.Request, ri ResourceInfo) condResult { + inm := r.Header.Get("If-None-Match") + if inm == "" { + return condNone + } + buf := inm + for { + buf = textproto.TrimString(buf) + if len(buf) == 0 { + break + } + if buf[0] == ',' { + buf = buf[1:] + continue + } + if buf[0] == '*' { + return condFalse + } + etag, remain := scanETag(buf) + if etag == "" { + break + } + if etagWeakMatch(etag, ri.ETag()) { + return condFalse + } + buf = remain + } + return condTrue +} + +func checkIfModifiedSince(r *http.Request, ri ResourceInfo) condResult { + if r.Method != "GET" && r.Method != "HEAD" { + return condNone + } + ims := r.Header.Get("If-Modified-Since") + if ims == "" || isZeroTime(ri.ModTime()) { + return condNone + } + t, err := http.ParseTime(ims) + if err != nil { + return condNone + } + // The Last-Modified header truncates sub-second precision so + // the modtime needs to be truncated too. + modtime := ri.ModTime().Truncate(time.Second) + if ret := modtime.Compare(t); ret <= 0 { + return condFalse + } + return condTrue +} + +func checkIfRange(r *http.Request, ri ResourceInfo) condResult { + if r.Method != "GET" && r.Method != "HEAD" { + return condNone + } + ir := r.Header.Get("If-Range") + if ir == "" { + return condNone + } + etag, _ := scanETag(ir) + if etag != "" { + if etagStrongMatch(etag, ri.ETag()) { + return condTrue + } else { + return condFalse + } + } + // The If-Range value is typically the ETag value, but it may also be + // the modtime date. See golang.org/issue/8367. + if ri.ModTime().IsZero() { + return condFalse + } + t, err := http.ParseTime(ir) + if err != nil { + return condFalse + } + if t.Unix() == ri.ModTime().Unix() { + return condTrue + } + return condFalse +} + +var unixEpochTime = time.Unix(0, 0) + +// isZeroTime reports whether t is obviously unspecified (either zero or Unix()=0). +func isZeroTime(t time.Time) bool { + return t.IsZero() || t.Equal(unixEpochTime) +} + +func writeNotModified(w http.ResponseWriter) { + // RFC 7232 section 4.1: + // a sender SHOULD NOT generate representation metadata other than the + // above listed fields unless said metadata exists for the purpose of + // guiding cache updates (e.g., Last-Modified might be useful if the + // response does not have an ETag field). + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") + delete(h, "Content-Encoding") + if h.Get("Etag") != "" { + delete(h, "Last-Modified") + } + w.WriteHeader(http.StatusNotModified) +} diff --git a/internal/webdav/webdav.go b/internal/webdav/webdav.go index 368d85ff..2587706a 100644 --- a/internal/webdav/webdav.go +++ b/internal/webdav/webdav.go @@ -207,14 +207,7 @@ func (h *Handler) handleGetHeadPost(w http.ResponseWriter, r *http.Request) (sta return http.StatusMethodNotAllowed, nil } - f, err := fi.OpenRead(ctx) - if err != nil { - return http.StatusNotFound, err - } - defer f.Close() - w.Header().Set("ETag", fi.ETag()) - w.Header().Set("Content-Type", fi.ContentType()) - http.ServeContent(w, r, reqPath, fi.ModTime(), f) + serveResource(w, r, fi) return 0, nil }