diff --git a/server/internal/api/authenticator/authenticator.go b/server/internal/api/authenticator/authenticator.go index cbdb43e4..65e62518 100644 --- a/server/internal/api/authenticator/authenticator.go +++ b/server/internal/api/authenticator/authenticator.go @@ -7,13 +7,13 @@ import ( "codeberg.org/shroff/phylum/server/internal/auth" "codeberg.org/shroff/phylum/server/internal/core" + "codeberg.org/shroff/phylum/server/internal/db" "github.com/gin-gonic/gin" ) var errAuthRequired = core.NewError(http.StatusUnauthorized, "auth_required", "authorization required") const keyUser = "user" -const keyFileSystem = "filesystem" func GetUser(c *gin.Context) core.User { val, ok := c.Get(keyUser) @@ -24,28 +24,24 @@ func GetUser(c *gin.Context) core.User { } func GetFileSystem(c *gin.Context) core.FileSystem { - val, ok := c.Get(keyFileSystem) - if !ok { - return nil - } - return val.(core.FileSystem) + user := GetUser(c) + return user.OpenFileSystem(db.Get(c.Request.Context())) } func Require(c *gin.Context) { - ctx := c.Request.Context() if u, err := extractUserDetails(c); err != nil { panic(err) } else { c.Set(keyUser, u) - c.Set(keyFileSystem, u.OpenFileSystem(ctx)) } } func extractUserDetails(c *gin.Context) (core.User, error) { + db := db.Get(c.Request.Context()) if header := c.Request.Header.Get("Authorization"); header == "" { if cookie, err := c.Request.Cookie("auth_token"); err == nil { token := cookie.Value - if u, err := auth.ReadAccessToken(c.Request.Context(), token); err == nil { + if u, err := auth.ReadAccessToken(db, token); err == nil { return u, nil } else { return core.User{}, err @@ -56,14 +52,14 @@ func extractUserDetails(c *gin.Context) (core.User, error) { return core.User{}, errAuthRequired } else if authHeader, ok := checkAuthHeader(header, "basic"); ok { if email, password, ok := decodeBasicAuth(authHeader); ok { - if u, err := auth.VerifyUserPassword(c.Request.Context(), email, password); err == nil { + if u, err := auth.VerifyUserPassword(db, email, password); err == nil { return u, nil } else { return core.User{}, err } } } else if token, ok := checkAuthHeader(header, "bearer"); ok { - if u, err := auth.ReadAccessToken(c.Request.Context(), token); err == nil { + if u, err := auth.ReadAccessToken(db, token); err == nil { return u, nil } else { return core.User{}, err diff --git a/server/internal/api/v1/my/bookmarks.go b/server/internal/api/v1/my/bookmarks.go index 13eb5bc3..85d6fff9 100644 --- a/server/internal/api/v1/my/bookmarks.go +++ b/server/internal/api/v1/my/bookmarks.go @@ -65,7 +65,11 @@ func handleBookmarksAddRoute(c *gin.Context) { if err != nil { panic(err) } - b, err := core.AddBookmark(db.Get(c.Request.Context()), u, r, params.Name) + var b core.Bookmark + err = db.Get(c.Request.Context()).RunInTx(func(db db.TxHandler) error { + b, err = core.AddBookmark(db, u, r, params.Name) + return err + }) if err != nil { panic(err) } @@ -84,7 +88,9 @@ func handleBookmarksRemoveRoute(c *gin.Context) { c.ShouldBind(¶ms) u := authenticator.GetUser(c) - if err := core.RemoveBookmark(db.Get(c.Request.Context()), u, params.ID); err != nil { + if err := db.Get(c.Request.Context()).RunInTx(func(db db.TxHandler) error { + return core.RemoveBookmark(db, u, params.ID) + }); err != nil { panic(err) } c.JSON(200, gin.H{}) diff --git a/server/internal/api/v1/my/details.go b/server/internal/api/v1/my/details.go index 055c4339..e2d07fb1 100644 --- a/server/internal/api/v1/my/details.go +++ b/server/internal/api/v1/my/details.go @@ -21,7 +21,7 @@ func handleDetailsUpdateRoute(c *gin.Context) { u := authenticator.GetUser(c) - err = db.Get(c.Request.Context()).RunInTx(func(db db.Handler) error { + err = db.Get(c.Request.Context()).RunInTx(func(db db.TxHandler) error { if params.Name != "" { if err := core.UpdateUserName(db, u, params.Name); err != nil { return err diff --git a/server/internal/api/webdav/handler.go b/server/internal/api/webdav/handler.go index dc8b5bb6..1b663e97 100644 --- a/server/internal/api/webdav/handler.go +++ b/server/internal/api/webdav/handler.go @@ -45,11 +45,12 @@ func (h *handler) HandleRequest(c *gin.Context) { var f core.FileSystem if email, pass, ok := c.Request.BasicAuth(); ok { ctx := c.Request.Context() - if u, err := auth.VerifyUserPassword(ctx, email, pass); err == nil { + db := db.Get(ctx) + if u, err := auth.VerifyUserPassword(db, email, pass); err == nil { authSuccess = true root := c.Param("root") if root[0] == '~' { - id, err := core.UserHome(db.Get(c.Request.Context()), root[1:]) + id, err := core.UserHome(db, root[1:]) if err != nil { if errors.Is(err, core.ErrUserNotFound) { c.AbortWithStatus(http.StatusNotFound) @@ -58,9 +59,9 @@ func (h *handler) HandleRequest(c *gin.Context) { panic(err) } } - f = core.OpenFileSystem(ctx, u, id) + f = core.OpenFileSystem(db, u, id) } else if id, err := uuid.Parse(root); err != nil { - f = core.OpenFileSystem(ctx, u, pgtype.UUID{Bytes: id, Valid: true}) + f = core.OpenFileSystem(db, u, pgtype.UUID{Bytes: id, Valid: true}) } else { c.AbortWithStatus(http.StatusNotFound) return diff --git a/server/internal/auth/auth.go b/server/internal/auth/auth.go index 86f829bc..13bd1d5f 100644 --- a/server/internal/auth/auth.go +++ b/server/internal/auth/auth.go @@ -1,7 +1,6 @@ package auth import ( - "context" "crypto/rand" "encoding/base64" "errors" @@ -28,12 +27,11 @@ var accessTokenValidity = pgtype.Interval{ var ErrCredentialsInvalid = errors.New("invalid credentials") -func VerifyUserPassword(ctx context.Context, email, password string) (core.User, error) { - return verifyUserPassword(db.Get(ctx), email, password) +func VerifyUserPassword(db db.Handler, email, password string) (core.User, error) { + return verifyUserPassword(db, email, password) } -func CreateAccessToken(ctx context.Context, email, password string) (core.User, string, error) { - db := db.Get(ctx) +func CreateAccessToken(db db.TxHandler, email, password string) (core.User, string, error) { if user, err := verifyUserPassword(db, email, password); err != nil { return core.User{}, "", err } else if token, err := insertAccessToken(db, user.ID); err != nil { @@ -43,9 +41,9 @@ func CreateAccessToken(ctx context.Context, email, password string) (core.User, } } -func ReadAccessToken(ctx context.Context, accessToken string) (user core.User, err error) { +func ReadAccessToken(db db.Handler, accessToken string) (user core.User, err error) { const q = `SELECT t.expires, u.id, u.email, u.name, u.permissions, u.home FROM access_tokens t JOIN users u ON t.user_id = u.id WHERE t.id = $1; ` - row := db.Get(ctx).QueryRow(q, accessToken) + row := db.QueryRow(q, accessToken) var expires pgtype.Timestamp err = row.Scan(&expires, &user.ID, &user.Email, &user.Name, &user.Permissions, &user.Home) @@ -59,8 +57,7 @@ func ReadAccessToken(ctx context.Context, accessToken string) (user core.User, e return } -func CreateResetToken(ctx context.Context, email string) (core.User, string, error) { - db := db.Get(ctx) +func CreateResetToken(db db.TxHandler, email string) (core.User, string, error) { user, err := core.UserByEmail(db, email) if err != nil { return core.User{}, "", err @@ -74,54 +71,46 @@ func CreateResetToken(ctx context.Context, email string) (core.User, string, err } -func ResetUserPassword(ctx context.Context, email, resetToken, password string) (core.User, string, error) { - var user core.User - var apiToken string - err := db.Get(ctx).RunInTx(func(db db.Handler) error { - var err error - user, err = core.UserByEmail(db, email) - if err != nil { - return err - } +func ResetUserPassword(db db.TxHandler, email, resetToken, password string) (core.User, string, error) { + user, err := core.UserByEmail(db, email) + if err != nil { + return user, "", err + } - // UpdateUserPassword will ensure the password strength - // Not incorrect to do this before token verification because we are in a transaction. - // TODO: Are there perf implications for this in case of malicious actors? - err = updateUserPassword(db, user.ID, password) - if err != nil { - return err - } + // UpdateUserPassword will ensure the password strength + // Not incorrect to do this before token verification because we are in a transaction. + // TODO: Are there perf implications for this in case of malicious actors? + err = updateUserPassword(db, user.ID, password) + if err != nil { + return core.User{}, "", err + } - const q = `DELETE FROM reset_tokens WHERE user_id = @user_id::INT AND token = @token::TEXT RETURNING expires` - args := pgx.NamedArgs{ - "user_id": user.ID, - "token": resetToken, - "expires": time.Now().Add(resetTokenDuration), - } - row := db.QueryRow(q, args) - var expires pgtype.Timestamp - if err := row.Scan(&expires); err != nil { - if errors.Is(err, pgx.ErrNoRows) { - return ErrCredentialsInvalid - } - return err - } - if time.Now().After(expires.Time) { - return ErrCredentialsInvalid + const q = `DELETE FROM reset_tokens WHERE user_id = @user_id::INT AND token = @token::TEXT RETURNING expires` + args := pgx.NamedArgs{ + "user_id": user.ID, + "token": resetToken, + "expires": time.Now().Add(resetTokenDuration), + } + row := db.QueryRow(q, args) + var expires pgtype.Timestamp + if err := row.Scan(&expires); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + err = ErrCredentialsInvalid } + return core.User{}, "", err + } + if time.Now().After(expires.Time) { + return core.User{}, "", ErrCredentialsInvalid + } - apiToken, err = insertAccessToken(db, user.ID) - return err - }) - + apiToken, err := insertAccessToken(db, user.ID) if err != nil { return core.User{}, "", err } return user, apiToken, nil } -func UpdateUserPassword(ctx context.Context, email, password string) error { - db := db.Get(ctx) +func UpdateUserPassword(db db.TxHandler, email, password string) error { user, err := core.UserByEmail(db, email) if err != nil { return err @@ -129,7 +118,7 @@ func UpdateUserPassword(ctx context.Context, email, password string) error { return updateUserPassword(db, user.ID, password) } -func updateUserPassword(db db.Handler, userID int32, password string) error { +func updateUserPassword(db db.TxHandler, userID int32, password string) error { if err := checkPasswordStrength(password); err != nil { return err } @@ -144,7 +133,7 @@ func updateUserPassword(db db.Handler, userID int32, password string) error { return nil } -func insertAccessToken(db db.Handler, userID int32) (string, error) { +func insertAccessToken(db db.TxHandler, userID int32) (string, error) { const q = `INSERT INTO access_tokens(id, expires, user_id) VALUES ($1::TEXT, NOW() + $2::INTERVAL, $3::INT)` token := generateRandomString(apiTokenLength) @@ -155,7 +144,7 @@ func insertAccessToken(db db.Handler, userID int32) (string, error) { } } -func insertResetToken(db db.Handler, userID int32) (string, error) { +func insertResetToken(db db.TxHandler, userID int32) (string, error) { const q = `INSERT INTO reset_tokens(user_id, token, expires) VALUES (@user_id::INT, @token::TEXT, @expires::TIMESTAMP) ON CONFLICT(user_id) DO UPDATE SET token = @token::TEXT, expires = @expires::TIMESTAMP` diff --git a/server/internal/command/admin/user/invite.go b/server/internal/command/admin/user/invite.go index 14bcfec4..a44566ab 100644 --- a/server/internal/command/admin/user/invite.go +++ b/server/internal/command/admin/user/invite.go @@ -23,16 +23,16 @@ func setupInviteCommand() *cobra.Command { name, _ := cmd.Flags().GetString("name") noCreateHome, _ := cmd.Flags().GetBool("no-create-home") - var u core.User - err := db.Get(context.Background()).RunInTx(func(db db.Handler) error { - if user, err := core.CreateUser(db, email, name, noCreateHome); err != nil { + err := db.Get(context.Background()).RunInTx(func(db db.TxHandler) error { + var user core.User + if u, err := core.CreateUser(db, email, name, noCreateHome); err != nil { return err } else { - u = user + user = u } if b, _ := cmd.Flags().GetBool("no-email"); !b { - if err := mail.SendWelcomeEmail(u); err != nil { + if err := mail.SendWelcomeEmail(user); err != nil { fmt.Println("Use --no-email if you want don't want to try sending the welcome email") return errors.New("unable to send welcome email: " + err.Error()) } @@ -46,7 +46,9 @@ func setupInviteCommand() *cobra.Command { }, } cmd.Flags().StringP("name", "n", "", "Name") + // TODO: #flags/#config cmd.Flags().StringP("user_basedir", "b", "", "Base directory for home") + // TODO: #flags/#config cmd.Flags().BoolP("no-create-home", "M", false, "Do not make home directory") cmd.Flags().Bool("no-email", false, "Do not send email") return cmd diff --git a/server/internal/command/admin/user/mod.go b/server/internal/command/admin/user/mod.go index df16854b..a295972f 100644 --- a/server/internal/command/admin/user/mod.go +++ b/server/internal/command/admin/user/mod.go @@ -41,7 +41,7 @@ func setupModCommand() *cobra.Command { } } - err = db.Get(context.Background()).RunInTx(func(db db.Handler) error { + err = db.Get(context.Background()).RunInTx(func(db db.TxHandler) error { if name != "" { if err := core.UpdateUserName(db, u, name); err != nil { return err diff --git a/server/internal/command/admin/user/passwd.go b/server/internal/command/admin/user/passwd.go index aaffddcf..1def5529 100644 --- a/server/internal/command/admin/user/passwd.go +++ b/server/internal/command/admin/user/passwd.go @@ -7,6 +7,7 @@ import ( "syscall" "codeberg.org/shroff/phylum/server/internal/auth" + "codeberg.org/shroff/phylum/server/internal/db" "github.com/spf13/cobra" "golang.org/x/term" ) @@ -47,8 +48,9 @@ func setupPasswdCommand() *cobra.Command { } } - err = auth.UpdateUserPassword(context.Background(), email, password) - if err != nil { + if err := db.Get(context.Background()).RunInTx(func(db db.TxHandler) error { + return auth.UpdateUserPassword(db, email, password) + }); err != nil { fmt.Println("could not change password: " + err.Error()) os.Exit(1) } diff --git a/server/internal/command/admin/user/permissions.go b/server/internal/command/admin/user/permissions.go index 03f4b6fb..41bb1b72 100644 --- a/server/internal/command/admin/user/permissions.go +++ b/server/internal/command/admin/user/permissions.go @@ -18,8 +18,8 @@ func setupGrantCommand() *cobra.Command { Short: "Grant Permissions", Args: cobra.ExactArgs(2), Run: func(cmd *cobra.Command, args []string) { - db := db.Get(context.Background()) - u, err := core.UserByEmail(db, args[0]) + d := db.Get(context.Background()) + u, err := core.UserByEmail(d, args[0]) if err != nil { fmt.Println("unable to find user :" + err.Error()) os.Exit(1) @@ -29,20 +29,21 @@ func setupGrantCommand() *cobra.Command { if strings.HasPrefix(permString, "0x") { var perm int64 perm, err = strconv.ParseInt(permString[2:], 16, 32) - p = int32(perm) + p = core.UserPermissions(perm) } else { var perm int64 perm, err = strconv.ParseInt(permString, 10, 32) - p = int32(perm) + p = core.UserPermissions(perm) } if err != nil { - fmt.Println("unable to parse permissions " + permString) + fmt.Println("failed to parse permission: " + err.Error()) os.Exit(1) } - core.GrantUserPermissions(db, u, p) - if err != nil { - fmt.Println("unable to update permissions:" + err.Error()) + if err := d.RunInTx(func(db db.TxHandler) error { + return core.GrantUserPermissions(db, u, p) + }); err != nil { + fmt.Println("failed to grant permission: " + err.Error()) os.Exit(1) } }, @@ -55,8 +56,8 @@ func setupRevokeCommand() *cobra.Command { Short: "Revoke Permissions", Args: cobra.ExactArgs(2), Run: func(cmd *cobra.Command, args []string) { - db := db.Get(context.Background()) - u, err := core.UserByEmail(db, args[0]) + d := db.Get(context.Background()) + u, err := core.UserByEmail(d, args[0]) if err != nil { fmt.Println("unable to find user :" + err.Error()) os.Exit(1) @@ -66,20 +67,22 @@ func setupRevokeCommand() *cobra.Command { if strings.HasPrefix(permString, "0x") { var perm int64 perm, err = strconv.ParseInt(permString[2:], 16, 32) - p = int32(perm) + p = core.UserPermissions(perm) } else { var perm int64 perm, err = strconv.ParseInt(permString, 10, 32) - p = int32(perm) + p = core.UserPermissions(perm) } if err != nil { - fmt.Println("unable to parse permissions " + permString) + fmt.Println("failed to parse permission: " + err.Error()) os.Exit(1) } - core.RevokeUserPermissions(db, u, p) - if err != nil { - fmt.Println("unable to update permissions:" + err.Error()) + if err := d.RunInTx(func(db db.TxHandler) error { + // TODO: Accept email directly instead of having to separately fetch the user + return core.RevokeUserPermissions(db, u, p) + }); err != nil { + fmt.Println("failed to revoke permission: " + err.Error()) os.Exit(1) } }, diff --git a/server/internal/command/common/common.go b/server/internal/command/common/common.go index 28b5e24d..64da4223 100644 --- a/server/internal/command/common/common.go +++ b/server/internal/command/common/common.go @@ -35,10 +35,11 @@ func User(cmd *cobra.Command) *core.User { func UserFileSystem(cmd *cobra.Command) core.FileSystem { if f == nil { user := User(cmd) + db := db.Get(context.Background()) if user == nil { - f = core.OpenOmniscient(db.Get(context.Background())) + f = core.OpenOmniscient(db) } else { - f = user.OpenFileSystem(context.Background()) + f = user.OpenFileSystem(db) } } return f diff --git a/server/internal/command/user/bookmarks/cmd.go b/server/internal/command/user/bookmarks/cmd.go index 2a1c771a..36736977 100644 --- a/server/internal/command/user/bookmarks/cmd.go +++ b/server/internal/command/user/bookmarks/cmd.go @@ -68,7 +68,9 @@ func setupRemoveCommand() *cobra.Command { os.Exit(1) } - if err := core.RemoveBookmark(db.Get(context.Background()), *u, r.ID()); err != nil { + if err := db.Get(context.Background()).RunInTx(func(db db.TxHandler) error { + return core.RemoveBookmark(db, *u, r.ID()) + }); err != nil { fmt.Println("unable to remove bookmark: " + err.Error()) os.Exit(1) } @@ -98,7 +100,10 @@ func setupAddCommand() *cobra.Command { name = args[1] } - if _, err := core.AddBookmark(db.Get(context.Background()), *u, r, name); err != nil { + if err := db.Get(context.Background()).RunInTx(func(db db.TxHandler) error { + _, err := core.AddBookmark(db, *u, r, name) + return err + }); err != nil { fmt.Println("unable to add bookmark: " + err.Error()) os.Exit(1) } diff --git a/server/internal/core/core.go b/server/internal/core/core.go index 99c08d44..c5979849 100644 --- a/server/internal/core/core.go +++ b/server/internal/core/core.go @@ -18,13 +18,9 @@ func init() { goqu.SetDefaultPrepared(true) } -// filesystem.go type FileSystem interface { - RunInTx(func(FileSystem) error) error - // resource_create.go CreateResourceByPath(path string, id uuid.UUID, dir, createParents bool, conflictResolution ResourceBindConflictResolution) (Resource, error) - CreateResources(args []CreateResourcesParams) (int64, error) // resource_locate.go ResourceByID(uuid.UUID) (Resource, error) diff --git a/server/internal/core/disk_usage.go b/server/internal/core/disk_usage.go index 8afc8982..828f8aaf 100644 --- a/server/internal/core/disk_usage.go +++ b/server/internal/core/disk_usage.go @@ -11,7 +11,7 @@ type DiskUsageInfo struct { Dirs int64 } -func (f filesystem) DiskUsage(r Resource) (DiskUsageInfo, error) { +func (f fileSystem) DiskUsage(r Resource) (DiskUsageInfo, error) { // TODO: #versions This is broken n, q := selectResourceTree(r.id, false, "content_length", "dir") diff --git a/server/internal/core/filesystem.go b/server/internal/core/filesystem.go index c1f8d15f..b2c5a7c8 100644 --- a/server/internal/core/filesystem.go +++ b/server/internal/core/filesystem.go @@ -1,7 +1,6 @@ package core import ( - "context" "errors" "codeberg.org/shroff/phylum/server/internal/db" @@ -11,67 +10,72 @@ import ( "github.com/sirupsen/logrus" ) -type filesystem struct { +type fileSystem struct { db db.Handler user User pathRoot pgtype.UUID } -func OpenFileSystem(ctx context.Context, user User, pathRoot pgtype.UUID) FileSystem { - return filesystem{ - db: db.Get(ctx), +type txFileSystem struct { + fileSystem + db db.TxHandler +} + +func OpenFileSystem(db db.Handler, user User, pathRoot pgtype.UUID) FileSystem { + return fileSystem{ + db: db, user: user, pathRoot: pathRoot, } } -func (u User) OpenFileSystem(ctx context.Context) FileSystem { - return OpenFileSystem(ctx, u, u.Home) +func (u User) OpenFileSystem(db db.Handler) FileSystem { + return OpenFileSystem(db, u, u.Home) } func OpenOmniscient(db db.Handler) FileSystem { return openOmniscient(db) } -func openOmniscient(db db.Handler) filesystem { - return filesystem{ +func openOmniscient(db db.Handler) fileSystem { + return fileSystem{ db: db, user: User{ID: -1, Permissions: -1}, - pathRoot: pgtype.UUID{Bytes: rootID(), Valid: true}, + pathRoot: pgtype.UUID{Bytes: rootID(db), Valid: true}, } } -func (f filesystem) withDb(db db.Handler) filesystem { - return filesystem{ - db: db, - user: f.user, - pathRoot: f.pathRoot, +func openOmniscientTx(db db.TxHandler) txFileSystem { + return txFileSystem{ + fileSystem: fileSystem{ + db: db, + user: User{ID: -1, Permissions: -1}, + pathRoot: pgtype.UUID{Bytes: rootID(db), Valid: true}, + }, + db: db, } } -func (f filesystem) withPathRoot(pathRoot pgtype.UUID) filesystem { - return filesystem{ +func (f fileSystem) runInTx(fn func(f txFileSystem) error) error { + return f.db.RunInTx(func(tx db.TxHandler) error { + return fn(txFileSystem{ + fileSystem: f, + db: tx, + }) + }) +} + +func (f fileSystem) withPathRoot(pathRoot pgtype.UUID) fileSystem { + return fileSystem{ db: f.db, user: f.user, pathRoot: pathRoot, } } -func (f filesystem) RunInTx(fn func(FileSystem) error) error { - return f.db.RunInTx(func(db db.Handler) error { - return fn(f.withDb(db)) - }) -} - -func (f filesystem) runInTx(fn func(filesystem) error) error { - return f.db.RunInTx(func(db db.Handler) error { - return fn(f.withDb(db)) - }) -} - -func rootID() uuid.UUID { +func rootID(db db.Handler) uuid.UUID { if _rootID == uuid.Nil { var err error - _rootID, err = _readRootID(context.Background()) + _rootID, err = _readRootID(db) if err != nil { logrus.Fatal("Could not read root ID: " + err.Error()) } @@ -79,16 +83,15 @@ func rootID() uuid.UUID { return _rootID } -func _readRootID(ctx context.Context) (uuid.UUID, error) { +func _readRootID(d db.Handler) (uuid.UUID, error) { const q = "SELECT id FROM resources WHERE parent IS NULL" - d := db.Get(ctx) row := d.QueryRow(q) var id uuid.UUID if err := row.Scan(&id); err != nil { if errors.Is(err, pgx.ErrNoRows) { const createDir = "INSERT INTO resources(id, name, dir) VALUES ($1::UUID, '', TRUE)" id, _ := uuid.NewV7() - _, err = d.Exec(createDir, id) + _, err = d.ExecNoTx(createDir, id) return id, err } return uuid.Nil, err diff --git a/server/internal/core/filesystem_readonly.go b/server/internal/core/filesystem_readonly.go index 80adc204..f6fce505 100644 --- a/server/internal/core/filesystem_readonly.go +++ b/server/internal/core/filesystem_readonly.go @@ -13,7 +13,7 @@ import ( ) type proxyFileSystemReadOnly struct { - f filesystem + f fileSystem } func (f proxyFileSystemReadOnly) ResourceByPath(path string) (Resource, error) { @@ -61,10 +61,14 @@ func OpenFileSystemFromPublink(ctx context.Context, id string, password string) } const q = "UPDATE publinks SET accessed = accessed + 1 WHERE id = $1" - if _, err := d.Exec(q, link.ID); err != nil { + if err := d.RunInTx(func(db db.TxHandler) error { + _, err := db.Exec(q, link.ID) + return err + }); err != nil { return nil, err } + // TODO: #redundant // TODO: #do not use omniscient return proxyFileSystemReadOnly{f: openOmniscient(d).withPathRoot(pgtype.UUID{Bytes: link.Root, Valid: true})}, nil } diff --git a/server/internal/core/resource_ancestors.go b/server/internal/core/resource_ancestors.go index 74b72d76..1b5beefa 100644 --- a/server/internal/core/resource_ancestors.go +++ b/server/internal/core/resource_ancestors.go @@ -13,7 +13,7 @@ type ResourceAncestor struct { UserPermission Permission } -func (f filesystem) scanResourceAncestor(row pgx.CollectableRow) (ResourceAncestor, error) { +func (f fileSystem) scanResourceAncestor(row pgx.CollectableRow) (ResourceAncestor, error) { var a ResourceAncestor err := row.Scan( &a.ID, @@ -36,7 +36,7 @@ const ancestorsQuery = `WITH RECURSIVE nodes(id, name, parent, userPermission) A ) SELECT id, name, userPermission FROM nodes` -func (f filesystem) GetAncestors(r Resource) ([]ResourceAncestor, error) { +func (f fileSystem) GetAncestors(r Resource) ([]ResourceAncestor, error) { if rows, err := f.db.Query(ancestorsQuery, r.id, f.user.ID, f.user.Permissions&PermissionFilesAll != 0); err != nil { return nil, err } else if a, err := pgx.CollectRows(rows, f.scanResourceAncestor); err != nil { @@ -46,7 +46,7 @@ func (f filesystem) GetAncestors(r Resource) ([]ResourceAncestor, error) { } } -func (f filesystem) GetPath(r Resource) (string, error) { +func (f fileSystem) GetPath(r Resource) (string, error) { if a, err := f.GetAncestors(r); err != nil { return "", err } else { diff --git a/server/internal/core/resource_copy_move.go b/server/internal/core/resource_copy_move.go index 965d946c..663ed0f4 100644 --- a/server/internal/core/resource_copy_move.go +++ b/server/internal/core/resource_copy_move.go @@ -4,13 +4,25 @@ import ( "errors" "strings" + "codeberg.org/shroff/phylum/server/internal/db" "github.com/doug-martin/goqu/v9" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgtype" "github.com/sirupsen/logrus" ) -func (f filesystem) Move(r Resource, target string, conflictResolution ResourceBindConflictResolution) (Resource, bool, error) { +func (f fileSystem) Move(r Resource, target string, conflictResolution ResourceBindConflictResolution) (Resource, bool, error) { + var res Resource + var deleted bool + var err error + err = f.runInTx(func(f txFileSystem) error { + res, deleted, err = f.Move(r, target, conflictResolution) + return err + }) + return res, deleted, err +} + +func (f txFileSystem) Move(r Resource, target string, conflictResolution ResourceBindConflictResolution) (Resource, bool, error) { // Check source directory permissions if r.deleted.Valid { return Resource{}, false, ErrResourceDeleted @@ -58,43 +70,41 @@ func (f filesystem) Move(r Resource, target string, conflictResolution ResourceB var res Resource var deleted = false - return res, deleted, f.runInTx(func(f filesystem) error { - if conflictResolution == ResourceBindConflictResolutionOverwrite || conflictResolution == ResourceBindConflictResolutionDelete { - if id, _, err := f.childResourceIDByName(r.ID(), destName); err != nil { - if !errors.Is(err, ErrResourceNotFound) { - return err - } - } else if err := softDelete(f.db, id); err != nil { - return err - } else { - deleted = true + if conflictResolution == ResourceBindConflictResolutionOverwrite || conflictResolution == ResourceBindConflictResolutionDelete { + if id, _, err := childResourceIDByName(f.db, r.ID(), destName); err != nil { + if !errors.Is(err, ErrResourceNotFound) { + return res, deleted, err } - } - newParentID := pgtype.UUID{ - Bytes: destParent.id, - Valid: true, - } - if r.parentID.Bytes == destParent.id { - newParentID = pgtype.UUID{} - } - if err := f.updateResourceNameParent(r.id, destName, newParentID); err != nil { - return err + } else if err := softDelete(f.db, id); err != nil { + return res, deleted, err } else { - res = r - res.name = destName - if newParentID.Valid { - if err := f.recomputePermissions(r.id); err != nil { - return err - } - res.parentID = newParentID - res.visibleParent = newParentID - } - return nil + deleted = true } - }) + } + newParentID := pgtype.UUID{ + Bytes: destParent.id, + Valid: true, + } + if r.parentID.Bytes == destParent.id { + newParentID = pgtype.UUID{} + } + if err := updateResourceNameParent(f.db, r.id, destName, newParentID); err != nil { + return res, deleted, err + } else { + res = r + res.name = destName + if newParentID.Valid { + if err := recomputePermissions(f.db, r.id); err != nil { + return res, deleted, err + } + res.parentID = newParentID + res.visibleParent = newParentID + } + return res, deleted, nil + } } -func (f filesystem) updateResourceNameParent(id uuid.UUID, name string, parent pgtype.UUID) error { +func updateResourceNameParent(db db.TxHandler, id uuid.UUID, name string, parent pgtype.UUID) error { updates := goqu.Record{ "modified": goqu.L("NOW()"), } @@ -106,7 +116,7 @@ func (f filesystem) updateResourceNameParent(id uuid.UUID, name string, parent p } q, args, _ := pg.Update("resources").Where(goqu.C("id").Eq(id)).Set(updates).ToSQL() - if _, err := f.db.Exec(q, args...); err != nil { + if _, err := db.Exec(q, args...); err != nil { if strings.Contains(err.Error(), "unique_member_resource_name") { return ErrResourceNameConflict } @@ -115,7 +125,18 @@ func (f filesystem) updateResourceNameParent(id uuid.UUID, name string, parent p return nil } -func (f filesystem) Copy(r Resource, target string, id uuid.UUID, recursive bool, conflictResolution ResourceBindConflictResolution) (Resource, bool, error) { +func (f fileSystem) Copy(r Resource, target string, id uuid.UUID, recursive bool, conflictResolution ResourceBindConflictResolution) (Resource, bool, error) { + var res Resource + var deleted bool + var err error + err = f.runInTx(func(f txFileSystem) error { + res, deleted, err = f.Copy(r, target, id, recursive, conflictResolution) + return err + }) + return res, deleted, err +} + +func (f txFileSystem) Copy(r Resource, target string, id uuid.UUID, recursive bool, conflictResolution ResourceBindConflictResolution) (Resource, bool, error) { // Check source directory permissions if err := r.checkPermission(f.user, PermissionWrite); err != nil { return Resource{}, false, err @@ -152,87 +173,81 @@ func (f filesystem) Copy(r Resource, target string, id uuid.UUID, recursive bool var contents []copyParams newIDs := make(map[uuid.UUID]uuid.UUID) - var targetRoot Resource - created := false - deleted := false - err = f.runInTx(func(f filesystem) error { - targetRoot, created, deleted, err = f.createResource( - id, - destParent.id, - destName, - r.dir, - destParent.permissions, - conflictResolution, - ) - // createResource may return an already existing resources, depending on the specified conflictResolution - id = targetRoot.id - if err == nil && r.id == id { - err = ErrResourceCopyTargetSelf - } - if err != nil { - return err - } - - if targetRoot.dir { - newIDs[r.id] = id - } else { - contents = append(contents, copyParams{ - src: r.latestVersionInfo, - destResource: id, - }) - } - - for _, src := range tree { - id, _ := uuid.NewV7() - parent := newIDs[src.parentID.Bytes] - - children = append(children, CreateResourcesParams{ - ID: id, - Parent: parent, - Name: src.name, - Dir: src.dir, - }) - - if src.dir { - newIDs[src.id] = id - } else { - contents = append(contents, copyParams{ - src: src.latestVersionInfo, - destResource: id, - }) - } - } - - if _, err := f.CreateResources(children); err != nil { - return err - } - if err := f.recomputePermissions(id); err != nil { - return err - } - if created { - return f.updateResourceModified(destParent.ID()) - } - return nil - }) - - if err == nil { - func() { - for _, c := range contents { - if err := f.copyContents(c); err != nil { - logrus.Warn("unable to copy " + c.src.ID.String() + " to " + c.destResource.String() + ": " + err.Error()) - } - } - - }() - } else { + targetRoot, created, deleted, err := createResource( + f.db, + id, + destParent.id, + destName, + r.dir, + destParent.permissions, + conflictResolution, + ) + // createResource may return an already existing resources, depending on the specified conflictResolution + id = targetRoot.id + if err == nil && r.id == id { + err = ErrResourceCopyTargetSelf + } + if err != nil { return Resource{}, false, err } - // TODO: #verify Shouldn't be necessary - targetRoot.visibleParent = pgtype.UUID{ - Bytes: destParent.id, - Valid: true, + if targetRoot.dir { + newIDs[r.id] = id + } else { + contents = append(contents, copyParams{ + src: r.latestVersionInfo, + destResource: id, + }) } + + for _, src := range tree { + id, _ := uuid.NewV7() + parent := newIDs[src.parentID.Bytes] + + children = append(children, CreateResourcesParams{ + ID: id, + Parent: parent, + Name: src.name, + Dir: src.dir, + }) + + if src.dir { + newIDs[src.id] = id + } else { + contents = append(contents, copyParams{ + src: src.latestVersionInfo, + destResource: id, + }) + } + } + + if _, err := createResources(f.db, children); err != nil { + return Resource{}, false, err + } + if err := recomputePermissions(f.db, id); err != nil { + return Resource{}, false, err + } + if created { + if err := updateResourceModified(f.db, destParent.ID()); err != nil { + return Resource{}, false, err + } + } + + func() { + // TODO: #jobs + for _, c := range contents { + if err := f.copyContents(c); err != nil { + logrus.Warn("unable to copy " + c.src.ID.String() + " to " + c.destResource.String() + ": " + err.Error()) + } + } + + }() + + // TODO: #verify this shouldn't be necessary + // targetRoot.visibleParent = pgtype.UUID{ + // Bytes: destParent.id, + // Valid: true, + // } return targetRoot, deleted, err } @@ -242,10 +257,10 @@ type copyParams struct { } // TODO: #implement copyContents -func (f filesystem) copyContents(params copyParams) error { +func (f txFileSystem) copyContents(params copyParams) error { versionID, _ := uuid.NewV7() - if err := f.createResourceVersion(params.destResource, versionID, params.src.Size, params.src.MimeType, params.src.SHA256); err != nil { + if err := createResourceVersion(f.db, params.destResource, versionID, params.src.Size, params.src.MimeType, params.src.SHA256); err != nil { return errors.New("failed to create version for " + params.destResource.String() + ": " + err.Error()) } diff --git a/server/internal/core/resource_create.go b/server/internal/core/resource_create.go index 97a8489d..f81ce039 100644 --- a/server/internal/core/resource_create.go +++ b/server/internal/core/resource_create.go @@ -2,15 +2,15 @@ package core import ( "errors" - "fmt" - "path" "strings" + "codeberg.org/shroff/phylum/server/internal/db" "codeberg.org/shroff/phylum/server/internal/storage" "github.com/doug-martin/goqu/v9" "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" + "github.com/sirupsen/logrus" ) type ResourceBindConflictResolution int32 @@ -29,7 +29,17 @@ func CheckResourceNameInvalid(s string) bool { }) } -func (f filesystem) CreateResourceByPath(path string, id uuid.UUID, dir, createParents bool, conflictResolution ResourceBindConflictResolution) (Resource, error) { +func (f fileSystem) CreateResourceByPath(path string, id uuid.UUID, dir, createParents bool, conflictResolution ResourceBindConflictResolution) (Resource, error) { + var res Resource + err := f.runInTx(func(f txFileSystem) error { + var err error + res, err = f.CreateResourceByPath(path, id, dir, createParents, conflictResolution) + return err + }) + return res, err +} + +func (f txFileSystem) CreateResourceByPath(path string, id uuid.UUID, dir, createParents bool, conflictResolution ResourceBindConflictResolution) (Resource, error) { if id == uuid.Nil { id, _ = uuid.NewV7() } @@ -50,7 +60,7 @@ func (f filesystem) CreateResourceByPath(path string, id uuid.UUID, dir, createP return Resource{}, ErrResourcePathInvalid } if root.Valid { - f = f.withPathRoot(root) + f.fileSystem = f.fileSystem.withPathRoot(root) } segments := strings.Split(strings.TrimRight(strings.TrimLeft(path, "/"), "/"), "/") r, err := f.ResourceByID(f.pathRoot.Bytes) @@ -72,7 +82,7 @@ func (f filesystem) CreateResourceByPath(path string, id uuid.UUID, dir, createP return r, err } -func (f filesystem) createMemberResource(r Resource, name string, id uuid.UUID, dir bool, conflictResolution ResourceBindConflictResolution) (Resource, error) { +func (f txFileSystem) createMemberResource(r Resource, name string, id uuid.UUID, dir bool, conflictResolution ResourceBindConflictResolution) (Resource, error) { if r.deleted.Valid { return Resource{}, ErrResourceDeleted } @@ -90,32 +100,26 @@ func (f filesystem) createMemberResource(r Resource, name string, id uuid.UUID, } var res Resource var created bool - err := f.runInTx(func(f filesystem) error { - var err error - if res, created, _, err = f.createResource(id, r.id, name, dir, r.permissions, conflictResolution); err != nil { - if strings.Contains(err.Error(), "unique_member_resource_name") { - return ErrResourceNameConflict - } - return err - } else if created { - if err := f.recomputePermissions(id); err != nil { - return err - } - return f.updateResourceModified(r.id) + var err error + if res, created, _, err = createResource(f.db, id, r.id, name, dir, r.permissions, conflictResolution); err != nil { + if errors.Is(err, ErrResourceIDConflict) { + return resourceByID(f.db, id, f.user.ID) } - return nil - }) - if err == ErrResourceIDConflict { - return f.ResourceByID(id) - } - if err != nil { return Resource{}, err + } else if created { + if err := recomputePermissions(f.db, id); err != nil { + return Resource{}, err + } + if err := updateResourceModified(f.db, r.id); err != nil { + return Resource{}, err + } } return res, nil } -func (f filesystem) createResource( +func createResource( + db db.TxHandler, id uuid.UUID, parent uuid.UUID, name string, @@ -123,79 +127,62 @@ func (f filesystem) createResource( permissions []byte, conflictResolution ResourceBindConflictResolution, ) (res Resource, created, deleted bool, err error) { - err = f.runInTx(func(f filesystem) error { - res, err = f.insertResource( + if name, err = detectNameConflict(db, parent, name, conflictResolution == ResourceBindConflictResolutionRename); err != nil { + // Name conflicts will be handled outside of this if-block + if !errors.Is(err, ErrResourceNameConflict) { + return + } + } else { + // No name conflict. Just insert and move along + res, err = insertResource( + db, id, parent, name, dir, permissions, ) - return err - }) - if err == nil { created = true + // maybe the request already succeeded in the previous attempt but the client didn't receive the response? + if strings.Contains(err.Error(), "resources_pkey") { + err = ErrResourceIDConflict + } return } - if strings.Contains(err.Error(), "unique_member_resource_name") { - switch conflictResolution { - case ResourceBindConflictResolutionError: + + switch conflictResolution { + case ResourceBindConflictResolutionError: + err = ErrResourceNameConflict + case ResourceBindConflictResolutionEnsure: + var rDir bool + _, rDir, err = childResourceIDByName(db, parent, name) + if err == nil && rDir != dir { err = ErrResourceNameConflict - case ResourceBindConflictResolutionEnsure: - var rDir bool - _, rDir, err = f.childResourceIDByName(parent, name) - if err == nil && rDir != dir { - err = ErrResourceNameConflict - } - case ResourceBindConflictResolutionRename: - if name, err = f.detectNameConflict(parent, name, true); err != nil { - return - } else { - res, err = f.insertResource( - id, - parent, - name, - dir, - permissions, - ) - return - } - case ResourceBindConflictResolutionOverwrite: - var rID uuid.UUID - var rDir bool - rID, rDir, err = f.childResourceIDByName(parent, name) - if err == nil { - deleted = true - if rDir == dir { - if dir { - err = f.softDeleteChildren(rID, parent) - } - if err == nil { - // Repurpose existing resource - res, err = f.ResourceByID(rID) - } - } else { - err = softDelete(f.db, res.id) - if err == nil { - res, created, _, err = f.createResource( - id, - parent, - name, - dir, - permissions, - ResourceBindConflictResolutionError, - ) - } + } + case ResourceBindConflictResolutionRename: + logrus.Warn("Rename case reached?!") + // This case is should already be handled above + case ResourceBindConflictResolutionOverwrite: + var rID uuid.UUID + var rDir bool + rID, rDir, err = childResourceIDByName(db, parent, name) + if err == nil { + deleted = true + if rDir == dir { + if dir { + err = softDeleteChildren(db, rID, parent) } - } - case ResourceBindConflictResolutionDelete: - var rID uuid.UUID - rID, _, err = f.childResourceIDByName(parent, name) - if err == nil { - deleted = true - err = softDelete(f.db, rID) if err == nil { - res, created, _, err = f.createResource( + // Repurpose existing resource + res, err = resourceByID(db, rID, -1) + // This is set from the query using the user id, which we passed in as '-1' above. + res.visibleParent = pgtype.UUID{Bytes: parent, Valid: true} + } + } else { + err = softDelete(db, res.id) + if err == nil { + res, created, _, err = createResource( + db, id, parent, name, @@ -206,14 +193,29 @@ func (f filesystem) createResource( } } } - } else if strings.Contains(err.Error(), "resources_pkey") { - // TODO: maybe the request already succeeded in the previous attempt but the client didn't receive the response? - err = ErrResourceIDConflict + case ResourceBindConflictResolutionDelete: + var rID uuid.UUID + rID, _, err = childResourceIDByName(db, parent, name) + if err == nil { + deleted = true + err = softDelete(db, rID) + if err == nil { + res, created, _, err = createResource( + db, + id, + parent, + name, + dir, + permissions, + ResourceBindConflictResolutionError, + ) + } + } } return } -func (f filesystem) insertResource(id, parent uuid.UUID, name string, dir bool, permissions []byte) (Resource, error) { +func insertResource(db db.TxHandler, id, parent uuid.UUID, name string, dir bool, permissions []byte) (Resource, error) { query, args, _ := pg.From("resources"). Insert(). Rows(goqu.Record{ @@ -231,7 +233,7 @@ func (f filesystem) insertResource(id, parent uuid.UUID, name string, dir bool, goqu.L("'{}'::JSONB"), // inherited permissions ). ToSQL() - if rows, err := f.db.Query(query, args...); err != nil { + if rows, err := db.Query(query, args...); err != nil { return Resource{}, err } else { r, err := collectFullResource(rows) @@ -243,13 +245,13 @@ func (f filesystem) insertResource(id, parent uuid.UUID, name string, dir bool, } } -func (f filesystem) updateResourceModified(id uuid.UUID) error { +func updateResourceModified(db db.TxHandler, id uuid.UUID) error { const q = "UPDATE resources SET modified = NOW() WHERE id = $1" - _, err := f.db.Exec(q, id) + _, err := db.Exec(q, id) return err } -func (f filesystem) createResourceVersion(id, versionID uuid.UUID, size int64, mimeType, sha256 string) error { +func createResourceVersion(db db.TxHandler, id, versionID uuid.UUID, size int64, mimeType, sha256 string) error { const q = `INSERT INTO resource_versions(id, resource_id, size, mime_type, sha256, storage) VALUES (@version_id::UUID, @resource_id::UUID, @size::INT, @mime_type::TEXT, @sha256::TEXT, @storage::TEXT)` @@ -261,39 +263,13 @@ func (f filesystem) createResourceVersion(id, versionID uuid.UUID, size int64, m "sha256": sha256, "storage": storage.DefaultBackendName, } - _, err := f.db.Exec(q, args) + _, err := db.Exec(q, args) return err } -func (f filesystem) detectNameConflict(parentID uuid.UUID, name string, autoRename bool) (string, error) { - if _, _, err := f.childResourceIDByName(parentID, name); err != nil { - // No name conflict. Good to go! - if errors.Is(err, ErrResourceNotFound) { - return name, nil - } - return "", err - } else if !autoRename { - return "", ErrResourceNameConflict - } - - ext := path.Ext(name) - basename := name[:len(name)-len(ext)] - counter := 1 - for { - name = fmt.Sprintf("%s (%d)%s", basename, counter, ext) - if _, _, err := f.childResourceIDByName(parentID, name); err == nil { - counter++ - } else if errors.Is(err, ErrResourceNotFound) { - return name, nil - } else { - return "", err - } - } -} - // TODO: Make not public -func (f filesystem) CreateResources(arg []CreateResourcesParams) (int64, error) { - return f.db.CopyFrom([]string{"resources"}, []string{"id", "parent", "name", "dir"}, &iteratorForCreateResources{rows: arg}) +func createResources(db db.TxHandler, arg []CreateResourcesParams) (int64, error) { + return db.CopyFrom([]string{"resources"}, []string{"id", "parent", "name", "dir"}, &iteratorForCreateResources{rows: arg}) } // For bulk insert diff --git a/server/internal/core/resource_delete.go b/server/internal/core/resource_delete.go index c5b0b46c..c81b53e3 100644 --- a/server/internal/core/resource_delete.go +++ b/server/internal/core/resource_delete.go @@ -11,7 +11,17 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) -func (f filesystem) Delete(r Resource) (Resource, error) { +func (f fileSystem) Delete(r Resource) (Resource, error) { + var res Resource + err := f.runInTx(func(f txFileSystem) error { + var err error + res, err = f.Delete(r) + return err + }) + return res, err +} + +func (f txFileSystem) Delete(r Resource) (Resource, error) { if !r.parentID.Valid { return Resource{}, ErrInsufficientPermissions } @@ -26,14 +36,10 @@ func (f filesystem) Delete(r Resource) (Resource, error) { return Resource{}, err } - err = f.runInTx(func(f filesystem) error { - if err := softDelete(f.db, r.id); err != nil { - return err - } - - return f.updateResourceModified(r.parentID.Bytes) - }) - if err != nil { + if err := softDelete(f.db, r.id); err != nil { + return Resource{}, err + } + if err := updateResourceModified(f.db, r.parentID.Bytes); err != nil { return Resource{}, err } @@ -44,75 +50,75 @@ func (f filesystem) Delete(r Resource) (Resource, error) { return r, nil } -func softDelete(d db.Handler, id uuid.UUID) error { - return d.RunInTx(func(db db.Handler) error { - n, q := selectResourceTree(id, false) - r := goqu.T("resources") +func softDelete(db db.TxHandler, id uuid.UUID) error { + n, q := selectResourceTree(id, false) + r := goqu.T("resources") - // Set modified and deleted - query, params, _ := q. - From(r). - Where(r.Col("id").Eq(pg.From(n).Select("id"))). - Update(). - Set( - goqu.Record{ - "modified": goqu.L("NOW()"), - "deleted": goqu.L("NOW()"), - }). - ToSQL() + // Set modified and deleted + query, params, _ := q. + From(r). + Where(r.Col("id").Eq(pg.From(n).Select("id"))). + Update(). + Set( + goqu.Record{ + "modified": goqu.L("NOW()"), + "deleted": goqu.L("NOW()"), + }). + ToSQL() - if _, err := db.Exec(query, params...); err != nil { - return err - } - - // Add to trash - query, params, _ = pg.Insert(goqu.T("trash")).Cols("id").Vals(goqu.Vals{id}).ToSQL() - _, err := db.Exec(query, params...) + if _, err := db.Exec(query, params...); err != nil { return err - }) -} - -func (f filesystem) softDeleteChildren(id, parent uuid.UUID) error { - err := f.runInTx(func(f filesystem) error { - n, s := selectResourceTree(id, false) - r := goqu.T("resources") - - // Mark deleted - q, params, _ := s. - From(r). - Where(r.Col("id").Eq(pg.From(n).Select("id"))). - Where(r.Col("id").Neq(id)). - Update(). - Set( - goqu.Record{ - "modified": goqu.L("NOW()"), - "deleted": goqu.L("NOW()"), - }). - ToSQL() - if _, err := f.db.Exec(q, params...); err != nil { - return err - } - - // Add children to trash - insert := pg. - Insert(goqu.T("trash")). - Cols("id"). - FromQuery(pg. - From("resources"). - Select("id"). - Where(goqu.C("parent").Eq(id))) - q, args, _ := insert.ToSQL() - if _, err := f.db.Exec(q, args...); err != nil { - return err - } - - return f.updateResourceModified(parent) - }) + } + // Add to trash + query, params, _ = pg.Insert(goqu.T("trash")).Cols("id").Vals(goqu.Vals{id}).ToSQL() + _, err := db.Exec(query, params...) return err } -func (f filesystem) DeleteForever(r Resource) error { +func softDeleteChildren(db db.TxHandler, id, parent uuid.UUID) error { + n, s := selectResourceTree(id, false) + r := goqu.T("resources") + + // Mark deleted + q, params, _ := s. + From(r). + Where(r.Col("id").Eq(pg.From(n).Select("id"))). + Where(r.Col("id").Neq(id)). + Update(). + Set( + goqu.Record{ + "modified": goqu.L("NOW()"), + "deleted": goqu.L("NOW()"), + }). + ToSQL() + if _, err := db.Exec(q, params...); err != nil { + return err + } + + // Add children to trash + insert := pg. + Insert(goqu.T("trash")). + Cols("id"). + FromQuery(pg. + From("resources"). + Select("id"). + Where(goqu.C("parent").Eq(id))) + q, args, _ := insert.ToSQL() + if _, err := db.Exec(q, args...); err != nil { + return err + } + + return updateResourceModified(db, parent) +} + +func (f fileSystem) DeleteForever(r Resource) error { + return f.runInTx(func(f txFileSystem) error { + return f.DeleteForever(r) + }) +} + +func (f txFileSystem) DeleteForever(r Resource) error { if !r.parentID.Valid { return ErrInsufficientPermissions } @@ -123,18 +129,16 @@ func (f filesystem) DeleteForever(r Resource) error { if err != nil { return err } - return f.runInTx(func(f filesystem) error { - // Select all descendants, including deleted resources - n, q := selectResourceTree(r.id, true) + // Select all descendants, including deleted resources + n, q := selectResourceTree(r.id, true) - if err := f.updateResourceModified(parent.id); err != nil { - return err - // deleteAllVersions needs to be called last, as it will enqueue the delete jobs - } else if err := hardDeleteAllVersions(f.db, q, n); err != nil { - return err - } - return nil - }) + if err := updateResourceModified(f.db, parent.id); err != nil { + return err + // deleteAllVersions needs to be called last, as it will enqueue the delete jobs + } else if err := hardDeleteAllVersions(f.db, q, n); err != nil { + return err + } + return nil } func collectDeletedVersions(rows pgx.Rows) ([]jobs.DeleteContentsArgs, error) { @@ -162,12 +166,21 @@ func collectDeletedVersions(rows pgx.Rows) ([]jobs.DeleteContentsArgs, error) { return result, nil } +func (f fileSystem) RestoreDeleted(r Resource, parentPathOrUUID string, name string, autoRename bool) (res Resource, err error) { + err = f.runInTx(func(f txFileSystem) error { + var err error + res, err = f.RestoreDeleted(r, parentPathOrUUID, name, autoRename) + return err + }) + return res, err +} + // RestoreDeleted restores a previously deleted resources // Checks: // - Parent must not be deleted // - Parent must have write permission // - No name conflict with exiting resource -func (f filesystem) RestoreDeleted(r Resource, parentPathOrUUID string, name string, autoRename bool) (res Resource, err error) { +func (f txFileSystem) RestoreDeleted(r Resource, parentPathOrUUID string, name string, autoRename bool) (res Resource, err error) { // Locate parent var parent Resource if parentPathOrUUID == "" { @@ -199,43 +212,43 @@ func (f filesystem) RestoreDeleted(r Resource, parentPathOrUUID string, name str if name == "" { name = r.name } - if name, err = f.detectNameConflict(parent.id, name, autoRename); err != nil { + if name, err = detectNameConflict(f.db, parent.id, name, autoRename); err != nil { return } - id := r.id - err = f.runInTx(func(f filesystem) error { - q, args, _ := pg.Delete(goqu.T("trash")).Where(goqu.C("id").Eq(r.id)).ToSQL() - if _, err := f.db.Exec(q, args...); err != nil { - return err - } + q, args, _ := pg.Delete(goqu.T("trash")).Where(goqu.C("id").Eq(r.id)).ToSQL() + if _, err = f.db.Exec(q, args...); err != nil { + return + } - if parent.id != r.parentID.Bytes || r.name != name { - if err := f.updateResourceNameParent(id, name, pgtype.UUID{Bytes: parent.id, Valid: true}); err != nil { - return err - } else { - r.name = name - r.parentID = pgtype.UUID{Bytes: parent.id, Valid: true} - r.visibleParent = r.parentID - } + if parent.id != r.parentID.Bytes || r.name != name { + if err = updateResourceNameParent(f.db, r.id, name, pgtype.UUID{Bytes: parent.id, Valid: true}); err != nil { + return + } else { + r.name = name + r.parentID = pgtype.UUID{Bytes: parent.id, Valid: true} + r.visibleParent = r.parentID } - n, s := selectResourceTree(id, false) - r := goqu.T("resources") - query, params, _ := s. - From(r). - Where(r.Col("id").Eq(pg.From(n).Select("id"))). - Update().Set( - goqu.Record{ - "modified": goqu.L("NOW()"), - "deleted": nil, - }).ToSQL() + } + n, s := selectResourceTree(r.id, false) + tR := goqu.T("resources") + query, params, _ := s. + From(r). + Where(tR.Col("id").Eq(pg.From(n).Select("id"))). + Update().Set( + goqu.Record{ + "modified": goqu.L("NOW()"), + "deleted": nil, + }).ToSQL() - if _, err := f.db.Exec(query, params...); err != nil { - return err - } + if _, err = f.db.Exec(query, params...); err != nil { + return + } + + if err = recomputePermissions(f.db, r.id); err != nil { + return + } - return f.recomputePermissions(id) - }) r.deleted = pgtype.Timestamp{} res = r return diff --git a/server/internal/core/resource_locate.go b/server/internal/core/resource_locate.go index 6edf9e1d..d5079103 100644 --- a/server/internal/core/resource_locate.go +++ b/server/internal/core/resource_locate.go @@ -1,42 +1,46 @@ package core import ( + "errors" + "fmt" + "path" "strings" + "codeberg.org/shroff/phylum/server/internal/db" "github.com/doug-martin/goqu/v9" "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" ) -func (f filesystem) ResourceByID(id uuid.UUID) (Resource, error) { - query := fullResourceQuery + "\nWHERE r.id = @id::UUID" - if !f.user.hasPermission(PermissionFilesAll) { - query = query + "\nAND r.permissions[@user_id::INT]::INTEGER <> 0" +func (f fileSystem) ResourceByID(id uuid.UUID) (Resource, error) { + if r, err := resourceByID(f.db, id, f.user.ID); err != nil { + return Resource{}, err + } else if err := r.checkPermission(f.user, PermissionRead); err != nil { + return Resource{}, err + } else { + return r, nil } +} + +func resourceByID(db db.Handler, id uuid.UUID, userID int32) (Resource, error) { + query := fullResourceQuery + "\nWHERE r.id = @id::UUID" args := pgx.NamedArgs{ - "user_id": f.user.ID, + "user_id": userID, "id": id, } - if rows, err := f.db.Query(query, args); err != nil { + if rows, err := db.Query(query, args); err != nil { return Resource{}, err } else { return collectFullResource(rows) } - // } else if r, err := collectFullResource(rows); err != nil { - // return Resource{}, err - // } else if err := r.checkPermission(f.user, PermissionRead); err != nil { - // return Resource{}, err - // } else { - // return r, nil - // } } // ResourceByPathWithRoot returns the resource at a given path from an optionally specified root. // The ": prefix can be used to specify the path root. // Will default to using the filesystem's current path root if one is not specified. // An empty path or "/" will return the root resource. -func (f filesystem) ResourceByPathWithRoot(path string) (Resource, error) { +func (f fileSystem) ResourceByPathWithRoot(path string) (Resource, error) { id, path, err := parseUUIDPrefix(path) if err != nil { return Resource{}, ErrResourceNotFound @@ -47,12 +51,16 @@ func (f filesystem) ResourceByPathWithRoot(path string) (Resource, error) { return f.ResourceByPath(path) } -// ResourceByPath returns the resource at a given path from this filesystem's path root -// An empty path or "/" will return the root resource. -func (f filesystem) ResourceByPath(path string) (Resource, error) { +func (f fileSystem) ResourceByPath(path string) (Resource, error) { if !f.pathRoot.Valid { return Resource{}, ErrResourceNotFound } + return resourceByPath(f.db, f.pathRoot.Bytes, path, f.user) +} + +// ResourceByPath returns the resource at a given path from this filesystem's path root +// An empty path or "/" will return the root resource. +func resourceByPath(db db.Handler, root uuid.UUID, path string, user User) (Resource, error) { nodes := goqu.T("nodes").As("n") r := goqu.T("resources").As("r") p := goqu.T("resources").As("p") @@ -68,7 +76,7 @@ func (f filesystem) ResourceByPath(path string) (Resource, error) { rec := pg. Select(r.Col("id"), r.Col("parent"), goqu.L("array_remove(string_to_array(?::TEXT, '/', NULL), '')", path), goqu.L("0")). From(r). - Where(r.Col("id").Eq(goqu.V(f.pathRoot))). + Where(r.Col("id").Eq(goqu.V(root))). UnionAll(sub) l := goqu.T("publinks").As("l") @@ -76,7 +84,7 @@ func (f filesystem) ResourceByPath(path string) (Resource, error) { q := pg.Select(r.All(), pg.Select(goqu.L(versionsQuery)).From(v).Where(v.Col("resource_id").Eq(r.Col("id"))), pg.Select(goqu.L(publinksQuery)).From(l).Where(l.Col("root").Eq(r.Col("id"))), - pg.Select(goqu.L("CASE WHEN COALESCE(p.permissions[?::INT]::INTEGER, 0) <> 0 THEN p.id ELSE NULL END AS visible_parent", f.user.ID)), + pg.Select(goqu.L("CASE WHEN COALESCE(p.permissions[?::INT]::INTEGER, 0) <> 0 THEN p.id ELSE NULL END AS visible_parent", user.ID)), pg.Select(goqu.L("COALESCE(p.permissions, '{}'::JSONB)")), ). From(r). @@ -85,13 +93,13 @@ func (f filesystem) ResourceByPath(path string) (Resource, error) { Join(nodes, goqu.On(r.Col("id").Eq(nodes.Col("id")))). Where(goqu.L("cardinality(n.search) = n.depth")) - if !f.user.hasPermission(PermissionFilesAll) { - q = q.Where(goqu.L("r.permissions[?::INT]::INTEGER <> 0", f.user.ID)) + if !user.hasPermission(PermissionFilesAll) { + q = q.Where(goqu.L("r.permissions[?::INT]::INTEGER <> 0", user.ID)) } query, args, _ := q.ToSQL() - if rows, err := f.db.Query(query, args...); err != nil { + if rows, err := db.Query(query, args...); err != nil { return Resource{}, err } else { return collectFullResource(rows) @@ -111,7 +119,7 @@ func (f filesystem) ResourceByPath(path string) (Resource, error) { // If no uuid prefix is supplied and the path begins with '/' then r.f is used as the path root // Splits the path to extract its last component as the name and traverses the rest of the path from the root as the parent // If no name is specified then return r.name as the name -func (f filesystem) targetNameParentByPathWithRoot(path string, src Resource) (string, Resource, error) { +func (f fileSystem) targetNameParentByPathWithRoot(path string, src Resource) (string, Resource, error) { id, path, err := parseUUIDPrefix(path) if err != nil { return "", Resource{}, err @@ -141,13 +149,39 @@ func (f filesystem) targetNameParentByPathWithRoot(path string, src Resource) (s return name, parent, nil } -func (f filesystem) childResourceIDByName(parentID uuid.UUID, name string) (uuid.UUID, bool, error) { +func detectNameConflict(db db.Handler, parentID uuid.UUID, name string, autoRename bool) (string, error) { + if _, _, err := childResourceIDByName(db, parentID, name); err != nil { + // No name conflict. Good to go! + if errors.Is(err, ErrResourceNotFound) { + return name, nil + } + return "", err + } else if !autoRename { + return "", ErrResourceNameConflict + } + + ext := path.Ext(name) + basename := name[:len(name)-len(ext)] + counter := 1 + for { + name = fmt.Sprintf("%s (%d)%s", basename, counter, ext) + if _, _, err := childResourceIDByName(db, parentID, name); err == nil { + counter++ + } else if errors.Is(err, ErrResourceNotFound) { + return name, nil + } else { + return "", err + } + } +} + +func childResourceIDByName(db db.Handler, parentID uuid.UUID, name string) (uuid.UUID, bool, error) { const query = "SELECT id, dir FROM resources WHERE parent = @parent::UUID AND name = @name::TEXT AND deleted IS NULL" args := pgx.NamedArgs{ "parent": parentID, "name": name, } - row := f.db.QueryRow(query, args) + row := db.QueryRow(query, args) var id uuid.UUID var dir bool err := row.Scan(&id, &dir) diff --git a/server/internal/core/resource_open.go b/server/internal/core/resource_open.go index d8f20815..1fbd3841 100644 --- a/server/internal/core/resource_open.go +++ b/server/internal/core/resource_open.go @@ -13,7 +13,13 @@ import ( "github.com/jackc/pgx/v5" ) -func (f filesystem) OpenWrite(r Resource, versionID uuid.UUID) (io.WriteCloser, error) { +// TODO: #tx Change to Write(Resource, uuid.UUID, func(io.WriteCloser) error) error +func (f fileSystem) OpenWrite(r Resource, versionID uuid.UUID) (io.WriteCloser, error) { + // TODO: #implement + return nil, nil +} + +func (f txFileSystem) OpenWrite(r Resource, versionID uuid.UUID) (io.WriteCloser, error) { if err := r.checkPermission(f.user, PermissionWrite); err != nil { return nil, err } @@ -28,31 +34,25 @@ func (f filesystem) OpenWrite(r Resource, versionID uuid.UUID) (io.WriteCloser, } else { return computeProps(dest, func(len int, hash hash.Hash, mimeType string) error { sum := hex.EncodeToString(hash.Sum(nil)) - err := f.runInTx(func(f filesystem) error { - if err := f.createResourceVersion(r.id, versionID, int64(len), mimeType, sum); err != nil { - return err - } - if err := f.updateResourceModified(r.id); err != nil { - return err - } - return nil - }) - if err != nil { + if err := createResourceVersion(f.db, r.id, versionID, int64(len), mimeType, sum); err != nil { + return err + } + if err := updateResourceModified(f.db, r.id); err != nil { return err } - jobs.MigrateVersionContents(versionID) - return nil + // TODO: #tx pass in transaction + return jobs.MigrateVersionContents(versionID) }), nil } } -func (f filesystem) ReadDir(r Resource, recursive bool) ([]Resource, error) { +func (f fileSystem) ReadDir(r Resource, recursive bool) ([]Resource, error) { return f.ReadDirDeleted(r, recursive, false) } -func (f filesystem) ReadDirDeleted(r Resource, recursive, includeDeleted bool) ([]Resource, error) { +func (f fileSystem) ReadDirDeleted(r Resource, recursive, includeDeleted bool) ([]Resource, error) { if !r.Dir() { return nil, ErrResourceNotCollection } @@ -76,7 +76,7 @@ func (f filesystem) ReadDirDeleted(r Resource, recursive, includeDeleted bool) ( } } -func (f filesystem) Walk(r Resource, depth int, fn func(Resource, string) error) error { +func (f fileSystem) Walk(r Resource, depth int, fn func(Resource, string) error) error { suffix := "" if r.Dir() { suffix = "/" diff --git a/server/internal/core/resource_permissions.go b/server/internal/core/resource_permissions.go index 2d46f360..20e7f7ea 100644 --- a/server/internal/core/resource_permissions.go +++ b/server/internal/core/resource_permissions.go @@ -1,6 +1,7 @@ package core import ( + "codeberg.org/shroff/phylum/server/internal/db" "github.com/google/uuid" "github.com/jackc/pgx/v5" ) @@ -15,7 +16,16 @@ const ( PermissionSU = Permission(-1) ) -func (f filesystem) UpdatePermissions(r Resource, user User, permission Permission) (Resource, error) { +func (f fileSystem) UpdatePermissions(r Resource, user User, permission Permission) (res Resource, err error) { + err = f.runInTx(func(f txFileSystem) error { + var err error + res, err = f.UpdatePermissions(r, user, permission) + return err + }) + return res, err +} + +func (f txFileSystem) UpdatePermissions(r Resource, user User, permission Permission) (Resource, error) { if r.deleted.Valid { return r, ErrResourceDeleted } @@ -48,20 +58,17 @@ RETURNING grants` } var grants []byte - err := f.runInTx(func(f filesystem) error { - row := f.db.QueryRow(q, pgx.NamedArgs{ - "resource_id": r.id, - "user_id": user.ID, - "permission": permission, - }) - - if err := row.Scan(&grants); err != nil { - return err - } - - return f.recomputePermissions(r.id) + row := f.db.QueryRow(q, pgx.NamedArgs{ + "resource_id": r.id, + "user_id": user.ID, + "permission": permission, }) - if err != nil { + + if err := row.Scan(&grants); err != nil { + return Resource{}, err + } + + if err := recomputePermissions(f.db, r.id); err != nil { return Resource{}, err } @@ -69,7 +76,7 @@ RETURNING grants` return r, nil } -func (f filesystem) recomputePermissions(id uuid.UUID) error { +func recomputePermissions(db db.TxHandler, id uuid.UUID) error { const q = ` WITH RECURSIVE nodes(id, parent, permissions) AS ( SELECT r.id, r.parent, phylum_merge_permission_grants(COALESCE(p.permissions, '{}'::JSONB), r.grants) END @@ -86,6 +93,6 @@ UPDATE resources FROM nodes WHERE resources.id = nodes.id` - _, err := f.db.Exec(q, id) + _, err := db.Exec(q, id) return err } diff --git a/server/internal/core/resource_publink.go b/server/internal/core/resource_publink.go index 5bcea7fd..1aa08322 100644 --- a/server/internal/core/resource_publink.go +++ b/server/internal/core/resource_publink.go @@ -8,7 +8,13 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) -func (f filesystem) CreatePublink(r Resource, id, password string, expires pgtype.Timestamp, accessLimit int) error { +func (f fileSystem) CreatePublink(r Resource, id, password string, expires pgtype.Timestamp, accessLimit int) error { + return f.runInTx(func(f txFileSystem) error { + return f.CreatePublink(r, id, password, expires, accessLimit) + }) +} + +func (f txFileSystem) CreatePublink(r Resource, id, password string, expires pgtype.Timestamp, accessLimit int) error { if err := r.checkPermission(f.user, PermissionShare|PermissionRead); err != nil { return err } @@ -36,7 +42,7 @@ func (f filesystem) CreatePublink(r Resource, id, password string, expires pgtyp return err } -func (f filesystem) ListPublinks(r Resource) ([]Publink, error) { +func (f fileSystem) ListPublinks(r Resource) ([]Publink, error) { const q = "SELECT * FROM publinks WHERE root = $1::UUID" if rows, err := f.db.Query(q, r.id); err != nil { return nil, err diff --git a/server/internal/core/resource_versions.go b/server/internal/core/resource_versions.go index fd23a193..9213705b 100644 --- a/server/internal/core/resource_versions.go +++ b/server/internal/core/resource_versions.go @@ -8,7 +8,7 @@ import ( "github.com/jackc/pgx/v5" ) -func (f filesystem) GetVersion(r Resource, versionID uuid.UUID) (Version, error) { +func (f fileSystem) GetVersion(r Resource, versionID uuid.UUID) (Version, error) { if versionID == uuid.Nil { versionID = r.latestVersionInfo.ID } @@ -33,7 +33,7 @@ AND DELETED IS NULL` return v, nil } -func (f filesystem) GetAllVersions(r Resource) ([]Version, error) { +func (f fileSystem) GetAllVersions(r Resource) ([]Version, error) { const q = `SELECT id, created, deleted, size, mime_type, sha256, storage FROM resource_versions WHERE resource_id = $1::UUID ORDER BY created DESC` diff --git a/server/internal/core/search.go b/server/internal/core/search.go index 9313b262..2ad5063a 100644 --- a/server/internal/core/search.go +++ b/server/internal/core/search.go @@ -6,7 +6,7 @@ import ( "github.com/jackc/pgx/v5" ) -func (f filesystem) Search(query string, includeDeleted bool) ([]Resource, error) { +func (f fileSystem) Search(query string, includeDeleted bool) ([]Resource, error) { qb := strings.Builder{} qb.WriteString(fullResourceQuery) qb.WriteString("WHERE f_prepare_search(r.name) %> @query::TEXT") diff --git a/server/internal/core/trash.go b/server/internal/core/trash.go index 90daca47..58a16043 100644 --- a/server/internal/core/trash.go +++ b/server/internal/core/trash.go @@ -16,7 +16,7 @@ import ( "github.com/sirupsen/logrus" ) -func (f filesystem) TrashList(cursor string, n uint) ([]Resource, string, error) { +func (f fileSystem) TrashList(cursor string, n uint) ([]Resource, string, error) { t := goqu.T("trash") r := goqu.T("resources").As("r") p := goqu.T("resources").As("p") @@ -90,7 +90,7 @@ func TrashCompact(ctx context.Context, duration time.Duration) { } } -func (f filesystem) TrashSummary() (int, int, error) { +func (f fileSystem) TrashSummary() (int, int, error) { v := goqu.T("resource_versions").As("v") n, q := f.selectTrash(time.Time{}) @@ -107,12 +107,14 @@ func (f filesystem) TrashSummary() (int, int, error) { return items, size, err } -func (f filesystem) TrashEmpty() error { +func (f fileSystem) TrashEmpty() error { n, q := f.selectTrash(time.Time{}) - return hardDeleteAllVersions(f.db, q, n) + return f.db.RunInTx(func(db db.TxHandler) error { + return hardDeleteAllVersions(db, q, n) + }) } -func (f filesystem) selectTrash(time time.Time) (exp.AliasedExpression, *goqu.SelectDataset) { +func (f fileSystem) selectTrash(time time.Time) (exp.AliasedExpression, *goqu.SelectDataset) { r := goqu.T("resources").As("r") n := goqu.T("nodes").As("n") t := goqu.T("trash").As("t") @@ -140,12 +142,14 @@ func (f filesystem) selectTrash(time time.Time) (exp.AliasedExpression, *goqu.Se return n, q } -func (f filesystem) hardDeleteOldResources(t time.Time) error { +func (f fileSystem) hardDeleteOldResources(t time.Time) error { n, q := f.selectTrash(t) - return hardDeleteAllVersions(f.db, q, n) + return f.db.RunInTx(func(db db.TxHandler) error { + return hardDeleteAllVersions(db, q, n) + }) } -func hardDeleteAllVersions(db db.Handler, q *goqu.SelectDataset, n interface { +func hardDeleteAllVersions(db db.TxHandler, q *goqu.SelectDataset, n interface { exp.Expression Col(interface{}) exp.IdentifierExpression }) error { diff --git a/server/internal/core/user_bookmarks.go b/server/internal/core/user_bookmarks.go index eda6d3d9..e401006e 100644 --- a/server/internal/core/user_bookmarks.go +++ b/server/internal/core/user_bookmarks.go @@ -29,7 +29,7 @@ func scanBookmark(row pgx.CollectableRow) (Bookmark, error) { return p, nil } -func AddBookmark(db db.Handler, u User, resource Resource, name string) (Bookmark, error) { +func AddBookmark(db db.TxHandler, u User, resource Resource, name string) (Bookmark, error) { if name == "" { name = resource.Name() } @@ -56,7 +56,7 @@ RETURNING resource_id, name, dir, created` } } -func RemoveBookmark(db db.Handler, u User, id uuid.UUID) error { +func RemoveBookmark(db db.TxHandler, u User, id uuid.UUID) error { const q = "DELETE FROM bookmarks WHERE user_id = $1::INT AND resource_id = $2::UUID" _, err := db.Exec(q, u.ID, id) return err diff --git a/server/internal/core/user_select.go b/server/internal/core/user_select.go index bfdd389b..193f7ce8 100644 --- a/server/internal/core/user_select.go +++ b/server/internal/core/user_select.go @@ -13,36 +13,31 @@ import ( var ErrUserNotFound = NewError(http.StatusNotFound, "user_not_found", "no such user") -func CreateUser(d db.Handler, email, name string, noCreateHome bool) (User, error) { - var user User - err := d.RunInTx(func(db db.Handler) error { +func CreateUser(db db.TxHandler, email, name string, noCreateHome bool) (User, error) { + f := openOmniscientTx(db) + var homeID pgtype.UUID + var home Resource + if !noCreateHome { var err error - var homeID pgtype.UUID - var home Resource - f := OpenOmniscient(db) - if !noCreateHome { - var err error - homePath := strings.TrimRight(Cfg.BaseDir, "/") + "/" + email - home, err = f.CreateResourceByPath(homePath, uuid.Nil, true, true, ResourceBindConflictResolutionEnsure) - if err != nil { - return err - } - homeID = pgtype.UUID{Bytes: home.ID(), Valid: true} - } - - user, err = insertUser(db, email, name, homeID) + homePath := strings.TrimRight(Cfg.BaseDir, "/") + "/" + email + home, err = f.CreateResourceByPath(homePath, uuid.Nil, true, true, ResourceBindConflictResolutionEnsure) if err != nil { - return err + return User{}, err } + homeID = pgtype.UUID{Bytes: home.ID(), Valid: true} + } - if homeID.Valid { - if _, err := f.UpdatePermissions(home, user, PermissionRead|PermissionWrite|PermissionShare); err != nil { - return err - } + user, err := insertUser(db, email, name, homeID) + if err != nil { + return User{}, err + } + + if homeID.Valid { + if _, err := f.UpdatePermissions(home, user, PermissionRead|PermissionWrite|PermissionShare); err != nil { + return User{}, err } - return err - }) - return user, err + } + return user, nil } func ListUsers(db db.Handler, since int64) ([]User, error) { diff --git a/server/internal/core/user_update.go b/server/internal/core/user_update.go index b6ea6583..fec2ddef 100644 --- a/server/internal/core/user_update.go +++ b/server/internal/core/user_update.go @@ -5,7 +5,7 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) -func UpdateUserHome(db db.Handler, user User, home pgtype.UUID) error { +func UpdateUserHome(db db.TxHandler, user User, home pgtype.UUID) error { const q = "UPDATE users SET home = $2::UUID, modified = NOW() WHERE id = $1::INT" if _, err := db.Exec(q, user.ID, home); err != nil { return err @@ -13,7 +13,7 @@ func UpdateUserHome(db db.Handler, user User, home pgtype.UUID) error { return nil } -func UpdateUserName(db db.Handler, user User, name string) error { +func UpdateUserName(db db.TxHandler, user User, name string) error { const q = "UPDATE users SET name = $2::TEXT, modified = NOW() WHERE id = $1::INT" if _, err := db.Exec(q, user.ID, name); err != nil { return err @@ -21,7 +21,7 @@ func UpdateUserName(db db.Handler, user User, name string) error { return nil } -func GrantUserPermissions(db db.Handler, user User, permissions UserPermissions) error { +func GrantUserPermissions(db db.TxHandler, user User, permissions UserPermissions) error { const q = "UPDATE users SET permissions = permissions | $2::INTEGER, modified = NOW() WHERE id = $1::INT" if _, err := db.Exec(q, user.ID, permissions); err != nil { return err @@ -29,7 +29,7 @@ func GrantUserPermissions(db db.Handler, user User, permissions UserPermissions) return nil } -func RevokeUserPermissions(db db.Handler, user User, permissions UserPermissions) error { +func RevokeUserPermissions(db db.TxHandler, user User, permissions UserPermissions) error { const q = "UPDATE users SET permissions = permissions & ~ $2::INTEGER, modified = NOW() WHERE id = $1::INT" if _, err := db.Exec(q, user.ID, permissions); err != nil { return err diff --git a/server/internal/db/db.go b/server/internal/db/db.go index 3aece498..b37aefad 100644 --- a/server/internal/db/db.go +++ b/server/internal/db/db.go @@ -27,7 +27,7 @@ func Get(ctx context.Context) Handler { notifier = n } } - return Handler{ + return handler{ ctx: ctx, tx: pool, } @@ -39,7 +39,7 @@ func Pool() *pgxpool.Pool { return pool } -func (d Handler) Notifier() pubsub.Notifier { +func Notifier() pubsub.Notifier { return notifier } diff --git a/server/internal/db/handler.go b/server/internal/db/handler.go index b5b0c21b..08650113 100644 --- a/server/internal/db/handler.go +++ b/server/internal/db/handler.go @@ -7,21 +7,45 @@ import ( "github.com/jackc/pgx/v5/pgconn" ) -type Handler struct { +type Handler interface { + Query(stmt string, args ...interface{}) (pgx.Rows, error) + QueryRow(stmt string, args ...interface{}) pgx.Row + ExecNoTx(stmt string, args ...interface{}) (pgconn.CommandTag, error) + RunInTx(fn func(TxHandler) error) error +} + +type TxHandler interface { + Handler + Exec(stmt string, args ...interface{}) (pgconn.CommandTag, error) + CopyFrom(tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) + SendBatch(batch *pgx.Batch) pgx.BatchResults + Tx() pgx.Tx +} + +type handler struct { ctx context.Context tx interface { Begin(context.Context) (pgx.Tx, error) - SendBatch(context.Context, *pgx.Batch) pgx.BatchResults - Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) Query(context.Context, string, ...interface{}) (pgx.Rows, error) QueryRow(context.Context, string, ...interface{}) pgx.Row - CopyFrom(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) (int64, error) + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) } } -func (h Handler) RunInTx(fn func(Handler) error) error { +func (h handler) Query(stmt string, args ...interface{}) (pgx.Rows, error) { + return h.tx.Query(h.ctx, stmt, args...) +} + +func (h handler) QueryRow(stmt string, args ...interface{}) pgx.Row { + return h.tx.QueryRow(h.ctx, stmt, args...) +} +func (h handler) ExecNoTx(stmt string, args ...interface{}) (pgconn.CommandTag, error) { + return h.tx.Exec(h.ctx, stmt, args...) +} + +func (h handler) RunInTx(fn func(TxHandler) error) error { return pgx.BeginFunc(h.ctx, h.tx, func(tx pgx.Tx) error { - h := Handler{ + h := txHandler{ ctx: h.ctx, tx: tx, } @@ -29,22 +53,46 @@ func (h Handler) RunInTx(fn func(Handler) error) error { }) } -func (h Handler) Exec(stmt string, args ...interface{}) (pgconn.CommandTag, error) { - return h.tx.Exec(h.ctx, stmt, args...) +type txHandler struct { + ctx context.Context + tx pgx.Tx } -func (h Handler) Query(stmt string, args ...interface{}) (pgx.Rows, error) { +func (h txHandler) Query(stmt string, args ...interface{}) (pgx.Rows, error) { return h.tx.Query(h.ctx, stmt, args...) } -func (h Handler) QueryRow(stmt string, args ...interface{}) pgx.Row { +func (h txHandler) QueryRow(stmt string, args ...interface{}) pgx.Row { return h.tx.QueryRow(h.ctx, stmt, args...) } -func (h Handler) CopyFrom(tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { +func (h txHandler) RunInTx(fn func(TxHandler) error) error { + return pgx.BeginFunc(h.ctx, h.tx, func(tx pgx.Tx) error { + h := txHandler{ + ctx: h.ctx, + tx: tx, + } + return fn(h) + }) +} + +func (h txHandler) Exec(stmt string, args ...interface{}) (pgconn.CommandTag, error) { + return h.tx.Exec(h.ctx, stmt, args...) +} + +// Part of the interface +func (h txHandler) ExecNoTx(stmt string, args ...interface{}) (pgconn.CommandTag, error) { + return h.tx.Exec(h.ctx, stmt, args...) +} + +func (h txHandler) CopyFrom(tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { return h.tx.CopyFrom(h.ctx, tableName, columnNames, rowSrc) } -func (h Handler) SendBatch(batch *pgx.Batch) pgx.BatchResults { +func (h txHandler) SendBatch(batch *pgx.Batch) pgx.BatchResults { return h.tx.SendBatch(h.ctx, batch) } + +func (h txHandler) Tx() pgx.Tx { + return h.tx +} diff --git a/server/internal/db/schema.go b/server/internal/db/schema.go index 99174168..34243b38 100644 --- a/server/internal/db/schema.go +++ b/server/internal/db/schema.go @@ -94,20 +94,20 @@ func Migrate(ctx context.Context, version int) error { func DeleteSchema(ctx context.Context) error { h := Get(ctx) user := pool.Config().ConnConfig.User - return h.RunInTx(func(d Handler) (err error) { - if _, err = d.Exec("DROP SCHEMA public CASCADE"); err != nil { + return h.RunInTx(func(tx TxHandler) (err error) { + if _, err = tx.Exec("DROP SCHEMA public CASCADE"); err != nil { return } - if _, err = d.Exec("CREATE SCHEMA public"); err != nil { + if _, err = tx.Exec("CREATE SCHEMA public"); err != nil { return } - if _, err = d.Exec("GRANT ALL ON SCHEMA public TO " + user); err != nil { + if _, err = tx.Exec("GRANT ALL ON SCHEMA public TO " + user); err != nil { return } - if _, err = d.Exec("GRANT ALL ON SCHEMA public TO public"); err != nil { + if _, err = tx.Exec("GRANT ALL ON SCHEMA public TO public"); err != nil { return } - _, err = d.Exec("COMMENT ON SCHEMA public IS 'standard public schema'") + _, err = tx.Exec("COMMENT ON SCHEMA public IS 'standard public schema'") return }) } diff --git a/server/internal/jobs/delete.go b/server/internal/jobs/delete.go deleted file mode 100644 index 2f377b8b..00000000 --- a/server/internal/jobs/delete.go +++ /dev/null @@ -1,71 +0,0 @@ -package jobs - -import ( - "context" - - "codeberg.org/shroff/phylum/server/internal/db" - "codeberg.org/shroff/phylum/server/internal/storage" - "github.com/google/uuid" - "github.com/jackc/pgx/v5" - "github.com/riverqueue/river" -) - -type DeleteArgs struct { - VersionIDS uuid.UUIDs `json:"version_ids"` -} - -func (DeleteArgs) Kind() string { return "delete" } - -type DeleteWorker struct { - river.WorkerDefaults[DeleteArgs] -} - -func (w *DeleteWorker) Work(ctx context.Context, job *river.Job[DeleteArgs]) error { - return deleteAllVersionContents(db.Get(ctx), job.Args.VersionIDS) -} - -type versionInfo struct { - id uuid.UUID - storage string -} - -func DeleteAllVersionContents(ids uuid.UUIDs) { - client.Insert(context.Background(), DeleteArgs{VersionIDS: ids}, &river.InsertOpts{}) -} - -func deleteAllVersionContents(db db.Handler, ids uuid.UUIDs) error { - const q = `SELECT v.id, v.storage FROM resources r -JOIN resource_versions v ON r.id = v.resource_id -WHERE r.id = ANY ($1::UUID[])` - if rows, err := db.Query(q, ids); err != nil { - return err - } else if versions, err := pgx.CollectRows(rows, scanDeletedVersion); err != nil { - return err - } else { - return deleteVersionContents(versions) - } -} - -func deleteVersionContents(versions []versionInfo) error { - idsPerBackend := make(map[string][]string) - for _, v := range versions { - idsPerBackend[v.storage] = append(idsPerBackend[v.storage], v.id.String()) - } - for k, v := range idsPerBackend { - if backend, err := storage.GetBackend(k); err != nil { - return err - } else { - backend.DeleteAll(v) - } - } - return nil -} - -func scanDeletedVersion(row pgx.CollectableRow) (versionInfo, error) { - var v versionInfo - err := row.Scan( - &v.id, - &v.storage, - ) - return v, err -} diff --git a/server/internal/jobs/jobs.go b/server/internal/jobs/jobs.go index d0a22c8c..3548ac4f 100644 --- a/server/internal/jobs/jobs.go +++ b/server/internal/jobs/jobs.go @@ -31,7 +31,6 @@ func Initialize(ctx context.Context, pool *pgxpool.Pool) error { workers := river.NewWorkers() river.AddWorker(workers, &MigrateWorker{}) - river.AddWorker(workers, &DeleteWorker{}) river.AddWorker(workers, &DeleteContentsWorker{}) if c, err := river.NewClient(riverpgxv5.New(pool), &river.Config{ diff --git a/server/internal/jobs/migrate.go b/server/internal/jobs/migrate.go index b31cdad4..d06c6781 100644 --- a/server/internal/jobs/migrate.go +++ b/server/internal/jobs/migrate.go @@ -25,8 +25,9 @@ func (w *MigrateWorker) Work(ctx context.Context, job *river.Job[MigrateArgs]) e return migrateVersionContents(ctx, job.Args.VersionID) } -func MigrateVersionContents(versionID uuid.UUID) { - client.Insert(context.Background(), MigrateArgs{VersionID: versionID}, &river.InsertOpts{}) +func MigrateVersionContents(versionID uuid.UUID) error { + _, err := client.Insert(context.Background(), MigrateArgs{VersionID: versionID}, &river.InsertOpts{}) + return err } func migrateVersionContents(ctx context.Context, versionID uuid.UUID) error { @@ -84,8 +85,10 @@ UNION ALL return storage.DefaultBackend(), nil } -func updateStorage(db db.Handler, versionID uuid.UUID, storage string) error { +func updateStorage(d db.Handler, versionID uuid.UUID, storage string) error { q := "UPDATE resource_versions SET storage = $2::TEXT WHERE id = $1::UUID" - _, err := db.Exec(q, versionID, storage) - return err + return d.RunInTx(func(db db.TxHandler) error { + _, err := db.Exec(q, versionID, storage) + return err + }) } diff --git a/server/internal/storage/storage.go b/server/internal/storage/storage.go index c84d3004..83539e82 100644 --- a/server/internal/storage/storage.go +++ b/server/internal/storage/storage.go @@ -59,7 +59,7 @@ func Initialize(db db.Handler) error { defaultBackend = b } - go processBackendUpdates(db) + go processBackendUpdates() return nil } @@ -82,7 +82,7 @@ func ListBackends() map[string]Backend { return backends } -func InsertBackend(db db.Handler, name string, driver Driver, params map[string]string) error { +func InsertBackend(d db.Handler, name string, driver Driver, params map[string]string) error { backend, err := driver.Create(name, params) if err != nil { return nil @@ -92,7 +92,10 @@ func InsertBackend(db db.Handler, name string, driver Driver, params map[string] if err != nil { return err } - if _, err := db.Exec(q, name, driver.Name, p); err != nil { + if err := d.RunInTx(func(tx db.TxHandler) error { + _, err := tx.Exec(q, name, driver.Name, p) + return err + }); err != nil { return err } backends[name] = backend @@ -130,7 +133,7 @@ func restoreBackends(db db.Handler) (map[string]Backend, error) { } } -func processBackendUpdates(db db.Handler) { +func processBackendUpdates() { sub := db.Notifier().Listen("backend_updates") for { p := <-sub.NotificationC()