mirror of
https://github.com/PrivateCaptcha/PrivateCaptcha.git
synced 2026-02-08 14:59:25 -06:00
548 lines
18 KiB
Go
548 lines
18 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/api"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/billing"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/common"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/config"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/db"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/difficulty"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/email"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/leakybucket"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/maintenance"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/monitoring"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/portal"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/ratelimit"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/session"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/web"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/widget"
|
|
"github.com/justinas/alice"
|
|
)
|
|
|
|
const (
|
|
modeMigrate = "migrate"
|
|
modeRollback = "rollback"
|
|
modeServer = "server"
|
|
modeAuto = "auto"
|
|
_readinessDrainDelay = 5 * time.Second
|
|
_shutdownHardPeriod = 5 * time.Second
|
|
_shutdownPeriod = 10 * time.Second
|
|
_dbConnectTimeout = 30 * time.Second
|
|
_sessionPersistInterval = 10 * time.Second
|
|
_auditLogInterval = 10 * time.Second
|
|
)
|
|
|
|
const (
|
|
// for puzzles the logic is that if something becomes popular, there will be a spike, but normal usage should be "low"
|
|
// NOTE: this assumes correct configuration of the whole chain of reverse proxies
|
|
// the main problem are NATs/VPNs that make possible for clump of legitimate users to actually come from 1 public IP
|
|
generalLeakyBucketCap = 20
|
|
generalLeakInterval = 1 * time.Second
|
|
// public defaults are reasonably low but we assume we should be fully cached on CDN level
|
|
publicLeakyBucketCap = 8
|
|
publicLeakInterval = 2 * time.Second
|
|
// catch call defaults are even lower
|
|
catchAllLeakyBucketCap = 2
|
|
catchAllLeakInterval = 30 * time.Second
|
|
)
|
|
|
|
var (
|
|
GitCommit string
|
|
flagMode = flag.String("mode", "", strings.Join([]string{modeMigrate, modeServer, modeRollback, modeAuto}, " | "))
|
|
envFileFlag = flag.String("env", "", "Path to .env file, 'stdin' or empty")
|
|
versionFlag = flag.Bool("version", false, "Print version and exit")
|
|
migrateHashFlag = flag.String("migrate-hash", "", "Target migration version (git commit)")
|
|
certFileFlag = flag.String("certfile", "", "certificate PEM file (e.g. cert.pem)")
|
|
keyFileFlag = flag.String("keyfile", "", "key PEM file (e.g. key.pem)")
|
|
env *common.EnvMap
|
|
)
|
|
|
|
func listenAddress(cfg common.ConfigStore) string {
|
|
host := cfg.Get(common.HostKey).Value()
|
|
if host == "" {
|
|
host = "localhost"
|
|
}
|
|
|
|
port := cfg.Get(common.PortKey).Value()
|
|
if port == "" {
|
|
port = "8080"
|
|
}
|
|
address := net.JoinHostPort(host, port)
|
|
return address
|
|
}
|
|
|
|
func createListener(ctx context.Context, cfg common.ConfigStore) (net.Listener, error) {
|
|
address := listenAddress(cfg)
|
|
listener, err := net.Listen("tcp", address)
|
|
if err != nil {
|
|
slog.ErrorContext(ctx, "Failed to listen", "address", address, common.ErrAttr(err))
|
|
return nil, err
|
|
}
|
|
|
|
if useTLS := (*certFileFlag != "") && (*keyFileFlag != ""); useTLS {
|
|
cert, err := tls.LoadX509KeyPair(*certFileFlag, *keyFileFlag)
|
|
if err != nil {
|
|
slog.ErrorContext(ctx, "Failed to load certificates", "cert", *certFileFlag, "key", *keyFileFlag, common.ErrAttr(err))
|
|
return nil, err
|
|
}
|
|
tlsConfig := &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
}
|
|
listener = tls.NewListener(listener, tlsConfig)
|
|
}
|
|
|
|
return listener, nil
|
|
}
|
|
|
|
func newIPAddrBuckets(cfg common.ConfigStore) *ratelimit.IPAddrBuckets {
|
|
const (
|
|
// number of simultaneous different clients for public APIs (/puzzle, /siteverify etc.), before forcing cleanup
|
|
maxBuckets = 1_000_000
|
|
)
|
|
|
|
puzzleBucketRate := cfg.Get(common.RateLimitRateKey)
|
|
puzzleBucketBurst := cfg.Get(common.RateLimitBurstKey)
|
|
|
|
return ratelimit.NewIPAddrBuckets(maxBuckets,
|
|
leakybucket.Cap(puzzleBucketBurst.Value(), generalLeakyBucketCap),
|
|
leakybucket.Interval(puzzleBucketRate.Value(), generalLeakInterval))
|
|
}
|
|
|
|
func updateIPBuckets(cfg common.ConfigStore, rateLimiter ratelimit.HTTPRateLimiter) {
|
|
bucketRate := cfg.Get(common.RateLimitRateKey)
|
|
bucketBurst := cfg.Get(common.RateLimitBurstKey)
|
|
rateLimiter.UpdateLimits(
|
|
leakybucket.Cap(bucketBurst.Value(), generalLeakyBucketCap),
|
|
leakybucket.Interval(bucketRate.Value(), generalLeakInterval))
|
|
}
|
|
|
|
func run(ctx context.Context, cfg common.ConfigStore, stderr io.Writer, listener net.Listener) error {
|
|
stage := cfg.Get(common.StageKey).Value()
|
|
verbose := config.AsBool(cfg.Get(common.VerboseKey))
|
|
logLevel := common.SetupLogs(stage, verbose)
|
|
|
|
planService := billing.NewPlanService(nil)
|
|
|
|
pool, clickhouse, dberr := db.Connect(ctx, cfg, _dbConnectTimeout, false /*admin*/)
|
|
if dberr != nil {
|
|
return dberr
|
|
}
|
|
|
|
defer pool.Close()
|
|
defer clickhouse.Close()
|
|
|
|
businessDB := db.NewBusiness(pool)
|
|
timeSeriesDB := db.NewTimeSeries(clickhouse, businessDB.Cache)
|
|
|
|
puzzleVerifier := api.NewVerifier(cfg, businessDB)
|
|
|
|
metrics := monitoring.NewService()
|
|
|
|
cdnURLConfig := config.AsURL(ctx, cfg.Get(common.CDNBaseURLKey))
|
|
portalURLConfig := config.AsURL(ctx, cfg.Get(common.PortalBaseURLKey))
|
|
|
|
sender := email.NewMailSender(cfg)
|
|
mailer := portal.NewPortalMailer("https:"+cdnURLConfig.URL(), "https:"+portalURLConfig.URL(), sender, cfg)
|
|
|
|
rateLimitHeader := cfg.Get(common.RateLimitHeaderKey).Value()
|
|
ipRateLimiter := ratelimit.NewIPAddrRateLimiter(rateLimitHeader, newIPAddrBuckets(cfg))
|
|
userLimiter := api.NewUserLimiter(businessDB)
|
|
subscriptionLimits := db.NewSubscriptionLimits(stage, businessDB, planService)
|
|
idHasher := common.NewIDHasher(cfg.Get(common.IDHasherSaltKey))
|
|
|
|
// special case for async jobs (register handlers before adding)
|
|
asyncTasksJob := maintenance.NewAsyncTasksJob(businessDB)
|
|
|
|
apiServer := &api.Server{
|
|
Stage: stage,
|
|
BusinessDB: businessDB,
|
|
TimeSeries: timeSeriesDB,
|
|
RateLimiter: ipRateLimiter,
|
|
Auth: api.NewAuthMiddleware(businessDB, userLimiter, planService, metrics),
|
|
VerifyLogChan: make(chan *common.VerifyRecord, 10*api.VerifyBatchSize),
|
|
Verifier: puzzleVerifier,
|
|
Metrics: metrics,
|
|
Mailer: mailer,
|
|
Levels: difficulty.NewLevels(timeSeriesDB, 100 /*levelsBatchSize*/, api.PropertyBucketSize),
|
|
VerifyLogCancel: func() {},
|
|
SubscriptionLimits: subscriptionLimits,
|
|
IDHasher: idHasher,
|
|
AsyncTasks: asyncTasksJob,
|
|
}
|
|
if err := apiServer.Init(ctx, 10*time.Second /*flush interval*/, 1*time.Second /*backfill duration*/, 50*time.Millisecond /*backpressure timeout*/); err != nil {
|
|
return err
|
|
}
|
|
|
|
dataCtx, err := web.LoadData()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
apiURLConfig := config.AsURL(ctx, cfg.Get(common.APIBaseURLKey))
|
|
sessionStore := db.NewSessionStore(businessDB, session.KeyPersistent, metrics)
|
|
xsrfKey := cfg.Get(common.XSRFKeyKey)
|
|
portalServer := &portal.Server{
|
|
Stage: stage,
|
|
Store: businessDB,
|
|
TimeSeries: timeSeriesDB,
|
|
XSRF: &common.XSRFMiddleware{Key: xsrfKey.Value(), Timeout: 1 * time.Hour},
|
|
Sessions: &session.Manager{
|
|
CookieName: "pcsid",
|
|
Store: sessionStore,
|
|
MaxLifetime: sessionStore.TTL(),
|
|
SecureCookie: (*certFileFlag != "") && (*keyFileFlag != ""),
|
|
},
|
|
PlanService: planService,
|
|
APIURL: apiURLConfig.URL(),
|
|
CDNURL: cdnURLConfig.URL(),
|
|
PuzzleEngine: apiServer.ReportingVerifier(),
|
|
Metrics: metrics,
|
|
Mailer: mailer,
|
|
RateLimiter: ipRateLimiter,
|
|
DataCtx: dataCtx,
|
|
IDHasher: idHasher,
|
|
CountryCodeHeader: cfg.Get(common.CountryCodeHeaderKey),
|
|
UserLimiter: userLimiter,
|
|
SubscriptionLimits: subscriptionLimits,
|
|
EmailVerifier: &portal.PortalEmailVerifier{},
|
|
TwoFactorDuration: 10*time.Minute + 5*time.Minute,
|
|
}
|
|
|
|
templatesBuilder := portal.NewTemplatesBuilder()
|
|
if err := templatesBuilder.AddFS(ctx, web.Templates(), "core"); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := portalServer.Init(ctx, templatesBuilder, GitCommit, _sessionPersistInterval); err != nil {
|
|
return err
|
|
}
|
|
|
|
healthCheck := &maintenance.HealthCheckJob{
|
|
BusinessDB: businessDB,
|
|
TimeSeriesDB: timeSeriesDB,
|
|
CheckInterval: cfg.Get(common.HealthCheckIntervalKey),
|
|
Metrics: metrics,
|
|
}
|
|
jobs := maintenance.NewJobs(businessDB)
|
|
|
|
updateConfigFunc := func(ctx context.Context) {
|
|
cfg.Update(ctx)
|
|
updateIPBuckets(cfg, ipRateLimiter)
|
|
maintenanceMode := config.AsBool(cfg.Get(common.MaintenanceModeKey))
|
|
businessDB.UpdateConfig(maintenanceMode)
|
|
timeSeriesDB.UpdateConfig(maintenanceMode)
|
|
portalServer.UpdateConfig(ctx, cfg)
|
|
jobs.UpdateConfig(cfg)
|
|
verboseLogs := config.AsBool(cfg.Get(common.VerboseKey))
|
|
common.SetLogLevel(logLevel, verboseLogs)
|
|
}
|
|
updateConfigFunc(ctx)
|
|
|
|
quit := make(chan struct{})
|
|
quitFunc := func(ctx context.Context) {
|
|
slog.DebugContext(ctx, "Server quit triggered")
|
|
healthCheck.Shutdown(ctx)
|
|
// Give time for readiness check to propagate
|
|
time.Sleep(min(_readinessDrainDelay, healthCheck.Interval()))
|
|
close(quit)
|
|
}
|
|
|
|
checkLicenseJob, err := maintenance.NewCheckLicenseJob(businessDB, cfg, GitCommit, quitFunc)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// nolint:errcheck
|
|
go common.RunPeriodicJobOnce(common.TraceContext(context.Background(), "check_license"), checkLicenseJob, checkLicenseJob.NewParams())
|
|
|
|
router := http.NewServeMux()
|
|
apiServer.Setup(apiURLConfig.Domain(), verbose, common.NoopMiddleware).Register(router)
|
|
portalDomain := portalURLConfig.Domain()
|
|
portalServer.Setup(portalDomain, common.NoopMiddleware).Register(router)
|
|
rateLimiter := ipRateLimiter.RateLimitExFunc(publicLeakyBucketCap, publicLeakInterval)
|
|
cdnDomain := cdnURLConfig.Domain()
|
|
cdnChain := alice.New(common.Recovered, metrics.CDNHandler, rateLimiter)
|
|
router.Handle("GET "+cdnDomain+"/portal/", http.StripPrefix("/portal/", cdnChain.Then(web.Static(GitCommit))))
|
|
router.Handle("GET "+cdnDomain+"/widget/", http.StripPrefix("/widget/", cdnChain.Then(widget.Static(GitCommit))))
|
|
// "protection" (NOTE: different than usual order of monitoring)
|
|
publicChain := alice.New(common.Recovered, metrics.IgnoredHandler, rateLimiter)
|
|
portalServer.SetupCatchAll(router, portalDomain, publicChain)
|
|
// catch all routes with stricter limit
|
|
catchAllRateLimiter := ipRateLimiter.RateLimitExFunc(catchAllLeakyBucketCap, catchAllLeakInterval)
|
|
catchAllChain := alice.New(common.Recovered, metrics.IgnoredHandler, catchAllRateLimiter)
|
|
router.Handle("/", catchAllChain.ThenFunc(common.CatchAll))
|
|
|
|
ongoingCtx, stopOngoingGracefully := context.WithCancel(context.Background())
|
|
httpServer := &http.Server{
|
|
Handler: router,
|
|
ReadHeaderTimeout: 5 * time.Second,
|
|
ReadTimeout: 10 * time.Second,
|
|
WriteTimeout: 10 * time.Second,
|
|
IdleTimeout: 60 * time.Second,
|
|
MaxHeaderBytes: 1024 * 1024,
|
|
BaseContext: func(_ net.Listener) context.Context {
|
|
return ongoingCtx
|
|
},
|
|
ErrorLog: slog.NewLogLogger(slog.Default().Handler(), slog.LevelError),
|
|
}
|
|
|
|
go func(ctx context.Context) {
|
|
signals := make(chan os.Signal, 1)
|
|
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
|
|
defer func() {
|
|
signal.Stop(signals)
|
|
close(signals)
|
|
}()
|
|
for {
|
|
sig, ok := <-signals
|
|
if !ok {
|
|
slog.DebugContext(ctx, "Signals channel closed")
|
|
return
|
|
}
|
|
slog.DebugContext(ctx, "Received signal", "signal", sig)
|
|
switch sig {
|
|
case syscall.SIGHUP:
|
|
if uerr := env.Update(); uerr != nil {
|
|
slog.ErrorContext(ctx, "Failed to update environment", common.ErrAttr(uerr))
|
|
}
|
|
updateConfigFunc(ctx)
|
|
case syscall.SIGINT, syscall.SIGTERM:
|
|
quitFunc(ctx)
|
|
return
|
|
}
|
|
}
|
|
}(common.TraceContext(context.Background(), "signal_handler"))
|
|
|
|
go func() {
|
|
slog.InfoContext(ctx, "Listening", "address", listener.Addr().String(), "version", GitCommit, "stage", stage)
|
|
if err := httpServer.Serve(listener); err != nil && err != http.ErrServerClosed {
|
|
slog.ErrorContext(ctx, "Error serving", common.ErrAttr(err))
|
|
}
|
|
}()
|
|
|
|
businessDB.Start(ctx, _auditLogInterval)
|
|
|
|
jobs.Spawn(healthCheck)
|
|
// start maintenance jobs
|
|
jobs.Add(&maintenance.CleanupDBCacheJob{Store: businessDB})
|
|
jobs.Add(&maintenance.CleanupDeletedRecordsJob{Store: businessDB, Age: 365 * 24 * time.Hour})
|
|
jobs.AddLocked(24*time.Hour, &maintenance.GarbageCollectDataJob{
|
|
Age: 30 * 24 * time.Hour,
|
|
BusinessDB: businessDB,
|
|
TimeSeries: timeSeriesDB,
|
|
})
|
|
jobs.AddOneOff(&maintenance.WarmupPortalAuthJob{
|
|
Store: businessDB,
|
|
RegistrationAllowed: config.AsBool(cfg.Get(common.RegistrationAllowedKey)),
|
|
})
|
|
jobs.AddOneOff(&maintenance.WarmupAPICacheJob{
|
|
Store: businessDB,
|
|
TimeSeries: timeSeriesDB,
|
|
Backoff: 200 * time.Millisecond,
|
|
Limit: 50,
|
|
})
|
|
jobs.AddLocked(2*time.Hour, checkLicenseJob)
|
|
jobs.AddOneOff(&maintenance.RegisterEmailTemplatesJob{
|
|
Templates: email.Templates(),
|
|
Store: businessDB,
|
|
})
|
|
jobs.AddLocked(1*time.Hour, &maintenance.UserEmailNotificationsJob{
|
|
RunInterval: 3 * time.Hour, // overlap few locked intervals to cover for possible unprocessed notifications
|
|
Store: businessDB,
|
|
Templates: email.Templates(),
|
|
Sender: sender,
|
|
ChunkSize: 50,
|
|
MaxAttempts: 5,
|
|
EmailFrom: cfg.Get(common.EmailFromKey),
|
|
ReplyToEmail: cfg.Get(common.ReplyToEmailKey),
|
|
PlanService: planService,
|
|
CDNURL: mailer.CDNURL,
|
|
PortalURL: mailer.PortalURL,
|
|
})
|
|
jobs.AddLocked(24*time.Hour, &maintenance.CleanupUserNotificationsJob{
|
|
Store: businessDB,
|
|
NotificationMonths: 6,
|
|
TemplateMonths: 7,
|
|
})
|
|
jobs.AddLocked(3*time.Hour, &maintenance.ExpireInternalTrialsJob{
|
|
PastInterval: 3 * time.Hour,
|
|
Age: 24 * time.Hour,
|
|
BusinessDB: businessDB,
|
|
PlanService: planService,
|
|
})
|
|
jobs.AddLocked(24*time.Hour, &maintenance.CleanupAuditLogJob{
|
|
PastInterval: portal.MaxAuditLogsRetention(cfg),
|
|
BusinessDB: businessDB,
|
|
})
|
|
jobs.AddLocked(24*time.Hour, &maintenance.CleanupAsyncTasksJob{
|
|
PastInterval: 30 * 24 * time.Hour,
|
|
BusinessDB: businessDB,
|
|
})
|
|
jobs.AddLocked(10*time.Minute, asyncTasksJob)
|
|
|
|
jobs.RunAll()
|
|
|
|
var localServer *http.Server
|
|
if localAddress := cfg.Get(common.LocalAddressKey).Value(); len(localAddress) > 0 {
|
|
localRouter := http.NewServeMux()
|
|
metrics.Setup(localRouter)
|
|
jobs.Setup(localRouter, cfg)
|
|
localRouter.Handle(http.MethodGet+" /"+common.LiveEndpoint, common.Recovered(http.HandlerFunc(healthCheck.LiveHandler)))
|
|
localRouter.Handle(http.MethodGet+" /"+common.ReadyEndpoint, common.Recovered(http.HandlerFunc(healthCheck.ReadyHandler)))
|
|
localServer = &http.Server{
|
|
Addr: localAddress,
|
|
Handler: localRouter,
|
|
ReadHeaderTimeout: 5 * time.Second,
|
|
}
|
|
go func() {
|
|
slog.InfoContext(ctx, "Serving local API", "address", localServer.Addr)
|
|
if err := localServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
slog.ErrorContext(ctx, "Error serving local API", common.ErrAttr(err))
|
|
}
|
|
}()
|
|
} else {
|
|
slog.DebugContext(ctx, "Skipping serving local API")
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
<-quit
|
|
slog.DebugContext(ctx, "Shutting down gracefully")
|
|
jobs.Shutdown()
|
|
sessionStore.Shutdown()
|
|
apiServer.Shutdown()
|
|
businessDB.Shutdown()
|
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), _shutdownPeriod)
|
|
defer cancel()
|
|
httpServer.SetKeepAlivesEnabled(false)
|
|
serr := httpServer.Shutdown(shutdownCtx)
|
|
stopOngoingGracefully()
|
|
if serr != nil {
|
|
slog.ErrorContext(ctx, "Failed to shutdown gracefully", common.ErrAttr(serr))
|
|
fmt.Fprintf(stderr, "error shutting down http server gracefully: %s\n", serr)
|
|
time.Sleep(_shutdownHardPeriod)
|
|
}
|
|
if localServer != nil {
|
|
if lerr := localServer.Close(); lerr != nil {
|
|
slog.ErrorContext(ctx, "Failed to shutdown local server", common.ErrAttr(lerr))
|
|
}
|
|
}
|
|
slog.DebugContext(ctx, "Shutdown finished")
|
|
}()
|
|
|
|
wg.Wait()
|
|
return nil
|
|
}
|
|
|
|
func migrate(ctx context.Context, cfg common.ConfigStore, up bool, autoClose bool) error {
|
|
if len(*migrateHashFlag) == 0 {
|
|
return errors.New("empty migrate hash")
|
|
}
|
|
|
|
if *migrateHashFlag != "ignore" && *migrateHashFlag != GitCommit {
|
|
return fmt.Errorf("target version (%v) does not match built version (%v)", *migrateHashFlag, GitCommit)
|
|
}
|
|
|
|
stage := cfg.Get(common.StageKey).Value()
|
|
verbose := config.AsBool(cfg.Get(common.VerboseKey))
|
|
|
|
common.SetupLogs(stage, verbose)
|
|
slog.InfoContext(ctx, "Migrating", "up", up, "version", GitCommit, "stage", stage)
|
|
|
|
planService := billing.NewPlanService(nil)
|
|
|
|
pool, clickhouse, dberr := db.Connect(ctx, cfg, _dbConnectTimeout, true /*admin*/)
|
|
if dberr != nil {
|
|
return dberr
|
|
}
|
|
|
|
if pool != nil {
|
|
if autoClose {
|
|
defer pool.Close()
|
|
}
|
|
|
|
if err := db.MigratePostgres(ctx, pool, cfg, planService, up); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if clickhouse != nil {
|
|
if autoClose {
|
|
defer clickhouse.Close()
|
|
}
|
|
|
|
if err := db.MigrateClickHouse(ctx, clickhouse, cfg, up); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func serve(cfg common.ConfigStore) (err error) {
|
|
ctx := common.TraceContext(context.Background(), "main")
|
|
if listener, lerr := createListener(ctx, cfg); lerr == nil {
|
|
err = run(ctx, cfg, os.Stderr, listener)
|
|
} else {
|
|
err = lerr
|
|
}
|
|
return
|
|
}
|
|
|
|
func main() {
|
|
flag.Parse()
|
|
|
|
if *versionFlag {
|
|
fmt.Print(GitCommit)
|
|
return
|
|
}
|
|
|
|
var err error
|
|
env, err = common.NewEnvMap(*envFileFlag)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "%s\n", err)
|
|
}
|
|
|
|
cfg := config.NewEnvConfig(env.Get)
|
|
|
|
switch *flagMode {
|
|
case modeServer:
|
|
err = serve(cfg)
|
|
case modeMigrate:
|
|
mctx := common.TraceContext(context.Background(), "migration")
|
|
err = migrate(mctx, cfg, true /*up*/, true /*auto close*/)
|
|
case modeRollback:
|
|
rctx := common.TraceContext(context.Background(), "migration")
|
|
err = migrate(rctx, cfg, false /*up*/, true /*auto close*/)
|
|
case modeAuto:
|
|
mctx := common.TraceContext(context.Background(), "migration")
|
|
if err = migrate(mctx, cfg, true /*up*/, false /*auto close*/); err == nil {
|
|
err = serve(cfg)
|
|
}
|
|
default:
|
|
err = fmt.Errorf("unknown mode: '%s'", *flagMode)
|
|
}
|
|
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "%s\n", err)
|
|
os.Exit(1)
|
|
}
|
|
}
|