diff --git a/server/internal/api/auth/auth_bearer.go b/server/internal/api/auth/auth_bearer.go index 115574b4..1551c5fd 100644 --- a/server/internal/api/auth/auth_bearer.go +++ b/server/internal/api/auth/auth_bearer.go @@ -15,6 +15,7 @@ const errCodeTokenInvalid = "token_invalid" func CreateBearerAuthHandler(a *core.App) func(c *gin.Context) { return func(c *gin.Context) { + ctx := c.Request.Context() authHeader := c.GetHeader("Authorization") if authHeader == "" { panic(errors.Err{Status: 401, Code: errCodeAuthRequred}) @@ -27,14 +28,14 @@ func CreateBearerAuthHandler(a *core.App) func(c *gin.Context) { panic(errors.Err{Status: 401, Code: errCodeAuthRequred}) } - userID, err := a.VerifyAccessToken(authParts[1]) + userID, err := a.VerifyAccessToken(ctx, authParts[1]) if err != nil { if errors.Is(err, core.ErrTokenExpired) || errors.Is(err, core.ErrTokenInvalid) { panic(errors.Err{Status: 401, Code: errCodeTokenInvalid}) } panic(err) } - if fs, err := a.OpenFileSystem(c.Request.Context(), userID); err != nil { + if fs, err := a.OpenFileSystem(ctx, userID); err != nil { logrus.Warn(err) c.AbortWithStatus(http.StatusInternalServerError) } else { diff --git a/server/internal/api/routes/auth.go b/server/internal/api/routes/auth.go index 1b0bb4e6..92ab51b4 100644 --- a/server/internal/api/routes/auth.go +++ b/server/internal/api/routes/auth.go @@ -24,7 +24,7 @@ func createLoginRouteHandler(a *core.App) func(c *gin.Context) { panic(errors.New(http.StatusBadRequest, "missing_password", "")) } - if token, err := a.CreateAccessToken(username, password); err != nil { + if token, err := a.CreateAccessToken(c.Request.Context(), username, password); err != nil { if errors.Is(err, core.ErrCredentialsInvalid) { panic(errors.New(http.StatusUnauthorized, "credentials_invalid", "")) } diff --git a/server/internal/command/appcmd/appcmd.go b/server/internal/command/appcmd/appcmd.go index 27b9843d..c8ac94fe 100644 --- a/server/internal/command/appcmd/appcmd.go +++ b/server/internal/command/appcmd/appcmd.go @@ -1,6 +1,8 @@ package appcmd import ( + "context" + "github.com/shroff/phylum/server/internal/core" "github.com/shroff/phylum/server/internal/db" "github.com/shroff/phylum/server/internal/storage" @@ -19,15 +21,15 @@ func SetupCommand(db **db.DbHandler, debug bool) *cobra.Command { for ; c.Parent() != nil; c = c.Parent() { } c.PersistentPreRun(cmd, args) - if err := (*db).CheckVersion(!viper.GetBool("no_auto_migrate")); err != nil { + if err := (*db).CheckVersion(context.Background(), !viper.GetBool("no_auto_migrate")); err != nil { logrus.Fatal(err) } var err error - if cs, err = storage.Open(*db, viper.GetString("content-dir")); err != nil { + if cs, err = storage.Open(*db, context.Background(), viper.GetString("content-dir")); err != nil { logrus.Fatal(err) } else { - if err := core.Create(*db, cs, debug); err != nil { + if err := core.Create(context.Background(), *db, cs, debug); err != nil { logrus.Fatal(err) } } diff --git a/server/internal/command/appcmd/storage.go b/server/internal/command/appcmd/storage.go index 7946a881..35a586ce 100644 --- a/server/internal/command/appcmd/storage.go +++ b/server/internal/command/appcmd/storage.go @@ -2,6 +2,7 @@ package appcmd import ( "bufio" + "context" "fmt" "os" "strings" @@ -48,7 +49,7 @@ func setupStorageCreateCommand(cs storage.Storage) *cobra.Command { params[paramName] = val } - if err := cs.CreateBackend(name, driver, params); err != nil { + if err := cs.CreateBackend(context.Background(), name, driver, params); err != nil { logrus.Fatal(err) } diff --git a/server/internal/command/appcmd/user.go b/server/internal/command/appcmd/user.go index 66c7c00c..9ed23bbe 100644 --- a/server/internal/command/appcmd/user.go +++ b/server/internal/command/appcmd/user.go @@ -103,7 +103,7 @@ func setupUserLoginCommand() *cobra.Command { } password := string(bytes) - accessToken, err := core.Default.CreateAccessToken(username, password) + accessToken, err := core.Default.CreateAccessToken(context.Background(), username, password) if err != nil { logrus.Fatal(err) } diff --git a/server/internal/command/command.go b/server/internal/command/command.go index c53805e6..d6c159fb 100644 --- a/server/internal/command/command.go +++ b/server/internal/command/command.go @@ -1,6 +1,7 @@ package command import ( + "context" "os" "path" @@ -46,7 +47,7 @@ func SetupCommand() { } var err error - if database, err = db.NewDb(viper.GetString("database_url"), debug && viper.GetBool("trace_sql")); err != nil { + if database, err = db.NewDb(context.Background(), viper.GetString("database_url"), debug && viper.GetBool("trace_sql")); err != nil { logrus.Fatal(err) } } diff --git a/server/internal/command/schema.go b/server/internal/command/schema.go index 7684edec..1fa7c1de 100644 --- a/server/internal/command/schema.go +++ b/server/internal/command/schema.go @@ -1,6 +1,7 @@ package command import ( + "context" "strconv" "github.com/shroff/phylum/server/internal/db" @@ -25,7 +26,7 @@ func setupSchemaResetCommand(db **db.DbHandler) *cobra.Command { Use: "reset", Short: "Reset Database Schema", Run: func(cmd *cobra.Command, args []string) { - if err := (*db).DeleteSchema(); err != nil { + if err := (*db).DeleteSchema(context.Background()); err != nil { logrus.Fatal(err) } }, @@ -38,17 +39,18 @@ func setupSchemaMigrateCommand(db **db.DbHandler) *cobra.Command { Short: "Migrate Database Schema", Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { + ctx := context.Background() if args[0] == "auto" || args[0] == "latest" { - if err := (*db).CheckVersion(true); err != nil { + if err := (*db).CheckVersion(ctx, true); err != nil { logrus.Fatal(err) } } else { - (*db).CheckVersion(false) + (*db).CheckVersion(ctx, false) v, err := strconv.Atoi(args[0]) if err != nil { logrus.Fatal(err) } - if err = (*db).Migrate(v); err != nil { + if err = (*db).Migrate(ctx, v); err != nil { logrus.Fatal(err) } } diff --git a/server/internal/core/app.go b/server/internal/core/app.go index e8c85172..26e8598e 100644 --- a/server/internal/core/app.go +++ b/server/internal/core/app.go @@ -21,20 +21,18 @@ type App struct { var Default *App -func Create(db *db.DbHandler, cs storage.Storage, debug bool) error { +func Create(ctx context.Context, db *db.DbHandler, cs storage.Storage, debug bool) error { Default = &App{ Debug: debug, db: db, cs: cs, } - return Default.setupAppData() + return Default.setupAppData(ctx) } -func (a App) setupAppData() error { - ctx := context.Background() - - _, err := a.db.Queries().UserByUsername(context.Background(), defaultUserName) +func (a App) setupAppData(ctx context.Context) error { + _, err := a.db.Queries().UserByUsername(ctx, defaultUserName) // Root user found. We can assume that setup has been done before if !errors.Is(err, pgx.ErrNoRows) { return err diff --git a/server/internal/core/auth.go b/server/internal/core/auth.go index 1c81818b..57cf3037 100644 --- a/server/internal/core/auth.go +++ b/server/internal/core/auth.go @@ -25,7 +25,7 @@ var ErrTokenInvalid = errors.New("token invalid") var ErrTokenExpired = errors.New("token expired") func (a App) VerifyUserPassword(ctx context.Context, username, password string) (User, error) { - if user, err := a.db.Queries().UserByUsername(context.Background(), username); err != nil { + if user, err := a.db.Queries().UserByUsername(ctx, username); err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil, ErrCredentialsInvalid } @@ -36,7 +36,7 @@ func (a App) VerifyUserPassword(ctx context.Context, username, password string) } else if !b { return nil, ErrCredentialsInvalid } - if user, err := a.UserByID(context.Background(), user.ID); err != nil { + if user, err := a.UserByID(ctx, user.ID); err != nil { return nil, err } else { return user, nil @@ -44,11 +44,11 @@ func (a App) VerifyUserPassword(ctx context.Context, username, password string) } } -func (a App) CreateAccessToken(username, password string) (db.AccessToken, error) { - if user, err := a.VerifyUserPassword(context.Background(), username, password); err != nil { +func (a App) CreateAccessToken(ctx context.Context, username, password string) (db.AccessToken, error) { + if user, err := a.VerifyUserPassword(ctx, username, password); err != nil { return db.AccessToken{}, err } else { - if token, err := a.db.Queries().InsertAccessToken(context.Background(), db.InsertAccessTokenParams{ + if token, err := a.db.Queries().InsertAccessToken(ctx, db.InsertAccessTokenParams{ ID: GenerateRandomString(accessTokenLength), Validity: accessTokenValiditiy, UserID: user.ID(), @@ -60,8 +60,8 @@ func (a App) CreateAccessToken(username, password string) (db.AccessToken, error } } -func (a App) VerifyAccessToken(accessToken string) (User, error) { - token, err := a.db.Queries().AccessTokenById(context.Background(), accessToken) +func (a App) VerifyAccessToken(ctx context.Context, accessToken string) (User, error) { + token, err := a.db.Queries().AccessTokenById(ctx, accessToken) if errors.Is(err, pgx.ErrNoRows) { return nil, ErrTokenInvalid } else if err != nil { @@ -70,7 +70,7 @@ func (a App) VerifyAccessToken(accessToken string) (User, error) { if time.Now().After(token.Expires.Time) { return nil, ErrTokenExpired } - if user, err := a.UserByID(context.Background(), token.UserID); err != nil { + if user, err := a.UserByID(ctx, token.UserID); err != nil { return nil, err } else { return user, nil diff --git a/server/internal/db/db_handler.go b/server/internal/db/db_handler.go index dc7f6ffc..8efb6eb3 100644 --- a/server/internal/db/db_handler.go +++ b/server/internal/db/db_handler.go @@ -25,7 +25,7 @@ type DbHandler struct { queries *Queries } -func NewDb(dsn string, trace bool) (*DbHandler, error) { +func NewDb(ctx context.Context, dsn string, trace bool) (*DbHandler, error) { config, err := pgxpool.ParseConfig(dsn) if err != nil { return nil, errors.New("Unable to parse DSN: " + err.Error()) @@ -35,22 +35,22 @@ func NewDb(dsn string, trace bool) (*DbHandler, error) { config.ConnConfig.Tracer = tracer{} } - pool, err := pgxpool.NewWithConfig(context.Background(), config) + pool, err := pgxpool.NewWithConfig(ctx, config) if err != nil { return nil, errors.New("Unable to connect to database: " + err.Error()) } logrus.Info("Connected to " + config.ConnConfig.Database + " at " + config.ConnConfig.Host + ":" + fmt.Sprint(config.ConnConfig.Port)) - conn, err := pool.Acquire(context.Background()) + conn, err := pool.Acquire(ctx) if err != nil { return nil, err } - migrator, err := migrations.NewMigrator(conn.Conn()) + migrator, err := migrations.NewMigrator(ctx, conn.Conn()) if err != nil { return nil, err } - currentSchemaVersion, err = migrator.GetCurrentVersion() + currentSchemaVersion, err = migrator.GetCurrentVersion(ctx) if err != nil { return nil, err } @@ -65,7 +65,7 @@ func NewDb(dsn string, trace bool) (*DbHandler, error) { return handler, nil } -func (d DbHandler) CheckVersion(autoMigrate bool) error { +func (d DbHandler) CheckVersion(ctx context.Context, autoMigrate bool) error { logrus.Info(fmt.Sprintf("Schema version %d", currentSchemaVersion)) if currentSchemaVersion != latestSchemaVersion { if autoMigrate { @@ -75,7 +75,7 @@ func (d DbHandler) CheckVersion(autoMigrate bool) error { if currentSchemaVersion == latestSchemaVersion { return nil } - return d.Migrate(latestSchemaVersion) + return d.Migrate(ctx, latestSchemaVersion) } else { logrus.Warn(fmt.Sprintf("Schema version is not at latest (%d)", latestSchemaVersion)) return nil @@ -84,8 +84,8 @@ func (d DbHandler) CheckVersion(autoMigrate bool) error { return nil } -func (d DbHandler) Migrate(version int) error { - conn, err := d.pool.Acquire(context.Background()) +func (d DbHandler) Migrate(ctx context.Context, version int) error { + conn, err := d.pool.Acquire(ctx) if err != nil { return err } @@ -96,22 +96,21 @@ func (d DbHandler) Migrate(version int) error { if version > latestSchemaVersion { return ErrMigrationTargetTooHigh } - migrator, err := migrations.NewMigrator(conn.Conn()) + migrator, err := migrations.NewMigrator(ctx, conn.Conn()) if err != nil { return err } logrus.Info(fmt.Sprintf("Migrating database from version %d to %d", currentSchemaVersion, version)) - if err = migrator.MigrateTo(int32(version)); err != nil { + if err = migrator.MigrateTo(ctx, int32(version)); err != nil { return err } return nil } -func (d DbHandler) DeleteSchema() (e error) { - ctx := context.Background() +func (d DbHandler) DeleteSchema(ctx context.Context) (e error) { return pgx.BeginFunc(ctx, d.pool, func(tx pgx.Tx) error { user := d.pool.Config().ConnConfig.User - _, err := tx.Exec(context.Background(), "DROP SCHEMA public CASCADE;"+ + _, err := tx.Exec(ctx, "DROP SCHEMA public CASCADE;"+ "CREATE SCHEMA public;"+ "GRANT ALL ON SCHEMA public TO "+user+";"+ "GRANT ALL ON SCHEMA public TO public;"+ diff --git a/server/internal/db/migrations/migrations.go b/server/internal/db/migrations/migrations.go index a7609882..a9874725 100644 --- a/server/internal/db/migrations/migrations.go +++ b/server/internal/db/migrations/migrations.go @@ -18,9 +18,9 @@ type Migrator struct { //go:embed data/*.sql var migrationFiles embed.FS -func NewMigrator(conn *pgx.Conn) (Migrator, error) { +func NewMigrator(ctx context.Context, conn *pgx.Conn) (Migrator, error) { migrator, err := migrate.NewMigratorEx( - context.Background(), conn, versionTable, + ctx, conn, versionTable, &migrate.MigratorOptions{ DisableTx: false, }) @@ -40,8 +40,8 @@ func NewMigrator(conn *pgx.Conn) (Migrator, error) { }, nil } -func (m Migrator) GetCurrentVersion() (int, error) { - version, err := m.migrator.GetCurrentVersion(context.Background()) +func (m Migrator) GetCurrentVersion(ctx context.Context) (int, error) { + version, err := m.migrator.GetCurrentVersion(ctx) if err != nil { return 0, err } @@ -58,13 +58,13 @@ func (m Migrator) Migrations() []*migrate.Migration { } // Migrate migrates the DB to the most recent version of the schema. -func (m Migrator) Migrate() error { - err := m.migrator.Migrate(context.Background()) +func (m Migrator) Migrate(ctx context.Context) error { + err := m.migrator.Migrate(ctx) return err } // MigrateTo migrates to a specific version of the schema. Use '0' to undo all migrations. -func (m Migrator) MigrateTo(ver int32) error { - err := m.migrator.MigrateTo(context.Background(), ver) +func (m Migrator) MigrateTo(ctx context.Context, ver int32) error { + err := m.migrator.MigrateTo(ctx, ver) return err } diff --git a/server/internal/storage/storage.go b/server/internal/storage/storage.go index c4ede113..d0912993 100644 --- a/server/internal/storage/storage.go +++ b/server/internal/storage/storage.go @@ -10,7 +10,7 @@ import ( ) type Storage interface { - CreateBackend(name string, driver string, params map[string]string) error + CreateBackend(ctx context.Context, name string, driver string, params map[string]string) error ListBackends() map[string]Backend OpenRead(id uuid.UUID, start, length int64) (io.ReadCloser, error) OpenWrite(id uuid.UUID, callback func(int, string) error) (io.WriteCloser, error) @@ -24,8 +24,8 @@ type storage struct { defaultBackend Backend } -func Open(db *db.DbHandler, contentDir string) (Storage, error) { - if backends, err := restoreBackends(db); err != nil { +func Open(db *db.DbHandler, ctx context.Context, contentDir string) (Storage, error) { + if backends, err := restoreBackends(db, ctx); err != nil { return nil, err } else { return storage{ @@ -76,12 +76,12 @@ func (s *storage) findStorageBackend(id uuid.UUID) (Backend, error) { return s.defaultBackend, nil } -func (s storage) CreateBackend(name string, driver string, params map[string]string) error { +func (s storage) CreateBackend(ctx context.Context, name string, driver string, params map[string]string) error { backend, err := openBackend(name, driver, params) if err != nil { return nil } - err = s.db.Queries().CreateStorageBackend(context.Background(), db.CreateStorageBackendParams{ + err = s.db.Queries().CreateStorageBackend(ctx, db.CreateStorageBackendParams{ Name: name, Driver: driver, Params: params, @@ -97,8 +97,8 @@ func (s storage) ListBackends() map[string]Backend { return s.backends } -func restoreBackends(db *db.DbHandler) (map[string]Backend, error) { - backends, err := db.Queries().AllStorageBackends(context.Background()) +func restoreBackends(db *db.DbHandler, ctx context.Context) (map[string]Backend, error) { + backends, err := db.Queries().AllStorageBackends(ctx) if err != nil { return nil, err }