From a8dd33c61f30d0fc9cd285b25943da3678110b04 Mon Sep 17 00:00:00 2001 From: Sean Reilly Date: Fri, 17 Jan 2025 15:34:10 -0800 Subject: [PATCH] Feature - configurable logging backend (#1188) * allow us to configure different repos * make the struct contents public * pass in config values to new log repo * rename functions - possibly breaking changes so lets discuss * make the logging backend configurable * fix tests * don't allow calls to WithAdditionalConfig * cleanup * replace sc with server Co-authored-by: abelanger5 * rename sc to server * add a LRU cache for the step run lookup * lets not use an expirable cache and just use the regular one - we cannot close the go func in exirable --------- Co-authored-by: abelanger5 --- cmd/hatchet-admin/cli/k8s.go | 8 +- cmd/hatchet-admin/cli/seed.go | 2 +- cmd/hatchet-admin/cli/token.go | 8 +- cmd/hatchet-api/api/run.go | 10 +- cmd/hatchet-engine/engine/run.go | 12 ++- cmd/hatchet-lite/main.go | 2 +- examples/logging/main.go | 127 +++++++++++++++++++++++++ internal/services/ingestor/ingestor.go | 37 +++++-- internal/services/ingestor/server.go | 14 +++ internal/testutils/env.go | 10 +- internal/testutils/with_database.go | 4 +- pkg/auth/cookie/sessionstore_test.go | 6 +- pkg/auth/token/token_test.go | 8 +- pkg/config/database/config.go | 2 +- pkg/config/loader/loader.go | 112 +++++++++++++--------- pkg/config/server/server.go | 2 +- pkg/repository/logs.go | 5 + pkg/repository/prisma/log.go | 11 +++ pkg/repository/prisma/repository.go | 40 ++++++-- 19 files changed, 324 insertions(+), 96 deletions(-) create mode 100644 examples/logging/main.go diff --git a/cmd/hatchet-admin/cli/k8s.go b/cmd/hatchet-admin/cli/k8s.go index 057e64ec2..2bc333f54 100644 --- a/cmd/hatchet-admin/cli/k8s.go +++ b/cmd/hatchet-admin/cli/k8s.go @@ -236,7 +236,7 @@ func runCreateWorkerToken() error { // read in the local config configLoader := loader.NewConfigLoader(configDirectory) - cleanup, serverConf, err := configLoader.LoadServerConfig("", func(scf *server.ServerConfigFile) { + cleanup, server, err := configLoader.CreateServerFromConfig("", func(scf *server.ServerConfigFile) { // disable rabbitmq since it's not needed to create the api token scf.MessageQueue.Enabled = false @@ -250,17 +250,17 @@ func runCreateWorkerToken() error { defer cleanup() // nolint:errcheck - defer serverConf.Disconnect() // nolint:errcheck + defer server.Disconnect() // nolint:errcheck expiresAt := time.Now().UTC().Add(100 * 365 * 24 * time.Hour) tenantId := tokenTenantId if tenantId == "" { - tenantId = serverConf.Seed.DefaultTenantID + tenantId = server.Seed.DefaultTenantID } - defaultTok, err := serverConf.Auth.JWTManager.GenerateTenantToken(context.Background(), tenantId, tokenName, false, &expiresAt) + defaultTok, err := server.Auth.JWTManager.GenerateTenantToken(context.Background(), tenantId, tokenName, false, &expiresAt) if err != nil { return err diff --git a/cmd/hatchet-admin/cli/seed.go b/cmd/hatchet-admin/cli/seed.go index 4f3935bb7..fd7f0f974 100644 --- a/cmd/hatchet-admin/cli/seed.go +++ b/cmd/hatchet-admin/cli/seed.go @@ -39,7 +39,7 @@ func init() { func runSeed(cf *loader.ConfigLoader) error { // load the config - dc, err := cf.LoadDatabaseConfig() + dc, err := cf.InitDataLayer() if err != nil { panic(err) diff --git a/cmd/hatchet-admin/cli/token.go b/cmd/hatchet-admin/cli/token.go index 45a102fd3..119aca953 100644 --- a/cmd/hatchet-admin/cli/token.go +++ b/cmd/hatchet-admin/cli/token.go @@ -72,7 +72,7 @@ func runCreateAPIToken(expiresIn time.Duration) error { // read in the local config configLoader := loader.NewConfigLoader(configDirectory) - cleanup, serverConf, err := configLoader.LoadServerConfig("", func(scf *server.ServerConfigFile) { + cleanup, server, err := configLoader.CreateServerFromConfig("", func(scf *server.ServerConfigFile) { // disable rabbitmq since it's not needed to create the api token scf.MessageQueue.Enabled = false @@ -86,17 +86,17 @@ func runCreateAPIToken(expiresIn time.Duration) error { defer cleanup() // nolint:errcheck - defer serverConf.Disconnect() // nolint:errcheck + defer server.Disconnect() // nolint:errcheck expiresAt := time.Now().UTC().Add(expiresIn) tenantId := tokenTenantId if tenantId == "" { - tenantId = serverConf.Seed.DefaultTenantID + tenantId = server.Seed.DefaultTenantID } - defaultTok, err := serverConf.Auth.JWTManager.GenerateTenantToken(context.Background(), tenantId, tokenName, false, &expiresAt) + defaultTok, err := server.Auth.JWTManager.GenerateTenantToken(context.Background(), tenantId, tokenName, false, &expiresAt) if err != nil { return err diff --git a/cmd/hatchet-api/api/run.go b/cmd/hatchet-api/api/run.go index 3075423c8..5241ad606 100644 --- a/cmd/hatchet-api/api/run.go +++ b/cmd/hatchet-api/api/run.go @@ -9,14 +9,14 @@ import ( func Start(cf *loader.ConfigLoader, interruptCh <-chan interface{}, version string) error { // init the repository - configCleanup, sc, err := cf.LoadServerConfig(version) + configCleanup, server, err := cf.CreateServerFromConfig(version) if err != nil { return fmt.Errorf("error loading server config: %w", err) } var teardown []func() error - runner := run.NewAPIServer(sc) + runner := run.NewAPIServer(server) if err != nil { return err @@ -30,11 +30,11 @@ func Start(cf *loader.ConfigLoader, interruptCh <-chan interface{}, version stri teardown = append(teardown, apiCleanup) teardown = append(teardown, configCleanup) - sc.Logger.Debug().Msgf("api started successfully") + server.Logger.Debug().Msgf("api started successfully") <-interruptCh - sc.Logger.Debug().Msgf("api is shutting down...") + server.Logger.Debug().Msgf("api is shutting down...") for _, teardown := range teardown { if err := teardown(); err != nil { @@ -42,7 +42,7 @@ func Start(cf *loader.ConfigLoader, interruptCh <-chan interface{}, version stri } } - sc.Logger.Debug().Msgf("api successfully shut down") + server.Logger.Debug().Msgf("api successfully shut down") return nil } diff --git a/cmd/hatchet-engine/engine/run.go b/cmd/hatchet-engine/engine/run.go index e39f8c6c5..2aa1d07af 100644 --- a/cmd/hatchet-engine/engine/run.go +++ b/cmd/hatchet-engine/engine/run.go @@ -60,14 +60,14 @@ func init() { } func Run(ctx context.Context, cf *loader.ConfigLoader, version string) error { - serverCleanup, sc, err := cf.LoadServerConfig(version) + serverCleanup, server, err := cf.CreateServerFromConfig(version) if err != nil { return fmt.Errorf("could not load server config: %w", err) } - var l = sc.Logger + var l = server.Logger - teardown, err := RunWithConfig(ctx, sc) + teardown, err := RunWithConfig(ctx, server) if err != nil { return fmt.Errorf("could not run with config: %w", err) @@ -82,11 +82,11 @@ func Run(ctx context.Context, cf *loader.ConfigLoader, version string) error { teardown = append(teardown, Teardown{ Name: "database", Fn: func() error { - return sc.Disconnect() + return server.Disconnect() }, }) - time.Sleep(sc.Runtime.ShutdownWait) + time.Sleep(server.Runtime.ShutdownWait) l.Debug().Msgf("interrupt received, shutting down") @@ -360,6 +360,7 @@ func runV0Config(ctx context.Context, sc *server.ServerConfig) ([]Teardown, erro ), ingestor.WithMessageQueue(sc.MessageQueue), ingestor.WithEntitlementsRepository(sc.EntitlementRepository), + ingestor.WithStepRunRepository(sc.EngineRepository.StepRun()), ) if err != nil { @@ -720,6 +721,7 @@ func runV1Config(ctx context.Context, sc *server.ServerConfig) ([]Teardown, erro ), ingestor.WithMessageQueue(sc.MessageQueue), ingestor.WithEntitlementsRepository(sc.EntitlementRepository), + ingestor.WithStepRunRepository(sc.EngineRepository.StepRun()), ) if err != nil { diff --git a/cmd/hatchet-lite/main.go b/cmd/hatchet-lite/main.go index ed42cfa14..7380e9f79 100644 --- a/cmd/hatchet-lite/main.go +++ b/cmd/hatchet-lite/main.go @@ -91,7 +91,7 @@ func start(cf *loader.ConfigLoader, interruptCh <-chan interface{}, version stri return fmt.Errorf("error parsing frontend URL: %w", err) } - _, sc, err := cf.LoadServerConfig(version) + _, sc, err := cf.CreateServerFromConfig(version) if err != nil { return fmt.Errorf("error loading server config: %w", err) diff --git a/examples/logging/main.go b/examples/logging/main.go new file mode 100644 index 000000000..59064100d --- /dev/null +++ b/examples/logging/main.go @@ -0,0 +1,127 @@ +package main + +import ( + "context" + "fmt" + "log" + + "github.com/joho/godotenv" + + "github.com/hatchet-dev/hatchet/pkg/client" + "github.com/hatchet-dev/hatchet/pkg/cmdutils" + "github.com/hatchet-dev/hatchet/pkg/worker" +) + +type userCreateEvent struct { + Username string `json:"username"` + UserID string `json:"user_id"` + Data map[string]string `json:"data"` +} + +type stepOneOutput struct { + Message string `json:"message"` +} + +func main() { + err := godotenv.Load() + if err != nil { + panic(err) + } + + events := make(chan string, 50) + interrupt := cmdutils.InterruptChan() + + cleanup, err := run(events) + if err != nil { + panic(err) + } + + <-interrupt + + if err := cleanup(); err != nil { + panic(fmt.Errorf("error cleaning up: %w", err)) + } +} + +func run(events chan<- string) (func() error, error) { + c, err := client.New() + + if err != nil { + return nil, fmt.Errorf("error creating client: %w", err) + } + + w, err := worker.NewWorker( + worker.WithClient( + c, + ), + ) + if err != nil { + return nil, fmt.Errorf("error creating worker: %w", err) + } + + err = w.RegisterWorkflow( + &worker.WorkflowJob{ + On: worker.Events("user:log:simple"), + Name: "simple", + Description: "This runs after an update to the user model.", + Concurrency: worker.Expression("input.user_id"), + Steps: []*worker.WorkflowStep{ + worker.Fn(func(ctx worker.HatchetContext) (result *stepOneOutput, err error) { + input := &userCreateEvent{} + + err = ctx.WorkflowInput(input) + + if err != nil { + return nil, err + } + + log.Printf("step-one") + events <- "step-one" + + for i := 0; i < 1000; i++ { + ctx.Log(fmt.Sprintf("step-one: %d", i)) + } + + return &stepOneOutput{ + Message: "Username is: " + input.Username, + }, nil + }, + ).SetName("step-one"), + }, + }, + ) + if err != nil { + return nil, fmt.Errorf("error registering workflow: %w", err) + } + + go func() { + testEvent := userCreateEvent{ + Username: "echo-test", + UserID: "1234", + Data: map[string]string{ + "test": "test", + }, + } + + log.Printf("pushing event user:create:simple") + // push an event + err := c.Event().Push( + context.Background(), + "user:log:simple", + testEvent, + client.WithEventMetadata(map[string]string{ + "hello": "world", + }), + ) + if err != nil { + panic(fmt.Errorf("error pushing event: %w", err)) + } + }() + + cleanup, err := w.Start() + if err != nil { + panic(err) + } + + return cleanup, nil +} diff --git a/internal/services/ingestor/ingestor.go b/internal/services/ingestor/ingestor.go index 71e0b1580..d888b65a2 100644 --- a/internal/services/ingestor/ingestor.go +++ b/internal/services/ingestor/ingestor.go @@ -4,6 +4,8 @@ import ( "context" "fmt" + lru "github.com/hashicorp/golang-lru/v2" + "github.com/hatchet-dev/hatchet/internal/datautils" "github.com/hatchet-dev/hatchet/internal/msgqueue" "github.com/hatchet-dev/hatchet/internal/services/ingestor/contracts" @@ -30,6 +32,7 @@ type IngestorOpts struct { streamEventRepository repository.StreamEventsEngineRepository logRepository repository.LogsEngineRepository entitlementsRepository repository.EntitlementsRepository + stepRunRepository repository.StepRunEngineRepository mq msgqueue.MessageQueue } @@ -63,6 +66,12 @@ func WithMessageQueue(mq msgqueue.MessageQueue) IngestorOptFunc { } } +func WithStepRunRepository(r repository.StepRunEngineRepository) IngestorOptFunc { + return func(opts *IngestorOpts) { + opts.stepRunRepository = r + } +} + func defaultIngestorOpts() *IngestorOpts { return &IngestorOpts{} } @@ -70,10 +79,12 @@ func defaultIngestorOpts() *IngestorOpts { type IngestorImpl struct { contracts.UnimplementedEventsServiceServer - eventRepository repository.EventEngineRepository - logRepository repository.LogsEngineRepository - streamEventRepository repository.StreamEventsEngineRepository - entitlementsRepository repository.EntitlementsRepository + eventRepository repository.EventEngineRepository + logRepository repository.LogsEngineRepository + streamEventRepository repository.StreamEventsEngineRepository + entitlementsRepository repository.EntitlementsRepository + stepRunRepository repository.StepRunEngineRepository + steprunTenantLookupCache *lru.Cache[string, string] mq msgqueue.MessageQueue v validator.Validator @@ -102,10 +113,22 @@ func NewIngestor(fs ...IngestorOptFunc) (Ingestor, error) { return nil, fmt.Errorf("task queue is required. use WithMessageQueue") } + if opts.stepRunRepository == nil { + return nil, fmt.Errorf("step run repository is required. use WithStepRunRepository") + } + // estimate of 1000 * 2 * UUID string size (roughly 104kb max) + stepRunCache, err := lru.New[string, string](1000) + + if err != nil { + return nil, fmt.Errorf("could not create step run cache: %w", err) + } + return &IngestorImpl{ - eventRepository: opts.eventRepository, - streamEventRepository: opts.streamEventRepository, - entitlementsRepository: opts.entitlementsRepository, + eventRepository: opts.eventRepository, + streamEventRepository: opts.streamEventRepository, + entitlementsRepository: opts.entitlementsRepository, + stepRunRepository: opts.stepRunRepository, + steprunTenantLookupCache: stepRunCache, logRepository: opts.logRepository, mq: opts.mq, diff --git a/internal/services/ingestor/server.go b/internal/services/ingestor/server.go index a10061e08..d416bb321 100644 --- a/internal/services/ingestor/server.go +++ b/internal/services/ingestor/server.go @@ -228,6 +228,20 @@ func (i *IngestorImpl) PutLog(ctx context.Context, req *contracts.PutLogRequest) return nil, status.Errorf(codes.InvalidArgument, "Invalid request: %s", apiErrors.String()) } + // Make sure we are writing to a step run owned by this tenant + if t, ok := i.steprunTenantLookupCache.Get(opts.StepRunId); ok { + if t != tenantId { + return nil, status.Errorf(codes.PermissionDenied, "Permission denied: step run does not belong to tenant") + } + // cache hit + } else { + if _, err := i.stepRunRepository.GetStepRunForEngine(ctx, tenantId, opts.StepRunId); err != nil { + return nil, err + } + + i.steprunTenantLookupCache.Add(opts.StepRunId, tenantId) + } + _, err := i.logRepository.PutLog(ctx, tenantId, opts) if err != nil { diff --git a/internal/testutils/env.go b/internal/testutils/env.go index c1c7bb5e4..01b2b8525 100644 --- a/internal/testutils/env.go +++ b/internal/testutils/env.go @@ -58,7 +58,7 @@ func Prepare(t *testing.T) { // read in the local config configLoader := loader.NewConfigLoader(path.Join(testPath, baseDir, "generated")) - cleanup, serverConf, err := configLoader.LoadServerConfig("", func(scf *server.ServerConfigFile) { + cleanup, server, err := configLoader.CreateServerFromConfig("", func(scf *server.ServerConfigFile) { // disable security checks since we're not running the server scf.SecurityCheck.Enabled = false }) @@ -67,10 +67,10 @@ func Prepare(t *testing.T) { } // check if tenant exists - _, err = serverConf.APIRepository.Tenant().GetTenantByID(tenantId) + _, err = server.APIRepository.Tenant().GetTenantByID(tenantId) if err != nil { if errors.Is(err, db.ErrNotFound) { - _, err = serverConf.APIRepository.Tenant().CreateTenant(&repository.CreateTenantOpts{ + _, err = server.APIRepository.Tenant().CreateTenant(&repository.CreateTenantOpts{ ID: &tenantId, Name: "test-tenant", Slug: "test-tenant", @@ -83,14 +83,14 @@ func Prepare(t *testing.T) { } } - defaultTok, err := serverConf.Auth.JWTManager.GenerateTenantToken(context.Background(), tenantId, "default", false, nil) + defaultTok, err := server.Auth.JWTManager.GenerateTenantToken(context.Background(), tenantId, "default", false, nil) if err != nil { t.Fatalf("could not generate default token: %v", err) } _ = os.Setenv("HATCHET_CLIENT_TOKEN", defaultTok.Token) - if err := serverConf.Disconnect(); err != nil { + if err := server.Disconnect(); err != nil { t.Fatalf("could not disconnect from server: %v", err) } diff --git a/internal/testutils/with_database.go b/internal/testutils/with_database.go index d0c6b7610..c4218517c 100644 --- a/internal/testutils/with_database.go +++ b/internal/testutils/with_database.go @@ -7,13 +7,13 @@ import ( "github.com/hatchet-dev/hatchet/pkg/config/loader" ) -func RunTestWithDatabase(t *testing.T, test func(config *database.Config) error) { +func RunTestWithDatabase(t *testing.T, test func(config *database.Layer) error) { t.Helper() Prepare(t) confLoader := &loader.ConfigLoader{} - conf, err := confLoader.LoadDatabaseConfig() + conf, err := confLoader.InitDataLayer() if err != nil { t.Fatalf("failed to load database config: %v\n", err) } diff --git a/pkg/auth/cookie/sessionstore_test.go b/pkg/auth/cookie/sessionstore_test.go index 4b78a83cf..7f7c8733b 100644 --- a/pkg/auth/cookie/sessionstore_test.go +++ b/pkg/auth/cookie/sessionstore_test.go @@ -18,7 +18,7 @@ import ( func TestSessionStoreSave(t *testing.T) { time.Sleep(10 * time.Second) // TODO temp hack for tenant non-upsert issue - testutils.RunTestWithDatabase(t, func(conf *database.Config) error { + testutils.RunTestWithDatabase(t, func(conf *database.Layer) error { const cookieName = "hatchet" ss := newSessionStore(t, conf, cookieName) @@ -36,7 +36,7 @@ func TestSessionStoreSave(t *testing.T) { } func TestSessionStoreGet(t *testing.T) { - testutils.RunTestWithDatabase(t, func(conf *database.Config) error { + testutils.RunTestWithDatabase(t, func(conf *database.Layer) error { const cookieName = "hatchet" ss := newSessionStore(t, conf, cookieName) @@ -64,7 +64,7 @@ func TestSessionStoreGet(t *testing.T) { }) } -func newSessionStore(t *testing.T, conf *database.Config, cookieName string) *cookie.UserSessionStore { +func newSessionStore(t *testing.T, conf *database.Layer, cookieName string) *cookie.UserSessionStore { hashKey, err := random.Generate(16) if err != nil { diff --git a/pkg/auth/token/token_test.go b/pkg/auth/token/token_test.go index 9011e3cf5..b597cf2e4 100644 --- a/pkg/auth/token/token_test.go +++ b/pkg/auth/token/token_test.go @@ -20,7 +20,7 @@ import ( ) func TestCreateTenantToken(t *testing.T) { // make sure no cache is used for tests - testutils.RunTestWithDatabase(t, func(conf *database.Config) error { + testutils.RunTestWithDatabase(t, func(conf *database.Layer) error { jwtManager := getJWTManager(t, conf) tenantId := uuid.New().String() @@ -61,7 +61,7 @@ func TestCreateTenantToken(t *testing.T) { // make sure no cache is used for tes func TestRevokeTenantToken(t *testing.T) { _ = os.Setenv("CACHE_DURATION", "0") - testutils.RunTestWithDatabase(t, func(conf *database.Config) error { + testutils.RunTestWithDatabase(t, func(conf *database.Layer) error { jwtManager := getJWTManager(t, conf) tenantId := uuid.New().String() @@ -121,7 +121,7 @@ func TestRevokeTenantToken(t *testing.T) { func TestRevokeTenantTokenCache(t *testing.T) { _ = os.Setenv("CACHE_DURATION", "60s") - testutils.RunTestWithDatabase(t, func(conf *database.Config) error { + testutils.RunTestWithDatabase(t, func(conf *database.Layer) error { jwtManager := getJWTManager(t, conf) tenantId := uuid.New().String() @@ -178,7 +178,7 @@ func TestRevokeTenantTokenCache(t *testing.T) { }) } -func getJWTManager(t *testing.T, conf *database.Config) token.JWTManager { +func getJWTManager(t *testing.T, conf *database.Layer) token.JWTManager { t.Helper() masterKeyBytes, privateJWTBytes, publicJWTBytes, err := encryption.GenerateLocalKeys() diff --git a/pkg/config/database/config.go b/pkg/config/database/config.go index 798da05f8..ea64a631e 100644 --- a/pkg/config/database/config.go +++ b/pkg/config/database/config.go @@ -45,7 +45,7 @@ type SeedConfigFile struct { IsDevelopment bool `mapstructure:"isDevelopment" json:"isDevelopment,omitempty" default:"false"` } -type Config struct { +type Layer struct { Disconnect func() error Pool *pgxpool.Pool diff --git a/pkg/config/loader/loader.go b/pkg/config/loader/loader.go index edd40b83e..6b89ff299 100644 --- a/pkg/config/loader/loader.go +++ b/pkg/config/loader/loader.go @@ -36,6 +36,7 @@ import ( "github.com/hatchet-dev/hatchet/pkg/errors" "github.com/hatchet-dev/hatchet/pkg/errors/sentry" "github.com/hatchet-dev/hatchet/pkg/logger" + "github.com/hatchet-dev/hatchet/pkg/repository" "github.com/hatchet-dev/hatchet/pkg/repository/cache" "github.com/hatchet-dev/hatchet/pkg/repository/metered" "github.com/hatchet-dev/hatchet/pkg/repository/prisma" @@ -64,16 +65,22 @@ func LoadServerConfigFile(files ...[]byte) (*server.ServerConfigFile, error) { return configFile, err } +type RepositoryOverrides struct { + LogsEngineRepository repository.LogsEngineRepository + LogsAPIRepository repository.LogsAPIRepository +} + type ConfigLoader struct { - directory string + directory string + RepositoryOverrides RepositoryOverrides } func NewConfigLoader(directory string) *ConfigLoader { - return &ConfigLoader{directory} + return &ConfigLoader{directory: directory} } -// LoadDatabaseConfig loads the database configuration -func (c *ConfigLoader) LoadDatabaseConfig() (res *database.Config, err error) { +// InitDataLayer initializes the database layer from the configuration +func (c *ConfigLoader) InitDataLayer() (res *database.Layer, err error) { sharedFilePath := filepath.Join(c.directory, "database.yaml") configFileBytes, err := loaderutils.GetConfigBytes(sharedFilePath) @@ -93,40 +100,6 @@ func (c *ConfigLoader) LoadDatabaseConfig() (res *database.Config, err error) { return nil, err } - return GetDatabaseConfigFromConfigFile(cf, &scf.Runtime) -} - -type ServerConfigFileOverride func(*server.ServerConfigFile) - -// LoadServerConfig loads the server configuration -func (c *ConfigLoader) LoadServerConfig(version string, overrides ...ServerConfigFileOverride) (cleanup func() error, res *server.ServerConfig, err error) { - log.Printf("Loading server config from %s", c.directory) - sharedFilePath := filepath.Join(c.directory, "server.yaml") - log.Printf("Shared file path: %s", sharedFilePath) - - configFileBytes, err := loaderutils.GetConfigBytes(sharedFilePath) - if err != nil { - return nil, nil, err - } - - dc, err := c.LoadDatabaseConfig() - if err != nil { - return nil, nil, err - } - - cf, err := LoadServerConfigFile(configFileBytes...) - if err != nil { - return nil, nil, err - } - - for _, override := range overrides { - override(cf) - } - - return GetServerConfigFromConfigfile(dc, cf, version) -} - -func GetDatabaseConfigFromConfigFile(cf *database.ConfigFile, runtime *server.ConfigFileRuntime) (res *database.Config, err error) { l := logger.NewStdErr(&cf.Logger, "database") databaseUrl := os.Getenv("DATABASE_URL") @@ -146,9 +119,9 @@ func GetDatabaseConfigFromConfigFile(cf *database.ConfigFile, runtime *server.Co _ = os.Setenv("DATABASE_URL", databaseUrl) } - c := db.NewClient() + client := db.NewClient() - if err := c.Prisma.Connect(); err != nil { + if err := client.Prisma.Connect(); err != nil { return nil, err } @@ -210,23 +183,35 @@ func GetDatabaseConfigFromConfigFile(cf *database.ConfigFile, runtime *server.Co ch := cache.New(cf.CacheDuration) - entitlementRepo := prisma.NewEntitlementRepository(pool, runtime, prisma.WithLogger(&l), prisma.WithCache(ch)) + entitlementRepo := prisma.NewEntitlementRepository(pool, &scf.Runtime, prisma.WithLogger(&l), prisma.WithCache(ch)) meter := metered.NewMetered(entitlementRepo, &l) - cleanupEngine, engineRepo, err := prisma.NewEngineRepository(pool, essentialPool, runtime, prisma.WithLogger(&l), prisma.WithCache(ch), prisma.WithMetered(meter)) + var opts []prisma.PrismaRepositoryOpt + + opts = append(opts, prisma.WithLogger(&l), prisma.WithCache(ch), prisma.WithMetered(meter)) + + if c.RepositoryOverrides.LogsEngineRepository != nil { + opts = append(opts, prisma.WithLogsEngineRepository(c.RepositoryOverrides.LogsEngineRepository)) + } + + cleanupEngine, engineRepo, err := prisma.NewEngineRepository(pool, essentialPool, &scf.Runtime, opts...) if err != nil { return nil, fmt.Errorf("could not create engine repository: %w", err) } - apiRepo, cleanupApiRepo, err := prisma.NewAPIRepository(c, pool, runtime, prisma.WithLogger(&l), prisma.WithCache(ch), prisma.WithMetered(meter)) + if c.RepositoryOverrides.LogsAPIRepository != nil { + opts = append(opts, prisma.WithLogsAPIRepository(c.RepositoryOverrides.LogsAPIRepository)) + } + + apiRepo, cleanupApiRepo, err := prisma.NewAPIRepository(client, pool, &scf.Runtime, opts...) if err != nil { return nil, fmt.Errorf("could not create api repository: %w", err) } - return &database.Config{ + return &database.Layer{ Disconnect: func() error { if err := cleanupEngine(); err != nil { return err @@ -237,7 +222,7 @@ func GetDatabaseConfigFromConfigFile(cf *database.ConfigFile, runtime *server.Co if err = cleanupApiRepo(); err != nil { return err } - return c.Prisma.Disconnect() + return client.Prisma.Disconnect() }, Pool: pool, EssentialPool: essentialPool, @@ -247,9 +232,41 @@ func GetDatabaseConfigFromConfigFile(cf *database.ConfigFile, runtime *server.Co EntitlementRepository: entitlementRepo, Seed: cf.Seed, }, nil + } -func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigFile, version string) (cleanup func() error, res *server.ServerConfig, err error) { +type ServerConfigFileOverride func(*server.ServerConfigFile) + +// CreateServerFromConfig loads the server configuration and returns a server +func (c *ConfigLoader) CreateServerFromConfig(version string, overrides ...ServerConfigFileOverride) (cleanup func() error, res *server.ServerConfig, err error) { + + log.Printf("Loading server config from %s", c.directory) + sharedFilePath := filepath.Join(c.directory, "server.yaml") + log.Printf("Shared file path: %s", sharedFilePath) + + configFileBytes, err := loaderutils.GetConfigBytes(sharedFilePath) + if err != nil { + return nil, nil, err + } + + dc, err := c.InitDataLayer() + if err != nil { + return nil, nil, err + } + + cf, err := LoadServerConfigFile(configFileBytes...) + if err != nil { + return nil, nil, err + } + + for _, override := range overrides { + override(cf) + } + + return createControllerLayer(dc, cf, version) +} + +func createControllerLayer(dc *database.Layer, cf *server.ServerConfigFile, version string) (cleanup func() error, res *server.ServerConfig, err error) { l := logger.NewStdErr(&cf.Logger, "server") queueLogger := logger.NewStdErr(&cf.AdditionalLoggers.Queue, "queue") @@ -303,6 +320,7 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF ingestor.WithLogRepository(dc.EngineRepository.Log()), ingestor.WithMessageQueue(mq), ingestor.WithEntitlementsRepository(dc.EntitlementRepository), + ingestor.WithStepRunRepository(dc.EngineRepository.StepRun()), ) if err != nil { @@ -496,7 +514,7 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF Runtime: cf.Runtime, Auth: auth, Encryption: encryptionSvc, - Config: dc, + Layer: dc, MessageQueue: mq, Services: services, Logger: &l, diff --git a/pkg/config/server/server.go b/pkg/config/server/server.go index d3bf2bc14..269e4773a 100644 --- a/pkg/config/server/server.go +++ b/pkg/config/server/server.go @@ -409,7 +409,7 @@ type FePosthogConfig struct { } type ServerConfig struct { - *database.Config + *database.Layer Auth AuthConfig diff --git a/pkg/repository/logs.go b/pkg/repository/logs.go index 08b11b362..7ce734f2d 100644 --- a/pkg/repository/logs.go +++ b/pkg/repository/logs.go @@ -4,7 +4,10 @@ import ( "context" "time" + "github.com/rs/zerolog" + "github.com/hatchet-dev/hatchet/pkg/repository/prisma/dbsqlc" + "github.com/hatchet-dev/hatchet/pkg/validator" ) type CreateLogLineOpts struct { @@ -55,9 +58,11 @@ type ListLogsResult struct { type LogsAPIRepository interface { // ListLogLines returns a list of log lines for a given step run. ListLogLines(tenantId string, opts *ListLogsOpts) (*ListLogsResult, error) + WithAdditionalConfig(validator.Validator, *zerolog.Logger) LogsAPIRepository } type LogsEngineRepository interface { // PutLog creates a new log line. PutLog(ctx context.Context, tenantId string, opts *CreateLogLineOpts) (*dbsqlc.LogLine, error) + WithAdditionalConfig(validator.Validator, *zerolog.Logger) LogsEngineRepository } diff --git a/pkg/repository/prisma/log.go b/pkg/repository/prisma/log.go index c7983ecf4..9f92bea2b 100644 --- a/pkg/repository/prisma/log.go +++ b/pkg/repository/prisma/log.go @@ -139,6 +139,17 @@ type logEngineRepository struct { l *zerolog.Logger } +// Used as hook a hook to allow for additional configuration to be passed to the repository if it is instantiated a different way +func (le *logAPIRepository) WithAdditionalConfig(v validator.Validator, l *zerolog.Logger) repository.LogsAPIRepository { + panic("not implemented in this repo") + +} + +// Used as hook a hook to allow for additional configuration to be passed to the repository if it is instantiated a different way +func (le *logEngineRepository) WithAdditionalConfig(v validator.Validator, l *zerolog.Logger) repository.LogsEngineRepository { + panic("not implemented in this repo") +} + func NewLogEngineRepository(pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.LogsEngineRepository { queries := dbsqlc.New() diff --git a/pkg/repository/prisma/repository.go b/pkg/repository/prisma/repository.go index 23c60b086..702fb70e2 100644 --- a/pkg/repository/prisma/repository.go +++ b/pkg/repository/prisma/repository.go @@ -40,10 +40,12 @@ type apiRepository struct { type PrismaRepositoryOpt func(*PrismaRepositoryOpts) type PrismaRepositoryOpts struct { - v validator.Validator - l *zerolog.Logger - cache cache.Cacheable - metered *metered.Metered + v validator.Validator + l *zerolog.Logger + cache cache.Cacheable + metered *metered.Metered + logsEngineRepository repository.LogsEngineRepository + logsAPIRepository repository.LogsAPIRepository } func defaultPrismaRepositoryOpts() *PrismaRepositoryOpts { @@ -76,6 +78,18 @@ func WithMetered(metered *metered.Metered) PrismaRepositoryOpt { } } +func WithLogsEngineRepository(newLogsEngine repository.LogsEngineRepository) PrismaRepositoryOpt { + return func(opts *PrismaRepositoryOpts) { + opts.logsEngineRepository = newLogsEngine + } +} + +func WithLogsAPIRepository(newLogsAPI repository.LogsAPIRepository) PrismaRepositoryOpt { + return func(opts *PrismaRepositoryOpts) { + opts.logsAPIRepository = newLogsAPI + } +} + func NewAPIRepository(client *db.PrismaClient, pool *pgxpool.Pool, cf *server.ConfigFileRuntime, fs ...PrismaRepositoryOpt) (repository.APIRepository, func() error, error) { opts := defaultPrismaRepositoryOpts() @@ -95,11 +109,18 @@ func NewAPIRepository(client *db.PrismaClient, pool *pgxpool.Pool, cf *server.Co if err != nil { return nil, nil, err } + var logsAPIRepo repository.LogsAPIRepository + + if opts.logsAPIRepository == nil { + logsAPIRepo = NewLogAPIRepository(pool, opts.v, opts.l) + } else { + logsAPIRepo = opts.logsAPIRepository.WithAdditionalConfig(opts.v, opts.l) + } return &apiRepository{ apiToken: NewAPITokenRepository(client, opts.v, opts.cache), event: NewEventAPIRepository(client, pool, opts.v, opts.l), - log: NewLogAPIRepository(pool, opts.v, opts.l), + log: logsAPIRepo, tenant: NewTenantAPIRepository(pool, client, opts.v, opts.l, opts.cache), tenantAlerting: NewTenantAlertingAPIRepository(client, opts.v, opts.cache), tenantInvite: NewTenantInviteRepository(client, opts.v, opts.l), @@ -322,6 +343,13 @@ func NewEngineRepository(pool *pgxpool.Pool, essentialPool *pgxpool.Pool, cf *se if err != nil { return nil, nil, err } + var logRepo repository.LogsEngineRepository + + if opts.logsEngineRepository == nil { + logRepo = NewLogEngineRepository(pool, opts.v, opts.l) + } else { + logRepo = opts.logsEngineRepository.WithAdditionalConfig(opts.v, opts.l) + } return func() error { rlCache.Stop() @@ -344,7 +372,7 @@ func NewEngineRepository(pool *pgxpool.Pool, essentialPool *pgxpool.Pool, cf *se workflow: NewWorkflowEngineRepository(shared, opts.metered, opts.cache), workflowRun: NewWorkflowRunEngineRepository(shared, opts.metered, cf), streamEvent: NewStreamEventsEngineRepository(pool, opts.v, opts.l), - log: NewLogEngineRepository(pool, opts.v, opts.l), + log: logRepo, rateLimit: NewRateLimitEngineRepository(pool, opts.v, opts.l), webhookWorker: NewWebhookWorkerEngineRepository(pool, opts.v, opts.l), scheduler: newSchedulerRepository(shared),