[server][core][db] Streamline db package

This commit is contained in:
Abhishek Shroff
2025-04-04 22:36:17 +05:30
parent b0725140db
commit ba8cb6ecc1
12 changed files with 86 additions and 89 deletions

View File

@@ -29,11 +29,11 @@ func setupSchemaResetCommand() *cobra.Command {
Short: "Reset Database Schema",
Run: func(cmd *cobra.Command, args []string) {
db.AutoMigrate = false
if err := db.Get().DeleteSchema(context.Background()); err != nil {
if err := db.DeleteSchema(context.Background()); err != nil {
fmt.Println("unable to delete database schema: " + err.Error())
os.Exit(1)
}
if err := db.Get().Migrate(context.Background(), -1); err != nil {
if err := db.Migrate(context.Background(), -1); err != nil {
fmt.Println("unable to migrate database schema: " + err.Error())
os.Exit(1)
}
@@ -56,7 +56,7 @@ func setupSchemaMigrateCommand() *cobra.Command {
if v, err := strconv.Atoi(args[0]); err != nil {
fmt.Println(err.Error())
os.Exit(3)
} else if err := db.Get().Migrate(ctx, v); err != nil {
} else if err := db.Migrate(ctx, v); err != nil {
fmt.Println(err.Error())
os.Exit(4)
}

View File

@@ -92,7 +92,7 @@ func setupUserAddCommand() *cobra.Command {
}
ctx := context.Background()
err = db.Get().RunInTx(ctx, func(db *db.Handler) error {
err = db.Get().RunInTx(ctx, func(db db.Handler) error {
userManager := user.CreateManager(ctx).WithDb(db)
var u user.User
if user, err := userManager.CreateUser(username, displayName, password, root); err != nil {

View File

@@ -55,7 +55,7 @@ func setupUserModCommand() *cobra.Command {
}
ctx := context.Background()
err = db.Get().RunInTx(ctx, func(db *db.Handler) error {
err = db.Get().RunInTx(ctx, func(db db.Handler) error {
m := user.CreateManager(ctx).WithDb(db)
if displayName != "" {
if err := m.UpdateUserDisplayName(u, displayName); err != nil {

View File

@@ -11,12 +11,43 @@ import (
const DefaultUserUsername = "phylum"
const DefaultUserDisplayName = "Phylum"
func (d Handler) BootstrapData(ctx context.Context) error {
func BootstrapData(ctx context.Context) error {
const q = "SELECT username FROM users WHERE username = $1::TEXT"
d := Get()
row := d.QueryRow(ctx, q, DefaultUserUsername)
if err := row.Scan(nil); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
err = d.populateData(ctx)
err = d.RunInTx(ctx, func(d Handler) error {
const createDir = `INSERT INTO resources(id, parent, name, dir, content_length, content_type, content_sha256)
VALUES ($1::UUID, $2::UUID, $3::TEXT, TRUE, 0, '', '')`
// Create root folder
rootID, _ := uuid.NewV7()
if _, err := d.Exec(ctx, createDir, rootID, nil, ""); err != nil {
return err
}
// Create home folder
homeID, _ := uuid.NewV7()
if _, err := d.Exec(ctx, createDir, homeID, rootID, "home"); err != nil {
return err
}
// Create user home folder
userHomeID, _ := uuid.NewV7()
if _, err := d.Exec(ctx, createDir, userHomeID, homeID, DefaultUserUsername); err != nil {
return err
}
const createRootUser = `INSERT INTO users(username, display_name, password_hash, root, home, permissions)
VALUES ($1::TEXT, $2::TEXT, '', $3::UUID, $4::UUID, -1)`
if _, err := d.Exec(ctx, createRootUser, DefaultUserUsername, DefaultUserDisplayName, rootID, userHomeID); err != nil {
return err
}
return nil
})
}
if err != nil {
return err
@@ -25,37 +56,3 @@ func (d Handler) BootstrapData(ctx context.Context) error {
return nil
}
func (d Handler) populateData(ctx context.Context) (e error) {
return d.RunInTx(ctx, func(d *Handler) error {
const createDir = `INSERT INTO resources(id, parent, name, dir, content_length, content_type, content_sha256)
VALUES ($1::UUID, $2::UUID, $3::TEXT, TRUE, 0, '', '')`
// Create root folder
rootID, _ := uuid.NewV7()
if _, err := d.Exec(ctx, createDir, rootID, nil, ""); err != nil {
return err
}
// Create home folder
homeID, _ := uuid.NewV7()
if _, err := d.Exec(ctx, createDir, homeID, rootID, "home"); err != nil {
return err
}
// Create user home folder
userHomeID, _ := uuid.NewV7()
if _, err := d.Exec(ctx, createDir, userHomeID, homeID, DefaultUserUsername); err != nil {
return err
}
const createRootUser = `INSERT INTO users(username, display_name, password_hash, root, home, permissions)
VALUES ($1::TEXT, $2::TEXT, '', $3::UUID, $4::UUID, -1)`
if _, err := d.Exec(ctx, createRootUser, DefaultUserUsername, DefaultUserDisplayName, rootID, userHomeID); err != nil {
return err
}
return nil
})
}

View File

@@ -9,29 +9,30 @@ import (
"github.com/sirupsen/logrus"
)
var h *Handler
var DatabaseURL string
var TraceSQL bool
var AutoMigrate bool
func Get() *Handler {
if h == nil {
if handler, err := createHandler(context.Background(), DatabaseURL, TraceSQL, AutoMigrate); err != nil {
var pool *pgxpool.Pool
func Get() Handler {
if pool == nil {
if p, err := createPool(context.Background(), DatabaseURL, TraceSQL, AutoMigrate); err != nil {
logrus.Fatal(err)
} else {
h = handler
pool = p
}
}
return h
return Handler{tx: pool}
}
func Close() {
if h != nil {
h.pool.Close()
h = nil
if pool != nil {
pool.Close()
pool = nil
}
}
func createHandler(ctx context.Context, dsn string, traceSQL, autoMigrate bool) (*Handler, error) {
func createPool(ctx context.Context, dsn string, traceSQL, autoMigrate bool) (*pgxpool.Pool, error) {
config, err := pgxpool.ParseConfig(dsn)
if err != nil {
return nil, errors.New("Unable to parse DSN: " + err.Error())
@@ -48,14 +49,9 @@ func createHandler(ctx context.Context, dsn string, traceSQL, autoMigrate bool)
logrus.Debug("Connected to " + config.ConnConfig.Database + " at " + config.ConnConfig.Host + ":" + fmt.Sprint(config.ConnConfig.Port))
h := Handler{
tx: pool,
pool: pool,
}
if err := h.checkVersion(ctx, autoMigrate); err != nil {
if err := checkVersion(ctx, autoMigrate); err != nil {
return nil, err
}
return &h, nil
return pool, nil
}

View File

@@ -5,27 +5,25 @@ import (
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
)
type Handler struct {
tx interface {
Begin(context.Context) (pgx.Tx, error)
SendBatch(context.Context, *pgx.Batch) pgx.BatchResults
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
QueryRow(context.Context, string, ...interface{}) pgx.Row
CopyFrom(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) (int64, error)
}
pool *pgxpool.Pool
}
func (h Handler) RunInTx(ctx context.Context, fn func(*Handler) error) error {
func (h Handler) RunInTx(ctx context.Context, fn func(Handler) error) error {
return pgx.BeginFunc(ctx, h.tx, func(tx pgx.Tx) error {
h := Handler{
tx: tx,
pool: h.pool,
tx: tx,
}
return fn(&h)
return fn(h)
})
}
@@ -44,3 +42,7 @@ func (h Handler) QueryRow(ctx context.Context, stmt string, args ...interface{})
func (h Handler) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
return h.tx.CopyFrom(ctx, tableName, columnNames, rowSrc)
}
func (h Handler) SendBatch(ctx context.Context, batch *pgx.Batch) pgx.BatchResults {
return h.tx.SendBatch(ctx, batch)
}

View File

@@ -15,8 +15,8 @@ var (
ErrMigrationNoAutoDowngrade = errors.New("will not auto-downgrade schema to prevent data loss")
)
func (h Handler) checkVersion(ctx context.Context, autoMigrate bool) error {
conn, err := h.pool.Acquire(ctx)
func checkVersion(ctx context.Context, autoMigrate bool) error {
conn, err := pool.Acquire(ctx)
if err != nil {
return err
}
@@ -34,7 +34,7 @@ func (h Handler) checkVersion(ctx context.Context, autoMigrate bool) error {
// Nothing to do
if currentSchemaVersion == latestSchemaVersion {
return h.BootstrapData(ctx)
return BootstrapData(ctx)
}
if !autoMigrate {
@@ -51,11 +51,12 @@ func (h Handler) checkVersion(ctx context.Context, autoMigrate bool) error {
if err != nil {
return err
}
return h.BootstrapData(ctx)
return BootstrapData(ctx)
}
func (h Handler) Migrate(ctx context.Context, version int) error {
conn, err := h.pool.Acquire(ctx)
func Migrate(ctx context.Context, version int) error {
Get() // Initialize the pool, just in case
conn, err := pool.Acquire(ctx)
if err != nil {
return err
}
@@ -74,7 +75,7 @@ func (h Handler) Migrate(ctx context.Context, version int) error {
version = latestSchemaVersion
} else if version == 0 {
logrus.Info("Deleting database schema")
return h.DeleteSchema(ctx)
return DeleteSchema(ctx)
} else if version > latestSchemaVersion {
return ErrMigrationTargetTooHigh
}
@@ -85,14 +86,15 @@ func (h Handler) Migrate(ctx context.Context, version int) error {
return err
}
if version == latestSchemaVersion {
return h.BootstrapData(ctx)
return BootstrapData(ctx)
}
return nil
}
func (h Handler) DeleteSchema(ctx context.Context) error {
return h.RunInTx(ctx, func(d *Handler) (err error) {
user := d.pool.Config().ConnConfig.User
func DeleteSchema(ctx context.Context) error {
h := Get()
user := pool.Config().ConnConfig.User
return h.RunInTx(ctx, func(d Handler) (err error) {
if _, err = d.Exec(ctx, "DROP SCHEMA public CASCADE"); err != nil {
return
}

View File

@@ -10,7 +10,7 @@ import (
type filesystem struct {
ctx context.Context
db *db.Handler
db db.Handler
cs storage.Storage
rootID uuid.UUID
username string
@@ -36,11 +36,11 @@ func (f filesystem) withRoot(id uuid.UUID) filesystem {
}
}
func (f filesystem) WithDb(db *db.Handler) FileSystem {
func (f filesystem) WithDb(db db.Handler) FileSystem {
return f.withDb(db)
}
func (f filesystem) withDb(db *db.Handler) filesystem {
func (f filesystem) withDb(db db.Handler) filesystem {
return filesystem{
ctx: f.ctx,
db: db,
@@ -52,13 +52,13 @@ func (f filesystem) withDb(db *db.Handler) filesystem {
}
func (f filesystem) RunInTx(fn func(FileSystem) error) error {
return f.db.RunInTx(f.ctx, func(db *db.Handler) error {
return f.db.RunInTx(f.ctx, func(db db.Handler) error {
return fn(f.WithDb(db))
})
}
func (f filesystem) runInTx(fn func(filesystem) error) error {
return f.db.RunInTx(f.ctx, func(db *db.Handler) error {
return f.db.RunInTx(f.ctx, func(db db.Handler) error {
return fn(f.withDb(db))
})
}

View File

@@ -36,7 +36,7 @@ type FileSystem interface {
// filesystem.go
RootID() uuid.UUID
WithRoot(uuid.UUID) FileSystem
WithDb(db *db.Handler) FileSystem
WithDb(db db.Handler) FileSystem
RunInTx(fn func(FileSystem) error) error
// create.go
@@ -116,7 +116,7 @@ func OpenFromPublink(ctx context.Context, name string, password string, path str
return f.ResourceByPath(path)
}
func getPublink(d *db.Handler, ctx context.Context, name string) (Publink, error) {
func getPublink(d db.Handler, ctx context.Context, name string) (Publink, error) {
q := "SELECT FROM publinks p WHERE name = $1::TEXT AND deleted IS NULL"
if rows, err := d.Query(ctx, q, name); err != nil {
return Publink{}, err

View File

@@ -25,7 +25,7 @@ type Storage interface {
}
type storage struct {
db *db.Handler
db db.Handler
backends map[string]Backend
defaultBackend Backend
}
@@ -44,7 +44,7 @@ func Get() Storage {
return s
}
func create(ctx context.Context, db *db.Handler, defaultStorageDir string) (Storage, error) {
func create(ctx context.Context, db db.Handler, defaultStorageDir string) (Storage, error) {
if backends, err := restoreBackends(db, ctx); err != nil {
return nil, err
} else if defaultBackend, err := newLocalStorage("<default>", defaultStorageDir); err != nil {
@@ -144,7 +144,7 @@ func (s storage) ListBackends() map[string]Backend {
return s.backends
}
func restoreBackends(db *db.Handler, ctx context.Context) (map[string]Backend, error) {
func restoreBackends(db db.Handler, ctx context.Context) (map[string]Backend, error) {
const q = "SELECT name, driver, params from storage_backends"
if rows, err := db.Query(ctx, q); err != nil {
return nil, err

View File

@@ -7,7 +7,7 @@ import (
)
type manager struct {
db *db.Handler
db db.Handler
ctx context.Context
}
@@ -18,11 +18,11 @@ func CreateManager(ctx context.Context) Manager {
}
}
func (m manager) WithDb(db *db.Handler) Manager {
func (m manager) WithDb(db db.Handler) Manager {
return m.withDb(db)
}
func (m manager) withDb(db *db.Handler) manager {
func (m manager) withDb(db db.Handler) manager {
return manager{
ctx: m.ctx,
db: db,

View File

@@ -51,7 +51,7 @@ func (u User) OpenFileSystem(ctx context.Context) fs.FileSystem {
type Manager interface {
// manager.go
WithDb(db *db.Handler) Manager
WithDb(db db.Handler) Manager
// create.go
CreateUser(username, displayName, password string, root uuid.UUID) (User, error)