Files
PrivateCaptcha/pkg/db/postgres.go
2025-09-12 21:45:05 +03:00

214 lines
6.8 KiB
Go

package db
import (
"context"
"embed"
"io/fs"
"log/slog"
"strconv"
"time"
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/billing"
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/common"
config_pkg "github.com/PrivateCaptcha/PrivateCaptcha/pkg/config"
"github.com/golang-migrate/migrate/v4"
pgxmigrate "github.com/golang-migrate/migrate/v4/database/pgx/v5"
"github.com/golang-migrate/migrate/v4/source/iofs"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/stdlib"
)
const (
pgMigrationsSchema = "public"
pgIdleInTransactionSessionTimeout = 10 * time.Second
pgStatementTimeout = 10 * time.Second
)
//go:embed migrations/postgres/*.sql
var postgresMigrationsFS embed.FS
type myQueryTracer struct {
}
func (tracer *myQueryTracer) TraceQueryStart(
ctx context.Context,
_ *pgx.Conn,
data pgx.TraceQueryStartData) context.Context {
slog.Log(ctx, common.LevelTrace, "Starting SQL command", "sql", data.SQL, "args", data.Args, "source", "postgres")
return context.WithValue(ctx, common.TimeContextKey, time.Now())
}
func (tracer *myQueryTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
if data.Err != nil {
slog.Log(ctx, common.LevelTrace, "SQL command failed", common.ErrAttr(data.Err), "source", "postgres")
} else {
t, ok := ctx.Value(common.TimeContextKey).(time.Time)
if !ok {
t = time.Now()
}
slog.Log(ctx, common.LevelTrace, "SQL command finished", "source", "postgres", "duration", time.Since(t).Milliseconds())
}
}
func postgresUser(cfg common.ConfigStore, admin bool) string {
if admin {
if user := cfg.Get(common.PostgresAdminKey).Value(); len(user) > 0 {
return user
}
}
return cfg.Get(common.PostgresUserKey).Value()
}
func postgresPassword(cfg common.ConfigStore, admin bool) string {
if admin {
if pwd := cfg.Get(common.PostgresAdminPasswordKey).Value(); len(pwd) > 0 {
return pwd
}
}
return cfg.Get(common.PostgresPasswordKey).Value()
}
func createPgxConfig(ctx context.Context, cfg common.ConfigStore, migrate bool) (config *pgxpool.Config, err error) {
dbURL := cfg.Get(common.PostgresKey).Value()
config, err = pgxpool.ParseConfig(dbURL)
if err != nil {
slog.ErrorContext(ctx, "Failed to parse Postgres URL", "url", dbURL, common.ErrAttr(err))
return nil, err
}
if len(dbURL) == 0 {
config.ConnConfig.Host = cfg.Get(common.PostgresHostKey).Value()
config.ConnConfig.Port = 5432 // Default PostgreSQL port
config.ConnConfig.Database = cfg.Get(common.PostgresDBKey).Value()
config.ConnConfig.User = postgresUser(cfg, migrate)
config.ConnConfig.Password = postgresPassword(cfg, migrate)
config.ConnConfig.TLSConfig = nil // not using SSL
}
config.ConnConfig.Tracer = &myQueryTracer{}
config.ConnConfig.RuntimeParams["application_name"] = "privatecaptcha"
config.ConnConfig.RuntimeParams["idle_in_transaction_session_timeout"] =
strconv.Itoa(int(pgIdleInTransactionSessionTimeout.Milliseconds()))
config.ConnConfig.RuntimeParams["statement_timeout"] =
strconv.Itoa(int(pgStatementTimeout.Milliseconds()))
return
}
func connectPostgres(ctx context.Context, config *pgxpool.Config, timeout time.Duration) (*pgxpool.Pool, error) {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
timeoutExceeded := time.After(timeout)
for {
select {
case <-timeoutExceeded:
slog.ErrorContext(ctx, "Connection to Postgres failed", "timeout", timeout)
return nil, errConnectionTimeout
case <-ticker.C:
slog.DebugContext(ctx, "Connecting to Postgres...")
pool, err := pgxpool.NewWithConfig(ctx, config)
if err == nil {
return pool, nil
}
slog.ErrorContext(ctx, "Failed to create pgxpool", common.ErrAttr(err))
}
}
}
type PostgresMigrateContext struct {
Stage string
ExternalProductID string
ExternalPriceID string
ExternalStatus string
PortalLoginPropertyID string
PortalRegisterPropertyID string
PortalDomain string
AdminEmail string
PortalLoginDifficulty common.DifficultyLevel
PortalRegisterDifficulty common.DifficultyLevel
}
func NewPostgresMigrateContext(ctx context.Context, cfg common.ConfigStore, planService billing.PlanService) *PostgresMigrateContext {
stage := cfg.Get(common.StageKey).Value()
portalDomain := config_pkg.AsURL(ctx, cfg.Get(common.PortalBaseURLKey)).Domain()
adminPlan := planService.GetInternalAdminPlan()
_, priceIDYearly := adminPlan.PriceIDs()
return &PostgresMigrateContext{
Stage: stage,
PortalLoginPropertyID: PortalLoginPropertyID,
PortalRegisterPropertyID: PortalRegisterPropertyID,
PortalDomain: portalDomain,
AdminEmail: cfg.Get(common.AdminEmailKey).Value(),
ExternalProductID: adminPlan.ProductID(),
ExternalPriceID: priceIDYearly,
ExternalStatus: planService.ActiveTrialStatus(),
PortalLoginDifficulty: common.DifficultyLevelSmall,
PortalRegisterDifficulty: common.DifficultyLevelSmall,
}
}
func MigratePostgresEx(ctx context.Context, pool *pgxpool.Pool, migrationsFS fs.FS, tableName string, up bool) error {
db := stdlib.OpenDBFromPool(pool)
mlog := slog.With("up", up)
d, err := iofs.New(migrationsFS, "migrations/postgres")
if err != nil {
mlog.ErrorContext(ctx, "Failed to read from Postgres migrations IOFS", common.ErrAttr(err))
return err
}
// NOTE: beware the run migrations twice problem with migrate, related to search_path
// https://github.com/golang-migrate/migrate/blob/master/database/postgres/TUTORIAL.md#fix-issue-where-migrations-run-twice
// the fix is to add '&search_path=public' to the connection string to force specific schema (for migrations table only)
driver, err := pgxmigrate.WithInstance(db, &pgxmigrate.Config{
MigrationsTable: tableName,
SchemaName: pgMigrationsSchema,
})
if err != nil {
mlog.ErrorContext(ctx, "Failed to create migrate driver", common.ErrAttr(err))
return err
}
m, err := migrate.NewWithInstance("iofs", d, "postgres", driver)
if err != nil {
mlog.ErrorContext(ctx, "Failed to create migration engine for Postgres", common.ErrAttr(err))
return err
}
defer func() {
srcErr, dstErr := m.Close()
if srcErr != nil {
mlog.ErrorContext(ctx, "Source error when running migrations", common.ErrAttr(srcErr))
}
if dstErr != nil {
mlog.ErrorContext(ctx, "Destination error when running migrations", common.ErrAttr(dstErr))
}
mlog.DebugContext(ctx, "Closed Postgres migrate connection")
}()
mlog.DebugContext(ctx, "Running Postgres migrations...")
if up {
err = m.Up()
} else {
err = m.Down()
}
if err != nil && err != migrate.ErrNoChange {
mlog.ErrorContext(ctx, "Failed to apply migrations in Postgres", common.ErrAttr(err))
return err
}
mlog.DebugContext(ctx, "Postgres migrated", "changes", (err != migrate.ErrNoChange))
return nil
}