diff --git a/server/internal/command/schema/cmd.go b/server/internal/command/schema/cmd.go index 07297d9d..039f0284 100644 --- a/server/internal/command/schema/cmd.go +++ b/server/internal/command/schema/cmd.go @@ -29,11 +29,11 @@ func setupSchemaResetCommand() *cobra.Command { Short: "Reset Database Schema", Run: func(cmd *cobra.Command, args []string) { db.AutoMigrate = false - if err := db.Get().DeleteSchema(context.Background()); err != nil { + if err := db.DeleteSchema(context.Background()); err != nil { fmt.Println("unable to delete database schema: " + err.Error()) os.Exit(1) } - if err := db.Get().Migrate(context.Background(), -1); err != nil { + if err := db.Migrate(context.Background(), -1); err != nil { fmt.Println("unable to migrate database schema: " + err.Error()) os.Exit(1) } @@ -56,7 +56,7 @@ func setupSchemaMigrateCommand() *cobra.Command { if v, err := strconv.Atoi(args[0]); err != nil { fmt.Println(err.Error()) os.Exit(3) - } else if err := db.Get().Migrate(ctx, v); err != nil { + } else if err := db.Migrate(ctx, v); err != nil { fmt.Println(err.Error()) os.Exit(4) } diff --git a/server/internal/command/user/add.go b/server/internal/command/user/add.go index 81af46b4..1f551700 100644 --- a/server/internal/command/user/add.go +++ b/server/internal/command/user/add.go @@ -92,7 +92,7 @@ func setupUserAddCommand() *cobra.Command { } ctx := context.Background() - err = db.Get().RunInTx(ctx, func(db *db.Handler) error { + err = db.Get().RunInTx(ctx, func(db db.Handler) error { userManager := user.CreateManager(ctx).WithDb(db) var u user.User if user, err := userManager.CreateUser(username, displayName, password, root); err != nil { diff --git a/server/internal/command/user/mod.go b/server/internal/command/user/mod.go index 2c75dece..df21799d 100644 --- a/server/internal/command/user/mod.go +++ b/server/internal/command/user/mod.go @@ -55,7 +55,7 @@ func setupUserModCommand() *cobra.Command { } ctx := context.Background() - err = db.Get().RunInTx(ctx, func(db *db.Handler) error { + err = db.Get().RunInTx(ctx, func(db db.Handler) error { m := user.CreateManager(ctx).WithDb(db) if displayName != "" { if err := m.UpdateUserDisplayName(u, displayName); err != nil { diff --git a/server/internal/core/db/bootstrap.go b/server/internal/core/db/bootstrap.go index 36bdc81f..7f208acd 100644 --- a/server/internal/core/db/bootstrap.go +++ b/server/internal/core/db/bootstrap.go @@ -11,12 +11,43 @@ import ( const DefaultUserUsername = "phylum" const DefaultUserDisplayName = "Phylum" -func (d Handler) BootstrapData(ctx context.Context) error { +func BootstrapData(ctx context.Context) error { const q = "SELECT username FROM users WHERE username = $1::TEXT" + d := Get() row := d.QueryRow(ctx, q, DefaultUserUsername) if err := row.Scan(nil); err != nil { if errors.Is(err, pgx.ErrNoRows) { - err = d.populateData(ctx) + err = d.RunInTx(ctx, func(d Handler) error { + const createDir = `INSERT INTO resources(id, parent, name, dir, content_length, content_type, content_sha256) +VALUES ($1::UUID, $2::UUID, $3::TEXT, TRUE, 0, '', '')` + + // Create root folder + rootID, _ := uuid.NewV7() + if _, err := d.Exec(ctx, createDir, rootID, nil, ""); err != nil { + return err + } + + // Create home folder + homeID, _ := uuid.NewV7() + if _, err := d.Exec(ctx, createDir, homeID, rootID, "home"); err != nil { + return err + } + + // Create user home folder + userHomeID, _ := uuid.NewV7() + if _, err := d.Exec(ctx, createDir, userHomeID, homeID, DefaultUserUsername); err != nil { + return err + } + + const createRootUser = `INSERT INTO users(username, display_name, password_hash, root, home, permissions) +VALUES ($1::TEXT, $2::TEXT, '', $3::UUID, $4::UUID, -1)` + + if _, err := d.Exec(ctx, createRootUser, DefaultUserUsername, DefaultUserDisplayName, rootID, userHomeID); err != nil { + return err + } + return nil + + }) } if err != nil { return err @@ -25,37 +56,3 @@ func (d Handler) BootstrapData(ctx context.Context) error { return nil } - -func (d Handler) populateData(ctx context.Context) (e error) { - return d.RunInTx(ctx, func(d *Handler) error { - const createDir = `INSERT INTO resources(id, parent, name, dir, content_length, content_type, content_sha256) -VALUES ($1::UUID, $2::UUID, $3::TEXT, TRUE, 0, '', '')` - - // Create root folder - rootID, _ := uuid.NewV7() - if _, err := d.Exec(ctx, createDir, rootID, nil, ""); err != nil { - return err - } - - // Create home folder - homeID, _ := uuid.NewV7() - if _, err := d.Exec(ctx, createDir, homeID, rootID, "home"); err != nil { - return err - } - - // Create user home folder - userHomeID, _ := uuid.NewV7() - if _, err := d.Exec(ctx, createDir, userHomeID, homeID, DefaultUserUsername); err != nil { - return err - } - - const createRootUser = `INSERT INTO users(username, display_name, password_hash, root, home, permissions) -VALUES ($1::TEXT, $2::TEXT, '', $3::UUID, $4::UUID, -1)` - - if _, err := d.Exec(ctx, createRootUser, DefaultUserUsername, DefaultUserDisplayName, rootID, userHomeID); err != nil { - return err - } - return nil - - }) -} diff --git a/server/internal/core/db/db.go b/server/internal/core/db/db.go index 0ee12c67..6388bdec 100644 --- a/server/internal/core/db/db.go +++ b/server/internal/core/db/db.go @@ -9,29 +9,30 @@ import ( "github.com/sirupsen/logrus" ) -var h *Handler var DatabaseURL string var TraceSQL bool var AutoMigrate bool -func Get() *Handler { - if h == nil { - if handler, err := createHandler(context.Background(), DatabaseURL, TraceSQL, AutoMigrate); err != nil { +var pool *pgxpool.Pool + +func Get() Handler { + if pool == nil { + if p, err := createPool(context.Background(), DatabaseURL, TraceSQL, AutoMigrate); err != nil { logrus.Fatal(err) } else { - h = handler + pool = p } } - return h + return Handler{tx: pool} } func Close() { - if h != nil { - h.pool.Close() - h = nil + if pool != nil { + pool.Close() + pool = nil } } -func createHandler(ctx context.Context, dsn string, traceSQL, autoMigrate bool) (*Handler, error) { +func createPool(ctx context.Context, dsn string, traceSQL, autoMigrate bool) (*pgxpool.Pool, error) { config, err := pgxpool.ParseConfig(dsn) if err != nil { return nil, errors.New("Unable to parse DSN: " + err.Error()) @@ -48,14 +49,9 @@ func createHandler(ctx context.Context, dsn string, traceSQL, autoMigrate bool) logrus.Debug("Connected to " + config.ConnConfig.Database + " at " + config.ConnConfig.Host + ":" + fmt.Sprint(config.ConnConfig.Port)) - h := Handler{ - tx: pool, - pool: pool, - } - - if err := h.checkVersion(ctx, autoMigrate); err != nil { + if err := checkVersion(ctx, autoMigrate); err != nil { return nil, err } - return &h, nil + return pool, nil } diff --git a/server/internal/core/db/handler.go b/server/internal/core/db/handler.go index c7c73e17..446a11b0 100644 --- a/server/internal/core/db/handler.go +++ b/server/internal/core/db/handler.go @@ -5,27 +5,25 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgxpool" ) type Handler struct { tx interface { Begin(context.Context) (pgx.Tx, error) + SendBatch(context.Context, *pgx.Batch) pgx.BatchResults Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) Query(context.Context, string, ...interface{}) (pgx.Rows, error) QueryRow(context.Context, string, ...interface{}) pgx.Row CopyFrom(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) (int64, error) } - pool *pgxpool.Pool } -func (h Handler) RunInTx(ctx context.Context, fn func(*Handler) error) error { +func (h Handler) RunInTx(ctx context.Context, fn func(Handler) error) error { return pgx.BeginFunc(ctx, h.tx, func(tx pgx.Tx) error { h := Handler{ - tx: tx, - pool: h.pool, + tx: tx, } - return fn(&h) + return fn(h) }) } @@ -44,3 +42,7 @@ func (h Handler) QueryRow(ctx context.Context, stmt string, args ...interface{}) func (h Handler) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { return h.tx.CopyFrom(ctx, tableName, columnNames, rowSrc) } + +func (h Handler) SendBatch(ctx context.Context, batch *pgx.Batch) pgx.BatchResults { + return h.tx.SendBatch(ctx, batch) +} diff --git a/server/internal/core/db/schema.go b/server/internal/core/db/schema.go index 85582824..fee5de71 100644 --- a/server/internal/core/db/schema.go +++ b/server/internal/core/db/schema.go @@ -15,8 +15,8 @@ var ( ErrMigrationNoAutoDowngrade = errors.New("will not auto-downgrade schema to prevent data loss") ) -func (h Handler) checkVersion(ctx context.Context, autoMigrate bool) error { - conn, err := h.pool.Acquire(ctx) +func checkVersion(ctx context.Context, autoMigrate bool) error { + conn, err := pool.Acquire(ctx) if err != nil { return err } @@ -34,7 +34,7 @@ func (h Handler) checkVersion(ctx context.Context, autoMigrate bool) error { // Nothing to do if currentSchemaVersion == latestSchemaVersion { - return h.BootstrapData(ctx) + return BootstrapData(ctx) } if !autoMigrate { @@ -51,11 +51,12 @@ func (h Handler) checkVersion(ctx context.Context, autoMigrate bool) error { if err != nil { return err } - return h.BootstrapData(ctx) + return BootstrapData(ctx) } -func (h Handler) Migrate(ctx context.Context, version int) error { - conn, err := h.pool.Acquire(ctx) +func Migrate(ctx context.Context, version int) error { + Get() // Initialize the pool, just in case + conn, err := pool.Acquire(ctx) if err != nil { return err } @@ -74,7 +75,7 @@ func (h Handler) Migrate(ctx context.Context, version int) error { version = latestSchemaVersion } else if version == 0 { logrus.Info("Deleting database schema") - return h.DeleteSchema(ctx) + return DeleteSchema(ctx) } else if version > latestSchemaVersion { return ErrMigrationTargetTooHigh } @@ -85,14 +86,15 @@ func (h Handler) Migrate(ctx context.Context, version int) error { return err } if version == latestSchemaVersion { - return h.BootstrapData(ctx) + return BootstrapData(ctx) } return nil } -func (h Handler) DeleteSchema(ctx context.Context) error { - return h.RunInTx(ctx, func(d *Handler) (err error) { - user := d.pool.Config().ConnConfig.User +func DeleteSchema(ctx context.Context) error { + h := Get() + user := pool.Config().ConnConfig.User + return h.RunInTx(ctx, func(d Handler) (err error) { if _, err = d.Exec(ctx, "DROP SCHEMA public CASCADE"); err != nil { return } diff --git a/server/internal/core/fs/filesystem.go b/server/internal/core/fs/filesystem.go index d697d520..95657d6b 100644 --- a/server/internal/core/fs/filesystem.go +++ b/server/internal/core/fs/filesystem.go @@ -10,7 +10,7 @@ import ( type filesystem struct { ctx context.Context - db *db.Handler + db db.Handler cs storage.Storage rootID uuid.UUID username string @@ -36,11 +36,11 @@ func (f filesystem) withRoot(id uuid.UUID) filesystem { } } -func (f filesystem) WithDb(db *db.Handler) FileSystem { +func (f filesystem) WithDb(db db.Handler) FileSystem { return f.withDb(db) } -func (f filesystem) withDb(db *db.Handler) filesystem { +func (f filesystem) withDb(db db.Handler) filesystem { return filesystem{ ctx: f.ctx, db: db, @@ -52,13 +52,13 @@ func (f filesystem) withDb(db *db.Handler) filesystem { } func (f filesystem) RunInTx(fn func(FileSystem) error) error { - return f.db.RunInTx(f.ctx, func(db *db.Handler) error { + return f.db.RunInTx(f.ctx, func(db db.Handler) error { return fn(f.WithDb(db)) }) } func (f filesystem) runInTx(fn func(filesystem) error) error { - return f.db.RunInTx(f.ctx, func(db *db.Handler) error { + return f.db.RunInTx(f.ctx, func(db db.Handler) error { return fn(f.withDb(db)) }) } diff --git a/server/internal/core/fs/fs.go b/server/internal/core/fs/fs.go index 4f42cc02..972d9e40 100644 --- a/server/internal/core/fs/fs.go +++ b/server/internal/core/fs/fs.go @@ -36,7 +36,7 @@ type FileSystem interface { // filesystem.go RootID() uuid.UUID WithRoot(uuid.UUID) FileSystem - WithDb(db *db.Handler) FileSystem + WithDb(db db.Handler) FileSystem RunInTx(fn func(FileSystem) error) error // create.go @@ -116,7 +116,7 @@ func OpenFromPublink(ctx context.Context, name string, password string, path str return f.ResourceByPath(path) } -func getPublink(d *db.Handler, ctx context.Context, name string) (Publink, error) { +func getPublink(d db.Handler, ctx context.Context, name string) (Publink, error) { q := "SELECT FROM publinks p WHERE name = $1::TEXT AND deleted IS NULL" if rows, err := d.Query(ctx, q, name); err != nil { return Publink{}, err diff --git a/server/internal/core/storage/storage.go b/server/internal/core/storage/storage.go index a6011c28..c19d0f1f 100644 --- a/server/internal/core/storage/storage.go +++ b/server/internal/core/storage/storage.go @@ -25,7 +25,7 @@ type Storage interface { } type storage struct { - db *db.Handler + db db.Handler backends map[string]Backend defaultBackend Backend } @@ -44,7 +44,7 @@ func Get() Storage { return s } -func create(ctx context.Context, db *db.Handler, defaultStorageDir string) (Storage, error) { +func create(ctx context.Context, db db.Handler, defaultStorageDir string) (Storage, error) { if backends, err := restoreBackends(db, ctx); err != nil { return nil, err } else if defaultBackend, err := newLocalStorage("", defaultStorageDir); err != nil { @@ -144,7 +144,7 @@ func (s storage) ListBackends() map[string]Backend { return s.backends } -func restoreBackends(db *db.Handler, ctx context.Context) (map[string]Backend, error) { +func restoreBackends(db db.Handler, ctx context.Context) (map[string]Backend, error) { const q = "SELECT name, driver, params from storage_backends" if rows, err := db.Query(ctx, q); err != nil { return nil, err diff --git a/server/internal/core/user/manager.go b/server/internal/core/user/manager.go index 12a8b192..db94a68a 100644 --- a/server/internal/core/user/manager.go +++ b/server/internal/core/user/manager.go @@ -7,7 +7,7 @@ import ( ) type manager struct { - db *db.Handler + db db.Handler ctx context.Context } @@ -18,11 +18,11 @@ func CreateManager(ctx context.Context) Manager { } } -func (m manager) WithDb(db *db.Handler) Manager { +func (m manager) WithDb(db db.Handler) Manager { return m.withDb(db) } -func (m manager) withDb(db *db.Handler) manager { +func (m manager) withDb(db db.Handler) manager { return manager{ ctx: m.ctx, db: db, diff --git a/server/internal/core/user/user.go b/server/internal/core/user/user.go index d9bd8f3f..9c7a43b7 100644 --- a/server/internal/core/user/user.go +++ b/server/internal/core/user/user.go @@ -51,7 +51,7 @@ func (u User) OpenFileSystem(ctx context.Context) fs.FileSystem { type Manager interface { // manager.go - WithDb(db *db.Handler) Manager + WithDb(db db.Handler) Manager // create.go CreateUser(username, displayName, password string, root uuid.UUID) (User, error)