Files
hatchet/pkg/config/loader/loader.go
matt 0a947924fa Feat: Parallelize replication from PG -> External (#2637)
* feat: chunking query

* feat: first pass at range chunking

* fix: bug bashing

* fix: function geq

* fix: use maps.Copy

* fix: olap func

* feat: olap side

* refactor: external id

* fix: order by

* feat: wire up env vars

* fix: pass var through

* fix: naming

* fix: append to returnErr properly

* fix: use eg.Go
2025-12-10 17:11:03 -05:00

929 lines
28 KiB
Go

// Adapted from: https://github.com/hatchet-dev/hatchet-v1-archived/blob/3c2c13168afa1af68d4baaf5ed02c9d49c5f0323/internal/config/loader/loader.go
package loader
import (
"context"
"fmt"
"log"
"net"
"os"
"path/filepath"
"strings"
"time"
"github.com/exaring/otelpgx"
pgxzero "github.com/jackc/pgx-zerolog"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/tracelog"
"github.com/rs/zerolog"
"golang.org/x/oauth2"
"github.com/hatchet-dev/hatchet/internal/integrations/alerting"
"github.com/hatchet-dev/hatchet/internal/msgqueue"
"github.com/hatchet-dev/hatchet/internal/msgqueue/postgres"
"github.com/hatchet-dev/hatchet/internal/msgqueue/rabbitmq"
"github.com/hatchet-dev/hatchet/internal/services/ingestor"
"github.com/hatchet-dev/hatchet/pkg/analytics"
"github.com/hatchet-dev/hatchet/pkg/analytics/posthog"
"github.com/hatchet-dev/hatchet/pkg/auth/cookie"
"github.com/hatchet-dev/hatchet/pkg/auth/oauth"
"github.com/hatchet-dev/hatchet/pkg/auth/token"
"github.com/hatchet-dev/hatchet/pkg/config/client"
"github.com/hatchet-dev/hatchet/pkg/config/database"
"github.com/hatchet-dev/hatchet/pkg/config/loader/loaderutils"
"github.com/hatchet-dev/hatchet/pkg/config/server"
"github.com/hatchet-dev/hatchet/pkg/config/shared"
"github.com/hatchet-dev/hatchet/pkg/encryption"
"github.com/hatchet-dev/hatchet/pkg/errors"
"github.com/hatchet-dev/hatchet/pkg/errors/sentry"
"github.com/hatchet-dev/hatchet/pkg/integrations/email"
"github.com/hatchet-dev/hatchet/pkg/integrations/email/postmark"
"github.com/hatchet-dev/hatchet/pkg/logger"
"github.com/hatchet-dev/hatchet/pkg/repository"
"github.com/hatchet-dev/hatchet/pkg/repository/cache"
"github.com/hatchet-dev/hatchet/pkg/repository/debugger"
"github.com/hatchet-dev/hatchet/pkg/repository/metered"
postgresdb "github.com/hatchet-dev/hatchet/pkg/repository/postgres"
v0 "github.com/hatchet-dev/hatchet/pkg/scheduling/v0"
v1 "github.com/hatchet-dev/hatchet/pkg/scheduling/v1"
"github.com/hatchet-dev/hatchet/pkg/security"
"github.com/hatchet-dev/hatchet/pkg/validator"
msgqueuev1 "github.com/hatchet-dev/hatchet/internal/msgqueue/v1"
pgmqv1 "github.com/hatchet-dev/hatchet/internal/msgqueue/v1/postgres"
rabbitmqv1 "github.com/hatchet-dev/hatchet/internal/msgqueue/v1/rabbitmq"
clientv1 "github.com/hatchet-dev/hatchet/pkg/client/v1"
repov1 "github.com/hatchet-dev/hatchet/pkg/repository/v1"
)
// LoadDatabaseConfigFile loads the database config file via viper
func LoadDatabaseConfigFile(files ...[]byte) (*database.ConfigFile, error) {
configFile := &database.ConfigFile{}
f := database.BindAllEnv
_, err := loaderutils.LoadConfigFromViper(f, configFile, files...)
return configFile, err
}
// LoadServerConfigFile loads the server config file via viper
func LoadServerConfigFile(files ...[]byte) (*server.ServerConfigFile, error) {
configFile := &server.ServerConfigFile{}
f := server.BindAllEnv
_, err := loaderutils.LoadConfigFromViper(f, configFile, files...)
return configFile, err
}
type RepositoryOverrides struct {
LogsEngineRepository repository.LogsEngineRepository
LogsAPIRepository repository.LogsAPIRepository
}
type ConfigLoader struct {
directory string
RepositoryOverrides RepositoryOverrides
}
func NewConfigLoader(directory string) *ConfigLoader {
return &ConfigLoader{directory: directory}
}
// InitDataLayer initializes the database layer from the configuration
func (c *ConfigLoader) InitDataLayer() (res *database.Layer, err error) {
sharedFilePath := filepath.Join(c.directory, "database.yaml")
configFileBytes, err := loaderutils.GetConfigBytes(sharedFilePath)
if err != nil {
return nil, err
}
cf, err := LoadDatabaseConfigFile(configFileBytes...)
if err != nil {
return nil, err
}
serverSharedFilePath := filepath.Join(c.directory, "server.yaml")
serverConfigFileBytes, err := loaderutils.GetConfigBytes(serverSharedFilePath)
if err != nil {
return nil, err
}
scf, err := LoadServerConfigFile(serverConfigFileBytes...)
if err != nil {
return nil, err
}
l := logger.NewStdErr(&cf.Logger, "database")
databaseUrl := os.Getenv("DATABASE_URL")
if databaseUrl == "" {
databaseUrl = fmt.Sprintf(
"postgresql://%s:%s@%s:%d/%s?sslmode=%s",
cf.PostgresUsername,
cf.PostgresPassword,
cf.PostgresHost,
cf.PostgresPort,
cf.PostgresDbName,
cf.PostgresSSLMode,
)
_ = os.Setenv("DATABASE_URL", databaseUrl)
}
pgxpoolConnAfterConnect := func(ctx context.Context, conn *pgx.Conn) error {
// Set timezone to UTC for all connections
if _, err := conn.Exec(ctx, "SET TIME ZONE 'UTC'"); err != nil {
return err
}
// ref: https://github.com/jackc/pgx/issues/1549
t, err := conn.LoadType(ctx, "v1_readable_status_olap")
if err != nil {
return err
}
conn.TypeMap().RegisterType(t)
t, err = conn.LoadType(ctx, "_v1_readable_status_olap")
if err != nil {
return err
}
conn.TypeMap().RegisterType(t)
return nil
}
config, err := pgxpool.ParseConfig(databaseUrl)
if err != nil {
return nil, err
}
config.AfterConnect = pgxpoolConnAfterConnect
if cf.LogQueries {
config.ConnConfig.Tracer = &tracelog.TraceLog{
Logger: pgxzero.NewLogger(l),
LogLevel: tracelog.LogLevelDebug,
}
}
config.ConnConfig.Tracer = otelpgx.NewTracer()
if cf.MaxConns != 0 {
config.MaxConns = int32(cf.MaxConns) // nolint: gosec
}
if cf.MinConns != 0 {
config.MinConns = int32(cf.MinConns) // nolint: gosec
}
config.MaxConnLifetime = 15 * 60 * time.Second
// Check database instance timezone if enforcement is enabled
if cf.EnforceUTCTimezone {
if err := checkDatabaseTimezone(config.ConnConfig, cf.PostgresDbName, "primary database", &l); err != nil {
return nil, err
}
}
var debug *debugger.Debugger
if cf.Logger.Level == "debug" {
debug = debugger.NewDebugger(&l)
config.BeforeAcquire = debug.BeforeAcquire // nolint: staticcheck
config.AfterRelease = debug.AfterRelease
}
pool, err := pgxpool.NewWithConfig(context.Background(), config)
if err != nil {
return nil, fmt.Errorf("could not connect to database: %w", err)
}
if debug != nil {
// pool needs the debugger hooks (BeforeAcquire/AfterRelease) but debugger needs the pool
// to track active connections, so we add the pool later
debug.Setup(pool)
}
// a pool for read replicas, if enabled
var readReplicaPool *pgxpool.Pool
if cf.ReadReplicaEnabled {
if cf.ReadReplicaDatabaseURL == "" {
return nil, fmt.Errorf("read replica database url is required if read replica is enabled")
}
readReplicaConfig, err := pgxpool.ParseConfig(cf.ReadReplicaDatabaseURL)
if err != nil {
return nil, fmt.Errorf("could not parse read replica database url: %w", err)
}
if cf.ReadReplicaMaxConns != 0 {
readReplicaConfig.MaxConns = int32(cf.ReadReplicaMaxConns) // nolint: gosec
}
if cf.ReadReplicaMinConns != 0 {
readReplicaConfig.MinConns = int32(cf.ReadReplicaMinConns) // nolint: gosec
}
readReplicaConfig.MaxConnLifetime = 15 * 60 * time.Second
readReplicaConfig.ConnConfig.Tracer = otelpgx.NewTracer()
// Check read replica database instance timezone if enforcement is enabled
if cf.EnforceUTCTimezone {
if err := checkDatabaseTimezone(readReplicaConfig.ConnConfig, "", "read replica database", &l); err != nil {
return nil, err
}
}
readReplicaConfig.AfterConnect = pgxpoolConnAfterConnect
readReplicaPool, err = pgxpool.NewWithConfig(context.Background(), readReplicaConfig)
if err != nil {
return nil, fmt.Errorf("could not connect to read replica database: %w", err)
}
}
ch := cache.New(cf.CacheDuration)
entitlementRepo := postgresdb.NewEntitlementRepository(pool, &scf.Runtime, postgresdb.WithLogger(&l), postgresdb.WithCache(ch))
meter := metered.NewMetered(entitlementRepo, &l)
var opts []postgresdb.PostgresRepositoryOpt
opts = append(opts, postgresdb.WithLogger(&l), postgresdb.WithCache(ch), postgresdb.WithMetered(meter))
if c.RepositoryOverrides.LogsEngineRepository != nil {
opts = append(opts, postgresdb.WithLogsEngineRepository(c.RepositoryOverrides.LogsEngineRepository))
}
cleanupEngine, engineRepo, err := postgresdb.NewEngineRepository(pool, &scf.Runtime, opts...)
if err != nil {
return nil, fmt.Errorf("could not create engine repository: %w", err)
}
if c.RepositoryOverrides.LogsAPIRepository != nil {
opts = append(opts, postgresdb.WithLogsAPIRepository(c.RepositoryOverrides.LogsAPIRepository))
}
retentionPeriod, err := time.ParseDuration(scf.Runtime.Limits.DefaultTenantRetentionPeriod)
if err != nil {
return nil, fmt.Errorf("could not parse retention period %s: %w", scf.Runtime.Limits.DefaultTenantRetentionPeriod, err)
}
taskLimits := repov1.TaskOperationLimits{
TimeoutLimit: scf.Runtime.TaskOperationLimits.TimeoutLimit,
ReassignLimit: scf.Runtime.TaskOperationLimits.ReassignLimit,
RetryQueueLimit: scf.Runtime.TaskOperationLimits.RetryQueueLimit,
DurableSleepLimit: scf.Runtime.TaskOperationLimits.DurableSleepLimit,
}
inlineStoreTTLDays := scf.PayloadStore.InlineStoreTTLDays
if inlineStoreTTLDays <= 0 {
return nil, fmt.Errorf("inline store TTL days must be greater than 0")
}
inlineStoreTTL := time.Duration(inlineStoreTTLDays) * 24 * time.Hour
payloadStoreOpts := repov1.PayloadStoreRepositoryOpts{
EnablePayloadDualWrites: scf.PayloadStore.EnablePayloadDualWrites,
EnableTaskEventPayloadDualWrites: scf.PayloadStore.EnableTaskEventPayloadDualWrites,
EnableOLAPPayloadDualWrites: scf.PayloadStore.EnableOLAPPayloadDualWrites,
EnableDagDataPayloadDualWrites: scf.PayloadStore.EnableDagDataPayloadDualWrites,
ExternalCutoverProcessInterval: scf.PayloadStore.ExternalCutoverProcessInterval,
ExternalCutoverBatchSize: scf.PayloadStore.ExternalCutoverBatchSize,
ExternalCutoverNumConcurrentOffloads: scf.PayloadStore.ExternalCutoverNumConcurrentOffloads,
InlineStoreTTL: &inlineStoreTTL,
EnableImmediateOffloads: scf.PayloadStore.EnableImmediateOffloads,
}
statusUpdateOpts := repov1.StatusUpdateBatchSizeLimits{
Task: int32(scf.OLAPStatusUpdates.TaskBatchSizeLimit),
DAG: int32(scf.OLAPStatusUpdates.DagBatchSizeLimit),
}
v1, cleanupV1 := repov1.NewRepository(pool, &l, retentionPeriod, retentionPeriod, scf.Runtime.MaxInternalRetryCount, entitlementRepo, taskLimits, payloadStoreOpts, statusUpdateOpts)
apiRepo, cleanupApiRepo, err := postgresdb.NewAPIRepository(pool, &scf.Runtime, opts...)
if err != nil {
return nil, fmt.Errorf("could not create api repository: %w", err)
}
if readReplicaPool != nil {
v1.OLAP().SetReadReplicaPool(readReplicaPool)
}
return &database.Layer{
Disconnect: func() error {
if err := cleanupEngine(); err != nil {
return err
}
ch.Stop()
meter.Stop()
if err := cleanupV1(); err != nil {
return err
}
return cleanupApiRepo()
},
Pool: pool,
QueuePool: pool,
APIRepository: apiRepo,
EngineRepository: engineRepo,
EntitlementRepository: entitlementRepo,
V1: v1,
Seed: cf.Seed,
}, nil
}
type ServerConfigFileOverride func(*server.ServerConfigFile)
// CreateServerFromConfig loads the server configuration and returns a server
func (c *ConfigLoader) CreateServerFromConfig(version string, overrides ...ServerConfigFileOverride) (cleanup func() error, res *server.ServerConfig, err error) {
sharedFilePath := filepath.Join(c.directory, "server.yaml")
configFileBytes, err := loaderutils.GetConfigBytes(sharedFilePath)
if err != nil {
return nil, nil, err
}
dc, err := c.InitDataLayer()
if err != nil {
return nil, nil, err
}
cf, err := LoadServerConfigFile(configFileBytes...)
if err != nil {
return nil, nil, err
}
for _, override := range overrides {
override(cf)
}
return createControllerLayer(dc, cf, version)
}
func createControllerLayer(dc *database.Layer, cf *server.ServerConfigFile, version string) (cleanup func() error, res *server.ServerConfig, err error) {
l := logger.NewStdErr(&cf.Logger, "server")
queueLogger := logger.NewStdErr(&cf.AdditionalLoggers.Queue, "queue")
tls, err := loaderutils.LoadServerTLSConfig(&cf.TLS)
if err != nil {
return nil, nil, fmt.Errorf("could not load TLS config: %w", err)
}
ss, err := cookie.NewUserSessionStore(
cookie.WithSessionRepository(dc.APIRepository.UserSession()),
cookie.WithCookieAllowInsecure(cf.Auth.Cookie.Insecure),
cookie.WithCookieDomain(cf.Auth.Cookie.Domain),
cookie.WithCookieName(cf.Auth.Cookie.Name),
cookie.WithCookieSecrets(getStrArr(cf.Auth.Cookie.Secrets)...),
)
if err != nil {
return nil, nil, fmt.Errorf("could not create session store: %w", err)
}
var mq msgqueue.MessageQueue
var mqv1 msgqueuev1.MessageQueue
cleanup1 := func() error {
return nil
}
var ing ingestor.Ingestor
if cf.MessageQueue.Enabled {
switch strings.ToLower(cf.MessageQueue.Kind) {
case "postgres":
var cleanupv0 func() error
var cleanupv1 func() error
cleanupv0, mq = postgres.NewPostgresMQ(
dc.EngineRepository.MessageQueue(),
postgres.WithLogger(&l),
postgres.WithQos(cf.MessageQueue.Postgres.Qos),
)
cleanupv1, mqv1, err = pgmqv1.NewPostgresMQ(
dc.EngineRepository.MessageQueue(),
pgmqv1.WithLogger(&l),
pgmqv1.WithQos(cf.MessageQueue.Postgres.Qos),
)
if err != nil {
return nil, nil, fmt.Errorf("could not init postgres queue: %w", err)
}
cleanup1 = func() error {
if err := cleanupv0(); err != nil {
return err
}
return cleanupv1()
}
case "rabbitmq":
if cf.MessageQueue.RabbitMQ.URL == "" {
return nil, nil, fmt.Errorf("using RabbitMQ as message queue requires a URL to be set")
}
var cleanupv0 func() error
var cleanupv1 func() error
cleanupv0, mq = rabbitmq.New(
rabbitmq.WithURL(cf.MessageQueue.RabbitMQ.URL),
rabbitmq.WithLogger(&l),
rabbitmq.WithQos(cf.MessageQueue.RabbitMQ.Qos),
rabbitmq.WithDisableTenantExchangePubs(cf.Runtime.DisableTenantPubs),
rabbitmq.WithMessageRejection(cf.MessageQueue.RabbitMQ.EnableMessageRejection, cf.MessageQueue.RabbitMQ.MaxDeathCount),
)
cleanupv1, mqv1, err = rabbitmqv1.New(
rabbitmqv1.WithURL(cf.MessageQueue.RabbitMQ.URL),
rabbitmqv1.WithLogger(&l),
rabbitmqv1.WithQos(cf.MessageQueue.RabbitMQ.Qos),
rabbitmqv1.WithDisableTenantExchangePubs(cf.Runtime.DisableTenantPubs),
rabbitmqv1.WithMaxPubChannels(cf.MessageQueue.RabbitMQ.MaxPubChans),
rabbitmqv1.WithMaxSubChannels(cf.MessageQueue.RabbitMQ.MaxSubChans),
rabbitmqv1.WithGzipCompression(
cf.MessageQueue.RabbitMQ.CompressionEnabled,
cf.MessageQueue.RabbitMQ.CompressionThreshold,
),
rabbitmqv1.WithMessageRejection(cf.MessageQueue.RabbitMQ.EnableMessageRejection, cf.MessageQueue.RabbitMQ.MaxDeathCount),
)
if err != nil {
return nil, nil, fmt.Errorf("could not init rabbitmq: %w", err)
}
cleanup1 = func() error {
if err := cleanupv0(); err != nil {
return err
}
return cleanupv1()
}
}
ing, err = ingestor.NewIngestor(
ingestor.WithEventRepository(dc.EngineRepository.Event()),
ingestor.WithStreamEventsRepository(dc.EngineRepository.StreamEvent()),
ingestor.WithLogRepository(dc.EngineRepository.Log()),
ingestor.WithMessageQueue(mq),
ingestor.WithMessageQueueV1(mqv1),
ingestor.WithEntitlementsRepository(dc.EntitlementRepository),
ingestor.WithStepRunRepository(dc.EngineRepository.StepRun()),
ingestor.WithRepositoryV1(dc.V1),
)
if err != nil {
return nil, nil, fmt.Errorf("could not create ingestor: %w", err)
}
}
var alerter errors.Alerter
if cf.Alerting.Sentry.Enabled {
alerter, err = sentry.NewSentryAlerter(&sentry.SentryAlerterOpts{
DSN: cf.Alerting.Sentry.DSN,
Environment: cf.Alerting.Sentry.Environment,
SampleRate: cf.Alerting.Sentry.SampleRate,
})
if err != nil {
return nil, nil, fmt.Errorf("could not create sentry alerter: %w", err)
}
} else {
alerter = errors.NoOpAlerter{}
}
if cf.SecurityCheck.Enabled {
securityCheck := security.NewSecurityCheck(&security.DefaultSecurityCheck{
Enabled: cf.SecurityCheck.Enabled,
Endpoint: cf.SecurityCheck.Endpoint,
Logger: &l,
Version: version,
}, dc.APIRepository.SecurityCheck())
defer securityCheck.Check()
}
var analyticsEmitter analytics.Analytics
var feAnalyticsConfig *server.FePosthogConfig
if cf.Analytics.Posthog.Enabled {
analyticsEmitter, err = posthog.NewPosthogAnalytics(&posthog.PosthogAnalyticsOpts{
ApiKey: cf.Analytics.Posthog.ApiKey,
Endpoint: cf.Analytics.Posthog.Endpoint,
})
if cf.Analytics.Posthog.FeApiKey != "" && cf.Analytics.Posthog.FeApiHost != "" {
feAnalyticsConfig = &server.FePosthogConfig{
ApiKey: cf.Analytics.Posthog.FeApiKey,
ApiHost: cf.Analytics.Posthog.FeApiHost,
}
}
if err != nil {
return nil, nil, fmt.Errorf("could not create posthog analytics: %w", err)
}
} else {
analyticsEmitter = analytics.NoOpAnalytics{}
}
var pylon server.PylonConfig
if cf.Pylon.Enabled {
if cf.Pylon.AppID == "" {
return nil, nil, fmt.Errorf("pylon app id is required")
}
pylon.AppID = cf.Pylon.AppID
pylon.Secret = cf.Pylon.Secret
}
auth := server.AuthConfig{
RestrictedEmailDomains: getStrArr(cf.Auth.RestrictedEmailDomains),
ConfigFile: cf.Auth,
}
if cf.Auth.Google.Enabled {
if cf.Auth.Google.ClientID == "" {
return nil, nil, fmt.Errorf("google client id is required")
}
if cf.Auth.Google.ClientSecret == "" {
return nil, nil, fmt.Errorf("google client secret is required")
}
gClient := oauth.NewGoogleClient(&oauth.Config{
ClientID: cf.Auth.Google.ClientID,
ClientSecret: cf.Auth.Google.ClientSecret,
BaseURL: cf.Runtime.ServerURL,
Scopes: cf.Auth.Google.Scopes,
})
auth.GoogleOAuthConfig = gClient
}
if cf.Auth.Github.Enabled {
if cf.Auth.Github.ClientID == "" {
return nil, nil, fmt.Errorf("github client id is required")
}
if cf.Auth.Github.ClientSecret == "" {
return nil, nil, fmt.Errorf("github client secret is required")
}
auth.GithubOAuthConfig = oauth.NewGithubClient(&oauth.Config{
ClientID: cf.Auth.Github.ClientID,
ClientSecret: cf.Auth.Github.ClientSecret,
BaseURL: cf.Runtime.ServerURL,
Scopes: cf.Auth.Github.Scopes,
})
}
encryptionSvc, err := LoadEncryptionSvc(cf)
if err != nil {
return nil, nil, fmt.Errorf("could not load encryption service: %w", err)
}
// create a new JWT manager
auth.JWTManager, err = token.NewJWTManager(encryptionSvc, dc.EngineRepository.APIToken(), &token.TokenOpts{
Issuer: cf.Runtime.ServerURL,
Audience: cf.Runtime.ServerURL,
GRPCBroadcastAddress: cf.Runtime.GRPCBroadcastAddress,
ServerURL: cf.Runtime.ServerURL,
})
if err != nil {
return nil, nil, fmt.Errorf("could not create JWT manager: %w", err)
}
var emailSvc email.EmailService = &email.NoOpService{}
if cf.Email.Postmark.Enabled {
emailSvc = postmark.NewPostmarkClient(
cf.Email.Postmark.ServerKey,
cf.Email.Postmark.FromEmail,
cf.Email.Postmark.FromName,
cf.Email.Postmark.SupportEmail,
)
}
additionalOAuthConfigs := make(map[string]*oauth2.Config)
if cf.TenantAlerting.Slack.Enabled {
additionalOAuthConfigs["slack"] = oauth.NewSlackClient(&oauth.Config{
ClientID: cf.TenantAlerting.Slack.SlackAppClientID,
ClientSecret: cf.TenantAlerting.Slack.SlackAppClientSecret,
BaseURL: cf.Runtime.ServerURL,
Scopes: cf.TenantAlerting.Slack.SlackAppScopes,
})
}
v := validator.NewDefaultValidator()
schedulingPool, cleanupSchedulingPool, err := v0.NewSchedulingPool(
dc.EngineRepository.Scheduler(),
&queueLogger,
cf.Runtime.SingleQueueLimit,
)
if err != nil {
return nil, nil, fmt.Errorf("could not create scheduling pool: %w", err)
}
schedulingPoolV1, cleanupSchedulingPoolV1, err := v1.NewSchedulingPool(
dc.V1.Scheduler(),
&queueLogger,
cf.Runtime.SingleQueueLimit,
cf.Runtime.SchedulerConcurrencyRateLimit,
cf.Runtime.SchedulerConcurrencyPollingMinInterval,
cf.Runtime.SchedulerConcurrencyPollingMaxInterval,
)
if err != nil {
return nil, nil, fmt.Errorf("could not create scheduling pool (v1): %w", err)
}
schedulingPoolV1.Extensions.Add(v1.NewPrometheusExtension())
cleanup = func() error {
log.Printf("cleaning up server config")
if err := cleanupSchedulingPool(); err != nil {
return fmt.Errorf("error cleaning up scheduling pool: %w", err)
}
if err := cleanupSchedulingPoolV1(); err != nil {
return fmt.Errorf("error cleaning up scheduling pool (v1): %w", err)
}
if err := cleanup1(); err != nil {
return fmt.Errorf("error cleaning up rabbitmq: %w", err)
}
return nil
}
services := cf.Services
// edge case to support backwards-compatibility with the services array in the config file
if cf.ServicesString != "" {
services = strings.Split(cf.ServicesString, " ")
}
pausedControllers := make(map[string]bool)
if cf.PausedControllers != "" {
for _, controller := range strings.Split(cf.PausedControllers, " ") {
pausedControllers[controller] = true
}
}
if cf.Runtime.Monitoring.TLSRootCAFile == "" {
cf.Runtime.Monitoring.TLSRootCAFile = cf.TLS.TLSRootCAFile
}
internalClientFactory, err := loadInternalClient(&l, &cf.InternalClient, cf.TLS, cf.Runtime.GRPCBroadcastAddress, cf.Runtime.GRPCInsecure)
if err != nil {
return nil, nil, fmt.Errorf("could not load internal client: %w", err)
}
return cleanup, &server.ServerConfig{
Alerter: alerter,
Analytics: analyticsEmitter,
FePosthog: feAnalyticsConfig,
Pylon: &pylon,
Runtime: cf.Runtime,
Auth: auth,
Encryption: encryptionSvc,
Layer: dc,
MessageQueue: mq,
MessageQueueV1: mqv1,
Services: services,
PausedControllers: pausedControllers,
InternalClientFactory: internalClientFactory,
Logger: &l,
TLSConfig: tls,
SessionStore: ss,
Validator: v,
Ingestor: ing,
OpenTelemetry: cf.OpenTelemetry,
Prometheus: cf.Prometheus,
Email: emailSvc,
TenantAlerter: alerting.New(dc.EngineRepository, encryptionSvc, cf.Runtime.ServerURL, emailSvc),
AdditionalOAuthConfigs: additionalOAuthConfigs,
AdditionalLoggers: cf.AdditionalLoggers,
EnableDataRetention: cf.EnableDataRetention,
EnableWorkerRetention: cf.EnableWorkerRetention,
SchedulingPool: schedulingPool,
SchedulingPoolV1: schedulingPoolV1,
Version: version,
Sampling: cf.Sampling,
Operations: cf.OLAP,
CronOperations: cf.CronOperations,
OLAPStatusUpdates: cf.OLAPStatusUpdates,
}, nil
}
func getStrArr(v string) []string {
return strings.Split(v, " ")
}
func LoadEncryptionSvc(cf *server.ServerConfigFile) (encryption.EncryptionService, error) {
var err error
hasLocalMasterKeyset := cf.Encryption.MasterKeyset != "" || cf.Encryption.MasterKeysetFile != ""
isCloudKMSEnabled := cf.Encryption.CloudKMS.Enabled
if !hasLocalMasterKeyset && !isCloudKMSEnabled {
return nil, fmt.Errorf("encryption is required")
}
if hasLocalMasterKeyset && isCloudKMSEnabled {
return nil, fmt.Errorf("cannot use both encryption and cloud kms")
}
hasJWTKeys := (cf.Encryption.JWT.PublicJWTKeyset != "" || cf.Encryption.JWT.PublicJWTKeysetFile != "") &&
(cf.Encryption.JWT.PrivateJWTKeyset != "" || cf.Encryption.JWT.PrivateJWTKeysetFile != "")
if !hasJWTKeys {
return nil, fmt.Errorf("jwt encryption is required")
}
privateJWT := cf.Encryption.JWT.PrivateJWTKeyset
if cf.Encryption.JWT.PrivateJWTKeysetFile != "" {
privateJWTBytes, err := loaderutils.GetFileBytes(cf.Encryption.JWT.PrivateJWTKeysetFile)
if err != nil {
return nil, fmt.Errorf("could not load private jwt keyset file: %w", err)
}
privateJWT = string(privateJWTBytes)
}
publicJWT := cf.Encryption.JWT.PublicJWTKeyset
if cf.Encryption.JWT.PublicJWTKeysetFile != "" {
publicJWTBytes, err := loaderutils.GetFileBytes(cf.Encryption.JWT.PublicJWTKeysetFile)
if err != nil {
return nil, fmt.Errorf("could not load public jwt keyset file: %w", err)
}
publicJWT = string(publicJWTBytes)
}
var encryptionSvc encryption.EncryptionService
if hasLocalMasterKeyset {
masterKeyset := cf.Encryption.MasterKeyset
if cf.Encryption.MasterKeysetFile != "" {
masterKeysetBytes, err := loaderutils.GetFileBytes(cf.Encryption.MasterKeysetFile)
if err != nil {
return nil, fmt.Errorf("could not load master keyset file: %w", err)
}
masterKeyset = string(masterKeysetBytes)
}
encryptionSvc, err = encryption.NewLocalEncryption(
[]byte(masterKeyset),
[]byte(privateJWT),
[]byte(publicJWT),
)
if err != nil {
return nil, fmt.Errorf("could not create raw keyset encryption service: %w", err)
}
}
if isCloudKMSEnabled {
encryptionSvc, err = encryption.NewCloudKMSEncryption(
cf.Encryption.CloudKMS.KeyURI,
[]byte(cf.Encryption.CloudKMS.CredentialsJSON),
[]byte(privateJWT),
[]byte(publicJWT),
)
if err != nil {
return nil, fmt.Errorf("could not create CloudKMS encryption service: %w", err)
}
}
return encryptionSvc, nil
}
func loadInternalClient(l *zerolog.Logger, conf *server.InternalClientTLSConfigFile, baseServerTLS shared.TLSConfigFile, grpcBroadcastAddress string, grpcInsecure bool) (*clientv1.GRPCClientFactory, error) {
// get gRPC broadcast address
broadcastAddress := grpcBroadcastAddress
if conf.InternalGRPCBroadcastAddress != "" {
broadcastAddress = conf.InternalGRPCBroadcastAddress
}
tlsServerName := conf.TLSServerName
if tlsServerName == "" {
// parse host from broadcast address
host, _, err := net.SplitHostPort(broadcastAddress)
if err != nil {
return nil, fmt.Errorf("could not parse host from broadcast address %s: %w", broadcastAddress, err)
}
tlsServerName = host
}
// construct TLS config
var base shared.TLSConfigFile
if conf.InheritBase {
base = baseServerTLS
if grpcInsecure {
base.TLSStrategy = "none"
}
} else {
base = conf.Base
}
tlsConfig, err := loaderutils.LoadClientTLSConfig(&client.ClientTLSConfigFile{
Base: base,
TLSServerName: tlsServerName,
}, tlsServerName)
if err != nil {
return nil, fmt.Errorf("could not load client TLS config: %w", err)
}
return clientv1.NewGRPCClientFactory(
clientv1.WithHostPort(broadcastAddress),
clientv1.WithTLS(tlsConfig),
clientv1.WithLogger(l),
), nil
}
// checkDatabaseTimezone validates that the database instance timezone is set to UTC.
// It creates a temporary connection to check the timezone without using the AfterConnect hook.
func checkDatabaseTimezone(connConfig *pgx.ConnConfig, dbName string, dbLabel string, l *zerolog.Logger) error {
tempConn, err := pgx.ConnectConfig(context.Background(), connConfig)
if err != nil {
return fmt.Errorf("could not create temporary connection to %s to check timezone: %w", dbLabel, err)
}
defer tempConn.Close(context.Background())
var dbTimezone string
if err := tempConn.QueryRow(context.Background(), "SHOW timezone").Scan(&dbTimezone); err != nil {
return fmt.Errorf("could not query %s timezone: %w", dbLabel, err)
}
// Accept both "UTC" and "Etc/UTC" as valid UTC timezones
if dbTimezone != "UTC" && dbTimezone != "Etc/UTC" {
if dbName == "" {
dbName = "<your_database_name>"
}
return fmt.Errorf(
"%s instance timezone is set to '%s' but must be 'UTC' or 'Etc/UTC'\n"+
"This check ensures time-based operations work correctly across all sessions\n"+
"To fix this issue, you have two options:\n"+
" 1. Set your PostgreSQL instance timezone to UTC by running: ALTER DATABASE %s SET TIMEZONE='UTC'\n"+
" 2. Disable this check by setting the environment variable: DATABASE_ENFORCE_UTC_TIMEZONE=false\n"+
"Note: Disabling this check is not recommended as it may lead to timezone-related issues",
dbLabel, dbTimezone, dbName,
)
}
l.Info().Msgf("%s instance timezone verified: %s", dbLabel, dbTimezone)
return nil
}