mirror of
https://github.com/hatchet-dev/hatchet.git
synced 2025-12-20 08:10:26 -06:00
411 lines
11 KiB
Go
411 lines
11 KiB
Go
// Adapted from: https://github.com/hatchet-dev/hatchet-v1-archived/blob/3c2c13168afa1af68d4baaf5ed02c9d49c5f0323/internal/config/loader/loader.go
|
|
|
|
package loader
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
"github.com/jackc/pgx/v5/tracelog"
|
|
|
|
pgxzero "github.com/jackc/pgx-zerolog"
|
|
|
|
"github.com/hatchet-dev/hatchet/internal/auth/cookie"
|
|
"github.com/hatchet-dev/hatchet/internal/auth/oauth"
|
|
"github.com/hatchet-dev/hatchet/internal/auth/token"
|
|
clientconfig "github.com/hatchet-dev/hatchet/internal/config/client"
|
|
"github.com/hatchet-dev/hatchet/internal/config/database"
|
|
"github.com/hatchet-dev/hatchet/internal/config/loader/loaderutils"
|
|
"github.com/hatchet-dev/hatchet/internal/config/server"
|
|
"github.com/hatchet-dev/hatchet/internal/encryption"
|
|
"github.com/hatchet-dev/hatchet/internal/integrations/vcs"
|
|
"github.com/hatchet-dev/hatchet/internal/integrations/vcs/github"
|
|
"github.com/hatchet-dev/hatchet/internal/logger"
|
|
"github.com/hatchet-dev/hatchet/internal/repository/prisma"
|
|
"github.com/hatchet-dev/hatchet/internal/repository/prisma/db"
|
|
"github.com/hatchet-dev/hatchet/internal/services/ingestor"
|
|
"github.com/hatchet-dev/hatchet/internal/taskqueue/rabbitmq"
|
|
"github.com/hatchet-dev/hatchet/internal/validator"
|
|
"github.com/hatchet-dev/hatchet/pkg/client"
|
|
)
|
|
|
|
// LoadDatabaseConfigFile loads the database config file via viper
|
|
func LoadDatabaseConfigFile(files ...[]byte) (*database.ConfigFile, error) {
|
|
configFile := &database.ConfigFile{}
|
|
f := database.BindAllEnv
|
|
|
|
_, err := loaderutils.LoadConfigFromViper(f, configFile, files...)
|
|
|
|
return configFile, err
|
|
}
|
|
|
|
// LoadServerConfigFile loads the server config file via viper
|
|
func LoadServerConfigFile(files ...[]byte) (*server.ServerConfigFile, error) {
|
|
configFile := &server.ServerConfigFile{}
|
|
f := server.BindAllEnv
|
|
|
|
_, err := loaderutils.LoadConfigFromViper(f, configFile, files...)
|
|
|
|
return configFile, err
|
|
}
|
|
|
|
type ConfigLoader struct {
|
|
directory string
|
|
}
|
|
|
|
func NewConfigLoader(directory string) *ConfigLoader {
|
|
return &ConfigLoader{directory}
|
|
}
|
|
|
|
// LoadDatabaseConfig loads the database configuration
|
|
func (c *ConfigLoader) LoadDatabaseConfig() (res *database.Config, err error) {
|
|
sharedFilePath := filepath.Join(c.directory, "database.yaml")
|
|
configFileBytes, err := loaderutils.GetConfigBytes(sharedFilePath)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cf, err := LoadDatabaseConfigFile(configFileBytes...)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return GetDatabaseConfigFromConfigFile(cf)
|
|
}
|
|
|
|
// LoadServerConfig loads the server configuration
|
|
func (c *ConfigLoader) LoadServerConfig() (cleanup func() error, res *server.ServerConfig, err error) {
|
|
log.Printf("Loading server config from %s", c.directory)
|
|
sharedFilePath := filepath.Join(c.directory, "server.yaml")
|
|
log.Printf("Shared file path: %s", sharedFilePath)
|
|
|
|
configFileBytes, err := loaderutils.GetConfigBytes(sharedFilePath)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
dc, err := c.LoadDatabaseConfig()
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
cf, err := LoadServerConfigFile(configFileBytes...)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
return GetServerConfigFromConfigfile(dc, cf)
|
|
}
|
|
|
|
func GetDatabaseConfigFromConfigFile(cf *database.ConfigFile) (res *database.Config, err error) {
|
|
l := logger.NewStdErr(&cf.Logger, "database")
|
|
|
|
databaseUrl := fmt.Sprintf(
|
|
"postgresql://%s:%s@%s:%d/%s?sslmode=%s",
|
|
cf.PostgresUsername,
|
|
cf.PostgresPassword,
|
|
cf.PostgresHost,
|
|
cf.PostgresPort,
|
|
cf.PostgresDbName,
|
|
cf.PostgresSSLMode,
|
|
)
|
|
|
|
os.Setenv("DATABASE_URL", databaseUrl)
|
|
|
|
c := db.NewClient(
|
|
// db.WithDatasourceURL(databaseUrl),
|
|
)
|
|
|
|
if err := c.Prisma.Connect(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
config, err := pgxpool.ParseConfig(databaseUrl)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if cf.LogQueries {
|
|
config.ConnConfig.Tracer = &tracelog.TraceLog{
|
|
Logger: pgxzero.NewLogger(l),
|
|
LogLevel: tracelog.LogLevelDebug,
|
|
}
|
|
}
|
|
|
|
config.MaxConns = 20
|
|
|
|
pool, err := pgxpool.NewWithConfig(context.Background(), config)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not connect to database: %w", err)
|
|
}
|
|
|
|
return &database.Config{
|
|
Disconnect: c.Prisma.Disconnect,
|
|
Repository: prisma.NewPrismaRepository(c, pool, prisma.WithLogger(&l)),
|
|
Seed: cf.Seed,
|
|
}, nil
|
|
}
|
|
|
|
func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigFile) (cleanup func() error, res *server.ServerConfig, err error) {
|
|
l := logger.NewStdErr(&cf.Logger, "server")
|
|
|
|
tls, err := loaderutils.LoadServerTLSConfig(&cf.TLS)
|
|
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("could not load TLS config: %w", err)
|
|
}
|
|
|
|
ss, err := cookie.NewUserSessionStore(
|
|
cookie.WithSessionRepository(dc.Repository.UserSession()),
|
|
cookie.WithCookieAllowInsecure(cf.Auth.Cookie.Insecure),
|
|
cookie.WithCookieDomain(cf.Auth.Cookie.Domain),
|
|
cookie.WithCookieName(cf.Auth.Cookie.Name),
|
|
cookie.WithCookieSecrets(getStrArr(cf.Auth.Cookie.Secrets)...),
|
|
)
|
|
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("could not create session store: %w", err)
|
|
}
|
|
|
|
cleanup1, tq := rabbitmq.New(
|
|
rabbitmq.WithURL(cf.TaskQueue.RabbitMQ.URL),
|
|
rabbitmq.WithLogger(&l),
|
|
)
|
|
|
|
ingestor, err := ingestor.NewIngestor(
|
|
ingestor.WithEventRepository(dc.Repository.Event()),
|
|
ingestor.WithTaskQueue(tq),
|
|
)
|
|
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("could not create ingestor: %w", err)
|
|
}
|
|
|
|
auth := server.AuthConfig{
|
|
ConfigFile: cf.Auth,
|
|
}
|
|
|
|
if cf.Auth.Google.Enabled {
|
|
if cf.Auth.Google.ClientID == "" {
|
|
return nil, nil, fmt.Errorf("google client id is required")
|
|
}
|
|
|
|
if cf.Auth.Google.ClientSecret == "" {
|
|
return nil, nil, fmt.Errorf("google client secret is required")
|
|
}
|
|
|
|
gClient := oauth.NewGoogleClient(&oauth.Config{
|
|
ClientID: cf.Auth.Google.ClientID,
|
|
ClientSecret: cf.Auth.Google.ClientSecret,
|
|
BaseURL: cf.Runtime.ServerURL,
|
|
Scopes: cf.Auth.Google.Scopes,
|
|
})
|
|
|
|
auth.GoogleOAuthConfig = gClient
|
|
}
|
|
|
|
encryptionSvc, err := loadEncryptionSvc(cf)
|
|
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("could not load encryption service: %w", err)
|
|
}
|
|
|
|
// create a new JWT manager
|
|
auth.JWTManager, err = token.NewJWTManager(encryptionSvc, dc.Repository.APIToken(), &token.TokenOpts{
|
|
Issuer: cf.Runtime.ServerURL,
|
|
Audience: cf.Runtime.ServerURL,
|
|
GRPCBroadcastAddress: cf.Runtime.GRPCBroadcastAddress,
|
|
ServerURL: cf.Runtime.ServerURL,
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("could not create JWT manager: %w", err)
|
|
}
|
|
|
|
vcsProviders := make(map[vcs.VCSRepositoryKind]vcs.VCSProvider)
|
|
|
|
if cf.VCS.Github.Enabled {
|
|
var err error
|
|
|
|
githubAppConf, err := github.NewGithubAppConf(
|
|
&oauth.Config{
|
|
ClientID: cf.VCS.Github.GithubAppClientID,
|
|
ClientSecret: cf.VCS.Github.GithubAppClientSecret,
|
|
Scopes: []string{"read:user"},
|
|
BaseURL: cf.Runtime.ServerURL,
|
|
},
|
|
cf.VCS.Github.GithubAppName,
|
|
cf.VCS.Github.GithubAppSecretPath,
|
|
cf.VCS.Github.GithubAppWebhookSecret,
|
|
cf.VCS.Github.GithubAppWebhookURL,
|
|
cf.VCS.Github.GithubAppID,
|
|
)
|
|
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
githubProvider := github.NewGithubVCSProvider(githubAppConf, dc.Repository, cf.Runtime.ServerURL, encryptionSvc)
|
|
|
|
vcsProviders[vcs.VCSRepositoryKindGithub] = githubProvider
|
|
}
|
|
|
|
var internalClient client.Client
|
|
|
|
if cf.Runtime.WorkerEnabled {
|
|
// get the internal tenant or create if it doesn't exist
|
|
internalTenant, err := dc.Repository.Tenant().GetTenantBySlug("internal")
|
|
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("could not get internal tenant: %w", err)
|
|
}
|
|
|
|
tokenSuffix, err := encryption.GenerateRandomBytes(4)
|
|
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("could not generate token suffix: %w", err)
|
|
}
|
|
|
|
// generate a token for the internal client
|
|
token, err := auth.JWTManager.GenerateTenantToken(internalTenant.ID, fmt.Sprintf("internal-%s", tokenSuffix))
|
|
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("could not generate internal token: %w", err)
|
|
}
|
|
|
|
internalClient, err = client.NewFromConfigFile(
|
|
&clientconfig.ClientConfigFile{
|
|
Token: token,
|
|
HostPort: cf.Runtime.GRPCBroadcastAddress,
|
|
},
|
|
)
|
|
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("could not create internal client: %w", err)
|
|
}
|
|
}
|
|
|
|
cleanup = func() error {
|
|
log.Printf("cleaning up server config")
|
|
if err := cleanup1(); err != nil {
|
|
return fmt.Errorf("error cleaning up rabbitmq: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
return cleanup, &server.ServerConfig{
|
|
Runtime: cf.Runtime,
|
|
Auth: auth,
|
|
Encryption: encryptionSvc,
|
|
Config: dc,
|
|
TaskQueue: tq,
|
|
Services: cf.Services,
|
|
Logger: &l,
|
|
TLSConfig: tls,
|
|
SessionStore: ss,
|
|
Validator: validator.NewDefaultValidator(),
|
|
Ingestor: ingestor,
|
|
OpenTelemetry: cf.OpenTelemetry,
|
|
VCSProviders: vcsProviders,
|
|
InternalClient: internalClient,
|
|
}, nil
|
|
}
|
|
|
|
func getStrArr(v string) []string {
|
|
return strings.Split(v, " ")
|
|
}
|
|
|
|
func loadEncryptionSvc(cf *server.ServerConfigFile) (encryption.EncryptionService, error) {
|
|
var err error
|
|
|
|
hasLocalMasterKeyset := cf.Encryption.MasterKeyset != "" || cf.Encryption.MasterKeysetFile != ""
|
|
isCloudKMSEnabled := cf.Encryption.CloudKMS.Enabled
|
|
|
|
if !hasLocalMasterKeyset && !isCloudKMSEnabled {
|
|
return nil, fmt.Errorf("encryption is required")
|
|
}
|
|
|
|
if hasLocalMasterKeyset && isCloudKMSEnabled {
|
|
return nil, fmt.Errorf("cannot use both encryption and cloud kms")
|
|
}
|
|
|
|
hasJWTKeys := (cf.Encryption.JWT.PublicJWTKeyset != "" || cf.Encryption.JWT.PublicJWTKeysetFile != "") &&
|
|
(cf.Encryption.JWT.PrivateJWTKeyset != "" || cf.Encryption.JWT.PrivateJWTKeysetFile != "")
|
|
|
|
if !hasJWTKeys {
|
|
return nil, fmt.Errorf("jwt encryption is required")
|
|
}
|
|
|
|
privateJWT := cf.Encryption.JWT.PrivateJWTKeyset
|
|
|
|
if cf.Encryption.JWT.PrivateJWTKeysetFile != "" {
|
|
privateJWTBytes, err := loaderutils.GetFileBytes(cf.Encryption.JWT.PrivateJWTKeysetFile)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not load private jwt keyset file: %w", err)
|
|
}
|
|
|
|
privateJWT = string(privateJWTBytes)
|
|
}
|
|
|
|
publicJWT := cf.Encryption.JWT.PublicJWTKeyset
|
|
|
|
if cf.Encryption.JWT.PublicJWTKeysetFile != "" {
|
|
publicJWTBytes, err := loaderutils.GetFileBytes(cf.Encryption.JWT.PublicJWTKeysetFile)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not load public jwt keyset file: %w", err)
|
|
}
|
|
|
|
publicJWT = string(publicJWTBytes)
|
|
}
|
|
|
|
var encryptionSvc encryption.EncryptionService
|
|
|
|
if hasLocalMasterKeyset {
|
|
masterKeyset := cf.Encryption.MasterKeyset
|
|
|
|
if cf.Encryption.MasterKeysetFile != "" {
|
|
masterKeysetBytes, err := loaderutils.GetFileBytes(cf.Encryption.MasterKeysetFile)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not load master keyset file: %w", err)
|
|
}
|
|
|
|
masterKeyset = string(masterKeysetBytes)
|
|
}
|
|
|
|
encryptionSvc, err = encryption.NewLocalEncryption(
|
|
[]byte(masterKeyset),
|
|
[]byte(privateJWT),
|
|
[]byte(publicJWT),
|
|
)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not create raw keyset encryption service: %w", err)
|
|
}
|
|
}
|
|
|
|
if isCloudKMSEnabled {
|
|
encryptionSvc, err = encryption.NewCloudKMSEncryption(
|
|
cf.Encryption.CloudKMS.KeyURI,
|
|
[]byte(cf.Encryption.CloudKMS.CredentialsJSON),
|
|
[]byte(privateJWT),
|
|
[]byte(publicJWT),
|
|
)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not create CloudKMS encryption service: %w", err)
|
|
}
|
|
}
|
|
|
|
return encryptionSvc, nil
|
|
}
|