[server][core][db] Streamline db package

This commit is contained in:
Abhishek Shroff
2025-04-04 22:36:17 +05:30
parent b0725140db
commit ba8cb6ecc1
12 changed files with 86 additions and 89 deletions
+3 -3
View File
@@ -29,11 +29,11 @@ func setupSchemaResetCommand() *cobra.Command {
Short: "Reset Database Schema", Short: "Reset Database Schema",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
db.AutoMigrate = false db.AutoMigrate = false
if err := db.Get().DeleteSchema(context.Background()); err != nil { if err := db.DeleteSchema(context.Background()); err != nil {
fmt.Println("unable to delete database schema: " + err.Error()) fmt.Println("unable to delete database schema: " + err.Error())
os.Exit(1) os.Exit(1)
} }
if err := db.Get().Migrate(context.Background(), -1); err != nil { if err := db.Migrate(context.Background(), -1); err != nil {
fmt.Println("unable to migrate database schema: " + err.Error()) fmt.Println("unable to migrate database schema: " + err.Error())
os.Exit(1) os.Exit(1)
} }
@@ -56,7 +56,7 @@ func setupSchemaMigrateCommand() *cobra.Command {
if v, err := strconv.Atoi(args[0]); err != nil { if v, err := strconv.Atoi(args[0]); err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
os.Exit(3) os.Exit(3)
} else if err := db.Get().Migrate(ctx, v); err != nil { } else if err := db.Migrate(ctx, v); err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
os.Exit(4) os.Exit(4)
} }
+1 -1
View File
@@ -92,7 +92,7 @@ func setupUserAddCommand() *cobra.Command {
} }
ctx := context.Background() ctx := context.Background()
err = db.Get().RunInTx(ctx, func(db *db.Handler) error { err = db.Get().RunInTx(ctx, func(db db.Handler) error {
userManager := user.CreateManager(ctx).WithDb(db) userManager := user.CreateManager(ctx).WithDb(db)
var u user.User var u user.User
if user, err := userManager.CreateUser(username, displayName, password, root); err != nil { if user, err := userManager.CreateUser(username, displayName, password, root); err != nil {
+1 -1
View File
@@ -55,7 +55,7 @@ func setupUserModCommand() *cobra.Command {
} }
ctx := context.Background() ctx := context.Background()
err = db.Get().RunInTx(ctx, func(db *db.Handler) error { err = db.Get().RunInTx(ctx, func(db db.Handler) error {
m := user.CreateManager(ctx).WithDb(db) m := user.CreateManager(ctx).WithDb(db)
if displayName != "" { if displayName != "" {
if err := m.UpdateUserDisplayName(u, displayName); err != nil { if err := m.UpdateUserDisplayName(u, displayName); err != nil {
+33 -36
View File
@@ -11,12 +11,43 @@ import (
const DefaultUserUsername = "phylum" const DefaultUserUsername = "phylum"
const DefaultUserDisplayName = "Phylum" const DefaultUserDisplayName = "Phylum"
func (d Handler) BootstrapData(ctx context.Context) error { func BootstrapData(ctx context.Context) error {
const q = "SELECT username FROM users WHERE username = $1::TEXT" const q = "SELECT username FROM users WHERE username = $1::TEXT"
d := Get()
row := d.QueryRow(ctx, q, DefaultUserUsername) row := d.QueryRow(ctx, q, DefaultUserUsername)
if err := row.Scan(nil); err != nil { if err := row.Scan(nil); err != nil {
if errors.Is(err, pgx.ErrNoRows) { if errors.Is(err, pgx.ErrNoRows) {
err = d.populateData(ctx) err = d.RunInTx(ctx, func(d Handler) error {
const createDir = `INSERT INTO resources(id, parent, name, dir, content_length, content_type, content_sha256)
VALUES ($1::UUID, $2::UUID, $3::TEXT, TRUE, 0, '', '')`
// Create root folder
rootID, _ := uuid.NewV7()
if _, err := d.Exec(ctx, createDir, rootID, nil, ""); err != nil {
return err
}
// Create home folder
homeID, _ := uuid.NewV7()
if _, err := d.Exec(ctx, createDir, homeID, rootID, "home"); err != nil {
return err
}
// Create user home folder
userHomeID, _ := uuid.NewV7()
if _, err := d.Exec(ctx, createDir, userHomeID, homeID, DefaultUserUsername); err != nil {
return err
}
const createRootUser = `INSERT INTO users(username, display_name, password_hash, root, home, permissions)
VALUES ($1::TEXT, $2::TEXT, '', $3::UUID, $4::UUID, -1)`
if _, err := d.Exec(ctx, createRootUser, DefaultUserUsername, DefaultUserDisplayName, rootID, userHomeID); err != nil {
return err
}
return nil
})
} }
if err != nil { if err != nil {
return err return err
@@ -25,37 +56,3 @@ func (d Handler) BootstrapData(ctx context.Context) error {
return nil return nil
} }
func (d Handler) populateData(ctx context.Context) (e error) {
return d.RunInTx(ctx, func(d *Handler) error {
const createDir = `INSERT INTO resources(id, parent, name, dir, content_length, content_type, content_sha256)
VALUES ($1::UUID, $2::UUID, $3::TEXT, TRUE, 0, '', '')`
// Create root folder
rootID, _ := uuid.NewV7()
if _, err := d.Exec(ctx, createDir, rootID, nil, ""); err != nil {
return err
}
// Create home folder
homeID, _ := uuid.NewV7()
if _, err := d.Exec(ctx, createDir, homeID, rootID, "home"); err != nil {
return err
}
// Create user home folder
userHomeID, _ := uuid.NewV7()
if _, err := d.Exec(ctx, createDir, userHomeID, homeID, DefaultUserUsername); err != nil {
return err
}
const createRootUser = `INSERT INTO users(username, display_name, password_hash, root, home, permissions)
VALUES ($1::TEXT, $2::TEXT, '', $3::UUID, $4::UUID, -1)`
if _, err := d.Exec(ctx, createRootUser, DefaultUserUsername, DefaultUserDisplayName, rootID, userHomeID); err != nil {
return err
}
return nil
})
}
+13 -17
View File
@@ -9,29 +9,30 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
var h *Handler
var DatabaseURL string var DatabaseURL string
var TraceSQL bool var TraceSQL bool
var AutoMigrate bool var AutoMigrate bool
func Get() *Handler { var pool *pgxpool.Pool
if h == nil {
if handler, err := createHandler(context.Background(), DatabaseURL, TraceSQL, AutoMigrate); err != nil { func Get() Handler {
if pool == nil {
if p, err := createPool(context.Background(), DatabaseURL, TraceSQL, AutoMigrate); err != nil {
logrus.Fatal(err) logrus.Fatal(err)
} else { } else {
h = handler pool = p
} }
} }
return h return Handler{tx: pool}
} }
func Close() { func Close() {
if h != nil { if pool != nil {
h.pool.Close() pool.Close()
h = nil pool = nil
} }
} }
func createHandler(ctx context.Context, dsn string, traceSQL, autoMigrate bool) (*Handler, error) { func createPool(ctx context.Context, dsn string, traceSQL, autoMigrate bool) (*pgxpool.Pool, error) {
config, err := pgxpool.ParseConfig(dsn) config, err := pgxpool.ParseConfig(dsn)
if err != nil { if err != nil {
return nil, errors.New("Unable to parse DSN: " + err.Error()) return nil, errors.New("Unable to parse DSN: " + err.Error())
@@ -48,14 +49,9 @@ func createHandler(ctx context.Context, dsn string, traceSQL, autoMigrate bool)
logrus.Debug("Connected to " + config.ConnConfig.Database + " at " + config.ConnConfig.Host + ":" + fmt.Sprint(config.ConnConfig.Port)) logrus.Debug("Connected to " + config.ConnConfig.Database + " at " + config.ConnConfig.Host + ":" + fmt.Sprint(config.ConnConfig.Port))
h := Handler{ if err := checkVersion(ctx, autoMigrate); err != nil {
tx: pool,
pool: pool,
}
if err := h.checkVersion(ctx, autoMigrate); err != nil {
return nil, err return nil, err
} }
return &h, nil return pool, nil
} }
+8 -6
View File
@@ -5,27 +5,25 @@ import (
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
) )
type Handler struct { type Handler struct {
tx interface { tx interface {
Begin(context.Context) (pgx.Tx, error) Begin(context.Context) (pgx.Tx, error)
SendBatch(context.Context, *pgx.Batch) pgx.BatchResults
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
Query(context.Context, string, ...interface{}) (pgx.Rows, error) Query(context.Context, string, ...interface{}) (pgx.Rows, error)
QueryRow(context.Context, string, ...interface{}) pgx.Row QueryRow(context.Context, string, ...interface{}) pgx.Row
CopyFrom(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) (int64, error) CopyFrom(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) (int64, error)
} }
pool *pgxpool.Pool
} }
func (h Handler) RunInTx(ctx context.Context, fn func(*Handler) error) error { func (h Handler) RunInTx(ctx context.Context, fn func(Handler) error) error {
return pgx.BeginFunc(ctx, h.tx, func(tx pgx.Tx) error { return pgx.BeginFunc(ctx, h.tx, func(tx pgx.Tx) error {
h := Handler{ h := Handler{
tx: tx, tx: tx,
pool: h.pool,
} }
return fn(&h) return fn(h)
}) })
} }
@@ -44,3 +42,7 @@ func (h Handler) QueryRow(ctx context.Context, stmt string, args ...interface{})
func (h Handler) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { func (h Handler) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
return h.tx.CopyFrom(ctx, tableName, columnNames, rowSrc) return h.tx.CopyFrom(ctx, tableName, columnNames, rowSrc)
} }
func (h Handler) SendBatch(ctx context.Context, batch *pgx.Batch) pgx.BatchResults {
return h.tx.SendBatch(ctx, batch)
}
+13 -11
View File
@@ -15,8 +15,8 @@ var (
ErrMigrationNoAutoDowngrade = errors.New("will not auto-downgrade schema to prevent data loss") ErrMigrationNoAutoDowngrade = errors.New("will not auto-downgrade schema to prevent data loss")
) )
func (h Handler) checkVersion(ctx context.Context, autoMigrate bool) error { func checkVersion(ctx context.Context, autoMigrate bool) error {
conn, err := h.pool.Acquire(ctx) conn, err := pool.Acquire(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -34,7 +34,7 @@ func (h Handler) checkVersion(ctx context.Context, autoMigrate bool) error {
// Nothing to do // Nothing to do
if currentSchemaVersion == latestSchemaVersion { if currentSchemaVersion == latestSchemaVersion {
return h.BootstrapData(ctx) return BootstrapData(ctx)
} }
if !autoMigrate { if !autoMigrate {
@@ -51,11 +51,12 @@ func (h Handler) checkVersion(ctx context.Context, autoMigrate bool) error {
if err != nil { if err != nil {
return err return err
} }
return h.BootstrapData(ctx) return BootstrapData(ctx)
} }
func (h Handler) Migrate(ctx context.Context, version int) error { func Migrate(ctx context.Context, version int) error {
conn, err := h.pool.Acquire(ctx) Get() // Initialize the pool, just in case
conn, err := pool.Acquire(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -74,7 +75,7 @@ func (h Handler) Migrate(ctx context.Context, version int) error {
version = latestSchemaVersion version = latestSchemaVersion
} else if version == 0 { } else if version == 0 {
logrus.Info("Deleting database schema") logrus.Info("Deleting database schema")
return h.DeleteSchema(ctx) return DeleteSchema(ctx)
} else if version > latestSchemaVersion { } else if version > latestSchemaVersion {
return ErrMigrationTargetTooHigh return ErrMigrationTargetTooHigh
} }
@@ -85,14 +86,15 @@ func (h Handler) Migrate(ctx context.Context, version int) error {
return err return err
} }
if version == latestSchemaVersion { if version == latestSchemaVersion {
return h.BootstrapData(ctx) return BootstrapData(ctx)
} }
return nil return nil
} }
func (h Handler) DeleteSchema(ctx context.Context) error { func DeleteSchema(ctx context.Context) error {
return h.RunInTx(ctx, func(d *Handler) (err error) { h := Get()
user := d.pool.Config().ConnConfig.User user := pool.Config().ConnConfig.User
return h.RunInTx(ctx, func(d Handler) (err error) {
if _, err = d.Exec(ctx, "DROP SCHEMA public CASCADE"); err != nil { if _, err = d.Exec(ctx, "DROP SCHEMA public CASCADE"); err != nil {
return return
} }
+5 -5
View File
@@ -10,7 +10,7 @@ import (
type filesystem struct { type filesystem struct {
ctx context.Context ctx context.Context
db *db.Handler db db.Handler
cs storage.Storage cs storage.Storage
rootID uuid.UUID rootID uuid.UUID
username string username string
@@ -36,11 +36,11 @@ func (f filesystem) withRoot(id uuid.UUID) filesystem {
} }
} }
func (f filesystem) WithDb(db *db.Handler) FileSystem { func (f filesystem) WithDb(db db.Handler) FileSystem {
return f.withDb(db) return f.withDb(db)
} }
func (f filesystem) withDb(db *db.Handler) filesystem { func (f filesystem) withDb(db db.Handler) filesystem {
return filesystem{ return filesystem{
ctx: f.ctx, ctx: f.ctx,
db: db, db: db,
@@ -52,13 +52,13 @@ func (f filesystem) withDb(db *db.Handler) filesystem {
} }
func (f filesystem) RunInTx(fn func(FileSystem) error) error { func (f filesystem) RunInTx(fn func(FileSystem) error) error {
return f.db.RunInTx(f.ctx, func(db *db.Handler) error { return f.db.RunInTx(f.ctx, func(db db.Handler) error {
return fn(f.WithDb(db)) return fn(f.WithDb(db))
}) })
} }
func (f filesystem) runInTx(fn func(filesystem) error) error { func (f filesystem) runInTx(fn func(filesystem) error) error {
return f.db.RunInTx(f.ctx, func(db *db.Handler) error { return f.db.RunInTx(f.ctx, func(db db.Handler) error {
return fn(f.withDb(db)) return fn(f.withDb(db))
}) })
} }
+2 -2
View File
@@ -36,7 +36,7 @@ type FileSystem interface {
// filesystem.go // filesystem.go
RootID() uuid.UUID RootID() uuid.UUID
WithRoot(uuid.UUID) FileSystem WithRoot(uuid.UUID) FileSystem
WithDb(db *db.Handler) FileSystem WithDb(db db.Handler) FileSystem
RunInTx(fn func(FileSystem) error) error RunInTx(fn func(FileSystem) error) error
// create.go // create.go
@@ -116,7 +116,7 @@ func OpenFromPublink(ctx context.Context, name string, password string, path str
return f.ResourceByPath(path) return f.ResourceByPath(path)
} }
func getPublink(d *db.Handler, ctx context.Context, name string) (Publink, error) { func getPublink(d db.Handler, ctx context.Context, name string) (Publink, error) {
q := "SELECT FROM publinks p WHERE name = $1::TEXT AND deleted IS NULL" q := "SELECT FROM publinks p WHERE name = $1::TEXT AND deleted IS NULL"
if rows, err := d.Query(ctx, q, name); err != nil { if rows, err := d.Query(ctx, q, name); err != nil {
return Publink{}, err return Publink{}, err
+3 -3
View File
@@ -25,7 +25,7 @@ type Storage interface {
} }
type storage struct { type storage struct {
db *db.Handler db db.Handler
backends map[string]Backend backends map[string]Backend
defaultBackend Backend defaultBackend Backend
} }
@@ -44,7 +44,7 @@ func Get() Storage {
return s return s
} }
func create(ctx context.Context, db *db.Handler, defaultStorageDir string) (Storage, error) { func create(ctx context.Context, db db.Handler, defaultStorageDir string) (Storage, error) {
if backends, err := restoreBackends(db, ctx); err != nil { if backends, err := restoreBackends(db, ctx); err != nil {
return nil, err return nil, err
} else if defaultBackend, err := newLocalStorage("<default>", defaultStorageDir); err != nil { } else if defaultBackend, err := newLocalStorage("<default>", defaultStorageDir); err != nil {
@@ -144,7 +144,7 @@ func (s storage) ListBackends() map[string]Backend {
return s.backends return s.backends
} }
func restoreBackends(db *db.Handler, ctx context.Context) (map[string]Backend, error) { func restoreBackends(db db.Handler, ctx context.Context) (map[string]Backend, error) {
const q = "SELECT name, driver, params from storage_backends" const q = "SELECT name, driver, params from storage_backends"
if rows, err := db.Query(ctx, q); err != nil { if rows, err := db.Query(ctx, q); err != nil {
return nil, err return nil, err
+3 -3
View File
@@ -7,7 +7,7 @@ import (
) )
type manager struct { type manager struct {
db *db.Handler db db.Handler
ctx context.Context ctx context.Context
} }
@@ -18,11 +18,11 @@ func CreateManager(ctx context.Context) Manager {
} }
} }
func (m manager) WithDb(db *db.Handler) Manager { func (m manager) WithDb(db db.Handler) Manager {
return m.withDb(db) return m.withDb(db)
} }
func (m manager) withDb(db *db.Handler) manager { func (m manager) withDb(db db.Handler) manager {
return manager{ return manager{
ctx: m.ctx, ctx: m.ctx,
db: db, db: db,
+1 -1
View File
@@ -51,7 +51,7 @@ func (u User) OpenFileSystem(ctx context.Context) fs.FileSystem {
type Manager interface { type Manager interface {
// manager.go // manager.go
WithDb(db *db.Handler) Manager WithDb(db db.Handler) Manager
// create.go // create.go
CreateUser(username, displayName, password string, root uuid.UUID) (User, error) CreateUser(username, displayName, password string, root uuid.UUID) (User, error)