mirror of
https://github.com/PrivateCaptcha/PrivateCaptcha.git
synced 2026-02-10 15:59:21 -06:00
403 lines
12 KiB
Go
403 lines
12 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/api"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/billing"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/common"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/config"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/db"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/difficulty"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/email"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/maintenance"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/monitoring"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/portal"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/session"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/session/store/memory"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/web"
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/widget"
|
|
"github.com/justinas/alice"
|
|
)
|
|
|
|
const (
|
|
modeMigrate = "migrate"
|
|
modeRollback = "rollback"
|
|
modeServer = "server"
|
|
_readinessDrainDelay = 1 * time.Second
|
|
_shutdownHardPeriod = 3 * time.Second
|
|
_shutdownPeriod = 10 * time.Second
|
|
_dbConnectTimeout = 30 * time.Second
|
|
)
|
|
|
|
var (
|
|
GitCommit string
|
|
flagMode = flag.String("mode", "", strings.Join([]string{modeMigrate, modeServer}, " | "))
|
|
envFileFlag = flag.String("env", "", "Path to .env file, 'stdin' or empty")
|
|
versionFlag = flag.Bool("version", false, "Print version and exit")
|
|
migrateHashFlag = flag.String("migrate-hash", "", "Target migration version (git commit)")
|
|
certFileFlag = flag.String("certfile", "", "certificate PEM file (e.g. cert.pem)")
|
|
keyFileFlag = flag.String("keyfile", "", "key PEM file (e.g. key.pem)")
|
|
env *common.EnvMap
|
|
)
|
|
|
|
func listenAddress(cfg common.ConfigStore) string {
|
|
host := cfg.Get(common.HostKey).Value()
|
|
if host == "" {
|
|
host = "localhost"
|
|
}
|
|
|
|
port := cfg.Get(common.PortKey).Value()
|
|
if port == "" {
|
|
port = "8080"
|
|
}
|
|
address := net.JoinHostPort(host, port)
|
|
return address
|
|
}
|
|
|
|
func createListener(ctx context.Context, cfg common.ConfigStore) (net.Listener, error) {
|
|
address := listenAddress(cfg)
|
|
listener, err := net.Listen("tcp", address)
|
|
if err != nil {
|
|
slog.ErrorContext(ctx, "Failed to listen", "address", address, common.ErrAttr(err))
|
|
return nil, err
|
|
}
|
|
|
|
if useTLS := (*certFileFlag != "") && (*keyFileFlag != ""); useTLS {
|
|
cert, err := tls.LoadX509KeyPair(*certFileFlag, *keyFileFlag)
|
|
if err != nil {
|
|
slog.ErrorContext(ctx, "Failed to load certificates", "cert", *certFileFlag, "key", *keyFileFlag, common.ErrAttr(err))
|
|
return nil, err
|
|
}
|
|
tlsConfig := &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
}
|
|
listener = tls.NewListener(listener, tlsConfig)
|
|
}
|
|
|
|
return listener, nil
|
|
}
|
|
|
|
func run(ctx context.Context, cfg common.ConfigStore, stderr io.Writer, listener net.Listener) error {
|
|
stage := cfg.Get(common.StageKey).Value()
|
|
verbose := config.AsBool(cfg.Get(common.VerboseKey))
|
|
common.SetupLogs(stage, verbose)
|
|
|
|
planService := billing.NewPlanService(nil)
|
|
|
|
pool, clickhouse, dberr := db.Connect(ctx, cfg, _dbConnectTimeout, false /*admin*/)
|
|
if dberr != nil {
|
|
return dberr
|
|
}
|
|
|
|
defer pool.Close()
|
|
defer clickhouse.Close()
|
|
|
|
businessDB := db.NewBusiness(pool)
|
|
timeSeriesDB := db.NewTimeSeries(clickhouse)
|
|
|
|
metrics := monitoring.NewService()
|
|
|
|
cdnURLConfig := config.AsURL(ctx, cfg.Get(common.CDNBaseURLKey))
|
|
portalURLConfig := config.AsURL(ctx, cfg.Get(common.PortalBaseURLKey))
|
|
|
|
mailer := email.NewMailer(cfg)
|
|
portalMailer := email.NewPortalMailer("https:"+cdnURLConfig.URL(), portalURLConfig.Domain(), mailer, cfg)
|
|
|
|
apiServer := &api.Server{
|
|
Stage: stage,
|
|
BusinessDB: businessDB,
|
|
TimeSeries: timeSeriesDB,
|
|
Auth: api.NewAuthMiddleware(cfg, businessDB, api.NewUserLimiter(businessDB), planService),
|
|
VerifyLogChan: make(chan *common.VerifyRecord, 10*api.VerifyBatchSize),
|
|
Salt: api.NewPuzzleSalt(cfg.Get(common.APISaltKey)),
|
|
UserFingerprintKey: api.NewUserFingerprintKey(cfg.Get(common.UserFingerprintIVKey)),
|
|
Metrics: metrics,
|
|
Mailer: portalMailer,
|
|
Levels: difficulty.NewLevels(timeSeriesDB, 100 /*levelsBatchSize*/, api.PropertyBucketSize),
|
|
VerifyLogCancel: func() {},
|
|
}
|
|
if err := apiServer.Init(ctx, 10*time.Second /*flush interval*/, 1*time.Second /*backfill duration*/); err != nil {
|
|
return err
|
|
}
|
|
|
|
router := http.NewServeMux()
|
|
|
|
apiURLConfig := config.AsURL(ctx, cfg.Get(common.APIBaseURLKey))
|
|
apiDomain := apiURLConfig.Domain()
|
|
apiServer.Setup(router, apiDomain, verbose, common.NoopMiddleware)
|
|
|
|
sessionStore := db.NewSessionStore(pool, memory.New(), 1*time.Minute, session.KeyPersistent)
|
|
portalServer := &portal.Server{
|
|
Stage: stage,
|
|
Store: businessDB,
|
|
TimeSeries: timeSeriesDB,
|
|
XSRF: &common.XSRFMiddleware{Key: "pckey", Timeout: 1 * time.Hour},
|
|
Sessions: &session.Manager{
|
|
CookieName: "pcsid",
|
|
Store: sessionStore,
|
|
MaxLifetime: sessionStore.MaxLifetime(),
|
|
SecureCookie: (*certFileFlag != "") && (*keyFileFlag != ""),
|
|
},
|
|
PlanService: planService,
|
|
APIURL: apiURLConfig.URL(),
|
|
CDNURL: cdnURLConfig.URL(),
|
|
PuzzleEngine: apiServer,
|
|
Metrics: metrics,
|
|
Mailer: portalMailer,
|
|
Auth: portal.NewAuthMiddleware(portal.NewRateLimiter(cfg)),
|
|
}
|
|
|
|
templatesBuilder := portal.NewTemplatesBuilder()
|
|
if err := templatesBuilder.AddFS(ctx, web.Templates(), "core"); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := portalServer.Init(ctx, templatesBuilder); err != nil {
|
|
return err
|
|
}
|
|
|
|
healthCheck := &maintenance.HealthCheckJob{
|
|
BusinessDB: businessDB,
|
|
TimeSeriesDB: timeSeriesDB,
|
|
CheckInterval: cfg.Get(common.HealthCheckIntervalKey),
|
|
Metrics: metrics,
|
|
}
|
|
|
|
portalDomain := portalURLConfig.Domain()
|
|
_ = portalServer.Setup(router, portalDomain, common.NoopMiddleware)
|
|
rateLimiter := portalServer.Auth.RateLimit()
|
|
cdnDomain := cdnURLConfig.Domain()
|
|
cdnChain := alice.New(common.Recovered, metrics.CDNHandler, rateLimiter)
|
|
router.Handle("GET "+cdnDomain+"/portal/", http.StripPrefix("/portal/", cdnChain.Then(web.Static())))
|
|
router.Handle("GET "+cdnDomain+"/widget/", http.StripPrefix("/widget/", cdnChain.Then(widget.Static())))
|
|
// "protection" (NOTE: different than usual order of monitoring)
|
|
publicChain := alice.New(common.Recovered, metrics.IgnoredHandler, rateLimiter)
|
|
portalServer.SetupCatchAll(router, portalDomain, publicChain)
|
|
router.Handle("/", publicChain.ThenFunc(common.CatchAll))
|
|
|
|
ongoingCtx, stopOngoingGracefully := context.WithCancel(context.Background())
|
|
httpServer := &http.Server{
|
|
Handler: router,
|
|
ReadHeaderTimeout: 5 * time.Second,
|
|
ReadTimeout: 10 * time.Second,
|
|
WriteTimeout: 10 * time.Second,
|
|
IdleTimeout: 60 * time.Second,
|
|
MaxHeaderBytes: 1024 * 1024,
|
|
BaseContext: func(_ net.Listener) context.Context {
|
|
return ongoingCtx
|
|
},
|
|
}
|
|
|
|
updateConfigFunc := func(ctx context.Context) {
|
|
cfg.Update(ctx)
|
|
maintenanceMode := config.AsBool(cfg.Get(common.MaintenanceModeKey))
|
|
businessDB.UpdateConfig(maintenanceMode)
|
|
timeSeriesDB.UpdateConfig(maintenanceMode)
|
|
portalServer.UpdateConfig(ctx, cfg)
|
|
apiServer.UpdateConfig(ctx, cfg)
|
|
}
|
|
updateConfigFunc(ctx)
|
|
|
|
quit := make(chan struct{})
|
|
go func(ctx context.Context) {
|
|
signals := make(chan os.Signal, 1)
|
|
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
|
|
defer func() {
|
|
signal.Stop(signals)
|
|
close(signals)
|
|
}()
|
|
for {
|
|
sig, ok := <-signals
|
|
if !ok {
|
|
slog.DebugContext(ctx, "Signals channel closed")
|
|
return
|
|
}
|
|
slog.DebugContext(ctx, "Received signal", "signal", sig)
|
|
switch sig {
|
|
case syscall.SIGHUP:
|
|
if uerr := env.Update(); uerr != nil {
|
|
slog.ErrorContext(ctx, "Failed to update environment", common.ErrAttr(uerr))
|
|
}
|
|
updateConfigFunc(ctx)
|
|
case syscall.SIGINT, syscall.SIGTERM:
|
|
healthCheck.Shutdown(ctx)
|
|
// Give time for readiness check to propagate
|
|
time.Sleep(min(_readinessDrainDelay, healthCheck.Interval()))
|
|
close(quit)
|
|
return
|
|
}
|
|
}
|
|
}(common.TraceContext(context.Background(), "signal_handler"))
|
|
|
|
go func() {
|
|
slog.InfoContext(ctx, "Listening", "address", listener.Addr().String(), "version", GitCommit, "stage", stage)
|
|
if err := httpServer.Serve(listener); err != nil && err != http.ErrServerClosed {
|
|
slog.ErrorContext(ctx, "Error serving", common.ErrAttr(err))
|
|
}
|
|
}()
|
|
|
|
// start maintenance jobs
|
|
jobs := maintenance.NewJobs(businessDB)
|
|
jobs.Add(healthCheck)
|
|
jobs.Add(&maintenance.SessionsCleanupJob{
|
|
Session: portalServer.Sessions,
|
|
})
|
|
jobs.Add(&maintenance.CleanupDBCacheJob{Store: businessDB})
|
|
jobs.Add(&maintenance.CleanupDeletedRecordsJob{Store: businessDB, Age: 365 * 24 * time.Hour})
|
|
jobs.AddLocked(24*time.Hour, &maintenance.GarbageCollectDataJob{
|
|
Age: 30 * 24 * time.Hour,
|
|
BusinessDB: businessDB,
|
|
TimeSeries: timeSeriesDB,
|
|
})
|
|
jobs.AddOneOff(&maintenance.WarmupPortalAuth{
|
|
Store: businessDB,
|
|
})
|
|
jobs.Run()
|
|
|
|
var localServer *http.Server
|
|
if localAddress := cfg.Get(common.LocalAddressKey).Value(); len(localAddress) > 0 {
|
|
localRouter := http.NewServeMux()
|
|
metrics.Setup(localRouter)
|
|
jobs.Setup(localRouter)
|
|
localRouter.Handle(http.MethodGet+" /"+common.LiveEndpoint, common.Recovered(http.HandlerFunc(healthCheck.LiveHandler)))
|
|
localRouter.Handle(http.MethodGet+" /"+common.ReadyEndpoint, common.Recovered(http.HandlerFunc(healthCheck.ReadyHandler)))
|
|
localServer = &http.Server{
|
|
Addr: localAddress,
|
|
Handler: localRouter,
|
|
}
|
|
go func() {
|
|
slog.InfoContext(ctx, "Serving local API", "address", localServer.Addr)
|
|
if err := localServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
slog.ErrorContext(ctx, "Error serving local API", common.ErrAttr(err))
|
|
}
|
|
}()
|
|
} else {
|
|
slog.DebugContext(ctx, "Skipping serving local API")
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
<-quit
|
|
slog.DebugContext(ctx, "Shutting down gracefully")
|
|
jobs.Shutdown()
|
|
sessionStore.Shutdown()
|
|
apiServer.Shutdown()
|
|
portalServer.Shutdown()
|
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), _shutdownPeriod)
|
|
defer cancel()
|
|
httpServer.SetKeepAlivesEnabled(false)
|
|
serr := httpServer.Shutdown(shutdownCtx)
|
|
stopOngoingGracefully()
|
|
if serr != nil {
|
|
slog.ErrorContext(ctx, "Failed to shutdown gracefully", common.ErrAttr(serr))
|
|
fmt.Fprintf(stderr, "error shutting down http server gracefully: %s\n", serr)
|
|
time.Sleep(_shutdownHardPeriod)
|
|
}
|
|
if localServer != nil {
|
|
localServer.Close()
|
|
}
|
|
slog.DebugContext(ctx, "Shutdown finished")
|
|
}()
|
|
|
|
wg.Wait()
|
|
return nil
|
|
}
|
|
|
|
func migrate(ctx context.Context, cfg common.ConfigStore, up bool) error {
|
|
if len(*migrateHashFlag) == 0 {
|
|
return errors.New("empty migrate hash")
|
|
}
|
|
|
|
if *migrateHashFlag != "ignore" && *migrateHashFlag != GitCommit {
|
|
return fmt.Errorf("target version (%v) does not match built version (%v)", *migrateHashFlag, GitCommit)
|
|
}
|
|
|
|
stage := cfg.Get(common.StageKey).Value()
|
|
verbose := config.AsBool(cfg.Get(common.VerboseKey))
|
|
|
|
common.SetupLogs(stage, verbose)
|
|
slog.InfoContext(ctx, "Migrating", "up", up, "version", GitCommit, "stage", stage)
|
|
|
|
planService := billing.NewPlanService(nil)
|
|
|
|
pool, clickhouse, dberr := db.Connect(ctx, cfg, _dbConnectTimeout, true /*admin*/)
|
|
if dberr != nil {
|
|
return dberr
|
|
}
|
|
|
|
defer pool.Close()
|
|
defer clickhouse.Close()
|
|
|
|
if err := db.MigratePostgres(ctx, pool, cfg, planService, up); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := db.MigrateClickHouse(ctx, clickhouse, cfg, up); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func main() {
|
|
flag.Parse()
|
|
|
|
if *versionFlag {
|
|
fmt.Print(GitCommit)
|
|
return
|
|
}
|
|
|
|
var err error
|
|
env, err = common.NewEnvMap(*envFileFlag)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "%s\n", err)
|
|
}
|
|
|
|
cfg := config.NewEnvConfig(config.DefaultMapper, env.Get)
|
|
|
|
if err = checkLicense(context.Background(), cfg); err != nil {
|
|
fmt.Fprintf(os.Stderr, "%s\n", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
switch *flagMode {
|
|
case modeServer:
|
|
ctx := common.TraceContext(context.Background(), "main")
|
|
if listener, lerr := createListener(ctx, cfg); lerr == nil {
|
|
err = run(ctx, cfg, os.Stderr, listener)
|
|
} else {
|
|
err = lerr
|
|
}
|
|
case modeMigrate:
|
|
ctx := common.TraceContext(context.Background(), "migration")
|
|
err = migrate(ctx, cfg, true /*up*/)
|
|
case modeRollback:
|
|
ctx := common.TraceContext(context.Background(), "migration")
|
|
err = migrate(ctx, cfg, false /*up*/)
|
|
default:
|
|
err = fmt.Errorf("unknown mode: '%s'", *flagMode)
|
|
}
|
|
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "%s\n", err)
|
|
os.Exit(1)
|
|
}
|
|
}
|