diff --git a/internal/library/library.go b/internal/library/library.go index 14ecd760..620f6ea9 100644 --- a/internal/library/library.go +++ b/internal/library/library.go @@ -3,7 +3,6 @@ package library import ( "context" "io/fs" - "os" "strings" "github.com/google/uuid" @@ -18,10 +17,9 @@ type Library struct { cs storage.Storage } -func (l Library) Open(ctx context.Context, id uuid.UUID, flag int) (storage.ReadWriteSeekCloser, error) { - writeMode := flag&(os.O_RDWR|os.O_WRONLY) != 0 +func (l Library) Open(ctx context.Context, id uuid.UUID, write bool) (storage.ReadWriteSeekCloser, error) { var callback func(int, string) error - if writeMode { + if write { callback = func(len int, etag string) error { return l.db.Queries().UpdateResourceContents(ctx, sql.UpdateResourceContentsParams{ ID: id, @@ -30,7 +28,7 @@ func (l Library) Open(ctx context.Context, id uuid.UUID, flag int) (storage.Read }) } } - return l.cs.Open(id, flag, callback) + return l.cs.Open(id, write, callback) } func (l Library) ReadDir(ctx context.Context, id uuid.UUID, includeRoot bool, recursive bool) ([]sql.ReadDirRow, error) { diff --git a/internal/storage/hasher.go b/internal/storage/hasher.go index 84b3c622..34c21bf7 100644 --- a/internal/storage/hasher.go +++ b/internal/storage/hasher.go @@ -1,31 +1,33 @@ package storage import ( + "errors" "hash" + "io" ) type hasher struct { - rwc ReadWriteSeekCloser + dest io.WriteCloser len int sum hash.Hash closeCallback func(int, hash.Hash, error) error } func (c *hasher) Read(p []byte) (n int, err error) { - return c.rwc.Read(p) -} - -func (c *hasher) Seek(offset int64, whence int) (int64, error) { - return c.rwc.Seek(offset, whence) + return 0, errors.New("read not supported") } func (c *hasher) Write(p []byte) (n int, err error) { - n, err = c.rwc.Write(p) + n, err = c.dest.Write(p) c.sum.Write(p) c.len += n return } -func (c *hasher) Close() error { - return c.closeCallback(c.len, c.sum, c.rwc.Close()) +func (c *hasher) Seek(offset int64, whence int) (int64, error) { + return 0, errors.New("seek not supported") +} + +func (c *hasher) Close() error { + return c.closeCallback(c.len, c.sum, c.dest.Close()) } diff --git a/internal/storage/local_storage.go b/internal/storage/local_storage.go index f69b9c0c..1a827459 100644 --- a/internal/storage/local_storage.go +++ b/internal/storage/local_storage.go @@ -24,17 +24,21 @@ func newLocalStorage(root string) (Storage, error) { return localStorage(root), nil } -func (l localStorage) Open(id uuid.UUID, flag int, writeCallback func(int, string) error) (file ReadWriteSeekCloser, err error) { - file, err = os.OpenFile(l.path(id), flag, 0640) - if err != nil || writeCallback == nil { - return +func (l localStorage) Open(id uuid.UUID, write bool, callback func(int, string) error) (ReadWriteSeekCloser, error) { + if !write { + return os.OpenFile(l.path(id), os.O_RDONLY, 0640) } - return &hasher{rwc: file, sum: md5.New(), closeCallback: func(len int, sum hash.Hash, err error) error { + + file, err := os.OpenFile(l.path(id), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0640) + if err != nil || callback == nil { + return file, err + } + return &hasher{dest: file, sum: md5.New(), closeCallback: func(len int, sum hash.Hash, err error) error { if err != nil { return err } etag := hex.EncodeToString(sum.Sum(nil)) - return writeCallback(len, etag) + return callback(len, etag) }}, nil } diff --git a/internal/storage/storage.go b/internal/storage/storage.go index d9907b51..877e1eed 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -5,7 +5,7 @@ import ( ) type Storage interface { - Open(id uuid.UUID, flag int, writeCallback func(int, string) error) (ReadWriteSeekCloser, error) + Open(id uuid.UUID, write bool, callback func(int, string) error) (ReadWriteSeekCloser, error) Delete(id uuid.UUID) error String() string } diff --git a/internal/webdav/adapter.go b/internal/webdav/adapter.go index 5211d217..bd57f0bd 100644 --- a/internal/webdav/adapter.go +++ b/internal/webdav/adapter.go @@ -6,7 +6,6 @@ import ( "io" "io/fs" "net/http" - "os" "strings" webdav "github.com/emersion/go-webdav" @@ -25,7 +24,7 @@ func (a adapter) Open(ctx context.Context, name string) (io.ReadCloser, error) { if err != nil { return nil, err } - return a.lib.Open(ctx, resource.ID, os.O_RDONLY) + return a.lib.Open(ctx, resource.ID, false) } func (a adapter) Stat(ctx context.Context, name string) (*webdav.FileInfo, error) { @@ -89,7 +88,7 @@ func (a adapter) Create(ctx context.Context, name string) (io.WriteCloser, error } } - return a.lib.Open(ctx, id, os.O_CREATE|os.O_RDWR|os.O_TRUNC) + return a.lib.Open(ctx, id, true) } func (a adapter) RemoveAll(ctx context.Context, name string) error { diff --git a/internal/webdav_xnet/adapter.go b/internal/webdav_xnet/adapter.go index b7280ca2..79900d84 100644 --- a/internal/webdav_xnet/adapter.go +++ b/internal/webdav_xnet/adapter.go @@ -61,7 +61,7 @@ func (a adapter) OpenFile(ctx context.Context, name string, flag int, perm os.Fi } var src storage.ReadWriteSeekCloser if !dir { - src, err = a.lib.Open(ctx, resourceId, flag) + src, err = a.lib.Open(ctx, resourceId, flag&(os.O_RDWR|os.O_WRONLY) != 0) if err != nil { return nil, err } diff --git a/internal/webdav_xnet/file.go b/internal/webdav_xnet/file.go index 265627fa..c68491e7 100644 --- a/internal/webdav_xnet/file.go +++ b/internal/webdav_xnet/file.go @@ -18,18 +18,21 @@ type file struct { func (f file) Seek(offset int64, whence int) (int64, error) { return f.src.Seek(offset, whence) } + func (f file) Read(p []byte) (n int, err error) { return f.src.Read(p) } + +func (f file) Write(p []byte) (n int, err error) { + return f.src.Write(p) +} + func (f file) Close() error { - if f.src == nil { + if f.dir { return nil } return f.src.Close() } -func (f file) Write(p []byte) (n int, err error) { - return f.src.Write(p) -} func (f file) Readdir(count int) ([]fs.FileInfo, error) { if !f.dir {