feat(engine): standalone tests and engine teardown (#172)

This commit is contained in:
Luca Steeb
2024-02-28 00:15:25 +07:00
committed by GitHub
parent 7709bcb175
commit ae4841031b
26 changed files with 882 additions and 523 deletions

View File

@@ -182,7 +182,73 @@ jobs:
run: | run: |
export HATCHET_CLIENT_TOKEN="$(go run ./cmd/hatchet-admin token create --config ./generated/ --tenant-id 707d0855-80ab-4e1f-a156-f1c4546cbf52)" export HATCHET_CLIENT_TOKEN="$(go run ./cmd/hatchet-admin token create --config ./generated/ --tenant-id 707d0855-80ab-4e1f-a156-f1c4546cbf52)"
go test -tags e2e ./... -race -p 1 -v -failfast go test -tags e2e ./... -p 1 -v -failfast
- name: Teardown
run: docker compose down
load:
runs-on: ubuntu-latest
timeout-minutes: 30
env:
DATABASE_URL: postgresql://hatchet:hatchet@127.0.0.1:5431/hatchet
steps:
- uses: actions/checkout@v4
- name: Install Task
uses: arduino/setup-task@v1
- name: Install Protoc
uses: arduino/setup-protoc@v2
with:
version: "25.1"
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: "1.21"
- name: Compose
run: docker compose up -d
- name: Go deps
run: go mod download
- name: Generate
run: |
go run github.com/steebchen/prisma-client-go db push
task generate-certs
task generate-local-encryption-keys
- name: Prepare
run: |
cat > .env <<EOF
DATABASE_URL='postgresql://hatchet:hatchet@127.0.0.1:5431/hatchet'
SERVER_TLS_CERT_FILE=./hack/dev/certs/cluster.pem
SERVER_TLS_KEY_FILE=./hack/dev/certs/cluster.key
SERVER_TLS_ROOT_CA_FILE=./hack/dev/certs/ca.cert
SERVER_PORT=8080
SERVER_URL=https://app.dev.hatchet-tools.com
SERVER_AUTH_COOKIE_SECRETS="something something"
SERVER_AUTH_COOKIE_DOMAIN=app.dev.hatchet-tools.com
SERVER_AUTH_COOKIE_INSECURE=false
SERVER_AUTH_SET_EMAIL_VERIFIED=true
EOF
- name: Setup
run: |
set -a
. .env
set +a
go run ./cmd/hatchet-admin quickstart --generated-config-dir ./generated/
- name: Test
run: |
export HATCHET_CLIENT_TOKEN="$(go run ./cmd/hatchet-admin token create --config ./generated/ --tenant-id 707d0855-80ab-4e1f-a156-f1c4546cbf52)"
go test -tags load ./... -p 1 -v -failfast
- name: Teardown - name: Teardown
run: docker compose down run: docker compose down

2
.gitignore vendored
View File

@@ -89,3 +89,5 @@ postgres-data
rabbitmq.conf rabbitmq.conf
*encryption-keys *encryption-keys
generated/
certs/

View File

@@ -93,6 +93,7 @@ tasks:
- go mod download - go mod download
- cd frontend/app/ && pnpm install - cd frontend/app/ && pnpm install
- cd frontend/docs/ && pnpm install - cd frontend/docs/ && pnpm install
- cd typescript-sdk/ && pnpm install
generate-api: generate-api:
cmds: cmds:
- task: generate-api-server - task: generate-api-server

View File

@@ -59,7 +59,12 @@ func runCreateAPIToken() error {
// read in the local config // read in the local config
configLoader := loader.NewConfigLoader(configDirectory) configLoader := loader.NewConfigLoader(configDirectory)
serverConf, err := configLoader.LoadServerConfig() cleanup, serverConf, err := configLoader.LoadServerConfig()
defer func() {
if err := cleanup(); err != nil {
panic(fmt.Errorf("could not cleanup server config: %v", err))
}
}()
if err != nil { if err != nil {
return err return err

View File

@@ -59,16 +59,23 @@ func main() {
} }
func startServerOrDie(cf *loader.ConfigLoader, interruptCh <-chan interface{}) { func startServerOrDie(cf *loader.ConfigLoader, interruptCh <-chan interface{}) {
ctx, cancel := cmdutils.InterruptContextFromChan(interruptCh)
defer cancel()
// init the repository // init the repository
sc, err := cf.LoadServerConfig() cleanup, sc, err := cf.LoadServerConfig()
defer func() {
if err := cleanup(); err != nil {
panic(fmt.Errorf("could not cleanup server config: %v", err))
}
}()
if err != nil { if err != nil {
panic(err) panic(err)
} }
errCh := make(chan error) errCh := make(chan error)
ctx, cancel := cmdutils.InterruptContextFromChan(interruptCh)
defer cancel()
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
if sc.InternalClient != nil { if sc.InternalClient != nil {

View File

@@ -0,0 +1,259 @@
package engine
import (
"context"
"fmt"
"github.com/hatchet-dev/hatchet/internal/config/loader"
"github.com/hatchet-dev/hatchet/internal/services/admin"
"github.com/hatchet-dev/hatchet/internal/services/controllers/events"
"github.com/hatchet-dev/hatchet/internal/services/controllers/jobs"
"github.com/hatchet-dev/hatchet/internal/services/controllers/workflows"
"github.com/hatchet-dev/hatchet/internal/services/dispatcher"
"github.com/hatchet-dev/hatchet/internal/services/grpc"
"github.com/hatchet-dev/hatchet/internal/services/heartbeat"
"github.com/hatchet-dev/hatchet/internal/services/ingestor"
"github.com/hatchet-dev/hatchet/internal/services/ticker"
"github.com/hatchet-dev/hatchet/internal/telemetry"
)
type Teardown struct {
name string
fn func() error
}
func Run(ctx context.Context, cf *loader.ConfigLoader) error {
serverCleanup, sc, err := cf.LoadServerConfig()
if err != nil {
return fmt.Errorf("could not load server config: %w", err)
}
var l = sc.Logger
shutdown, err := telemetry.InitTracer(&telemetry.TracerOpts{
ServiceName: sc.OpenTelemetry.ServiceName,
CollectorURL: sc.OpenTelemetry.CollectorURL,
})
if err != nil {
return fmt.Errorf("could not initialize tracer: %w", err)
}
var teardown []Teardown
if sc.HasService("ticker") {
t, err := ticker.New(
ticker.WithTaskQueue(sc.TaskQueue),
ticker.WithRepository(sc.Repository),
ticker.WithLogger(sc.Logger),
)
if err != nil {
return fmt.Errorf("could not create ticker: %w", err)
}
cleanup, err := t.Start()
if err != nil {
return fmt.Errorf("could not start ticker: %w", err)
}
teardown = append(teardown, Teardown{
name: "ticker",
fn: cleanup,
})
}
if sc.HasService("eventscontroller") {
ec, err := events.New(
events.WithTaskQueue(sc.TaskQueue),
events.WithRepository(sc.Repository),
events.WithLogger(sc.Logger),
)
if err != nil {
return fmt.Errorf("could not create events controller: %w", err)
}
cleanup, err := ec.Start()
if err != nil {
return fmt.Errorf("could not start events controller: %w", err)
}
teardown = append(teardown, Teardown{
name: "events controller",
fn: cleanup,
})
}
if sc.HasService("jobscontroller") {
jc, err := jobs.New(
jobs.WithTaskQueue(sc.TaskQueue),
jobs.WithRepository(sc.Repository),
jobs.WithLogger(sc.Logger),
)
if err != nil {
return fmt.Errorf("could not create jobs controller: %w", err)
}
cleanup, err := jc.Start()
if err != nil {
return fmt.Errorf("could not start jobs controller: %w", err)
}
teardown = append(teardown, Teardown{
name: "jobs controller",
fn: cleanup,
})
}
if sc.HasService("workflowscontroller") {
wc, err := workflows.New(
workflows.WithTaskQueue(sc.TaskQueue),
workflows.WithRepository(sc.Repository),
workflows.WithLogger(sc.Logger),
)
if err != nil {
return fmt.Errorf("could not create workflows controller: %w", err)
}
cleanup, err := wc.Start()
if err != nil {
return fmt.Errorf("could not start workflows controller: %w", err)
}
teardown = append(teardown, Teardown{
name: "workflows controller",
fn: cleanup,
})
}
if sc.HasService("heartbeater") {
h, err := heartbeat.New(
heartbeat.WithTaskQueue(sc.TaskQueue),
heartbeat.WithRepository(sc.Repository),
heartbeat.WithLogger(sc.Logger),
)
if err != nil {
return fmt.Errorf("could not create heartbeater: %w", err)
}
cleanup, err := h.Start()
if err != nil {
return fmt.Errorf("could not start heartbeater: %w", err)
}
teardown = append(teardown, Teardown{
name: "heartbeater",
fn: cleanup,
})
}
if sc.HasService("grpc") {
// create the dispatcher
d, err := dispatcher.New(
dispatcher.WithTaskQueue(sc.TaskQueue),
dispatcher.WithRepository(sc.Repository),
dispatcher.WithLogger(sc.Logger),
)
if err != nil {
return fmt.Errorf("could not create dispatcher: %w", err)
}
dispatcherCleanup, err := d.Start()
if err != nil {
return fmt.Errorf("could not start dispatcher: %w", err)
}
teardown = append(teardown, Teardown{
name: "grpc dispatcher",
fn: dispatcherCleanup,
})
// create the event ingestor
ei, err := ingestor.NewIngestor(
ingestor.WithEventRepository(
sc.Repository.Event(),
),
ingestor.WithTaskQueue(sc.TaskQueue),
)
if err != nil {
return fmt.Errorf("could not create ingestor: %w", err)
}
adminSvc, err := admin.NewAdminService(
admin.WithRepository(sc.Repository),
admin.WithTaskQueue(sc.TaskQueue),
)
if err != nil {
return fmt.Errorf("could not create admin service: %w", err)
}
grpcOpts := []grpc.ServerOpt{
grpc.WithConfig(sc),
grpc.WithIngestor(ei),
grpc.WithDispatcher(d),
grpc.WithAdmin(adminSvc),
grpc.WithLogger(sc.Logger),
grpc.WithTLSConfig(sc.TLSConfig),
grpc.WithPort(sc.Runtime.GRPCPort),
grpc.WithBindAddress(sc.Runtime.GRPCBindAddress),
}
if sc.Runtime.GRPCInsecure {
grpcOpts = append(grpcOpts, grpc.WithInsecure())
}
// create the grpc server
s, err := grpc.NewServer(
grpcOpts...,
)
if err != nil {
return fmt.Errorf("could not create grpc server: %w", err)
}
grpcServerCleanup, err := s.Start()
if err != nil {
return fmt.Errorf("could not start grpc server: %w", err)
}
teardown = append(teardown, Teardown{
name: "grpc server",
fn: grpcServerCleanup,
})
}
teardown = append(teardown, Teardown{
name: "telemetry",
fn: func() error {
return shutdown(ctx)
},
})
teardown = append(teardown, Teardown{
name: "server",
fn: func() error {
return serverCleanup()
},
})
teardown = append(teardown, Teardown{
name: "database",
fn: func() error {
return sc.Disconnect()
},
})
l.Debug().Msgf("engine has started")
<-ctx.Done()
l.Debug().Msgf("interrupt received, shutting down")
l.Debug().Msgf("waiting for all other services to gracefully exit...")
for i, t := range teardown {
l.Debug().Msgf("shutting down %s (%d/%d)", t.name, i+1, len(teardown))
err := t.fn()
if err != nil {
return fmt.Errorf("could not teardown %s: %w", t.name, err)
}
l.Debug().Msgf("successfully shutdown %s (%d/%d)", t.name, i+1, len(teardown))
}
l.Debug().Msgf("all services have successfully gracefully exited")
l.Debug().Msgf("successfully shutdown")
return nil
}

View File

@@ -2,22 +2,13 @@ package main
import ( import (
"fmt" "fmt"
"log"
"os" "os"
"sync"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/hatchet-dev/hatchet/cmd/hatchet-engine/engine"
"github.com/hatchet-dev/hatchet/internal/config/loader" "github.com/hatchet-dev/hatchet/internal/config/loader"
"github.com/hatchet-dev/hatchet/internal/services/admin"
"github.com/hatchet-dev/hatchet/internal/services/controllers/events"
"github.com/hatchet-dev/hatchet/internal/services/controllers/jobs"
"github.com/hatchet-dev/hatchet/internal/services/controllers/workflows"
"github.com/hatchet-dev/hatchet/internal/services/dispatcher"
"github.com/hatchet-dev/hatchet/internal/services/grpc"
"github.com/hatchet-dev/hatchet/internal/services/heartbeat"
"github.com/hatchet-dev/hatchet/internal/services/ingestor"
"github.com/hatchet-dev/hatchet/internal/services/ticker"
"github.com/hatchet-dev/hatchet/internal/telemetry"
"github.com/hatchet-dev/hatchet/pkg/cmdutils" "github.com/hatchet-dev/hatchet/pkg/cmdutils"
) )
@@ -35,9 +26,13 @@ var rootCmd = &cobra.Command{
} }
cf := loader.NewConfigLoader(configDirectory) cf := loader.NewConfigLoader(configDirectory)
interruptChan := cmdutils.InterruptChan() context, cancel := cmdutils.NewInterruptContext()
defer cancel()
startEngineOrDie(cf, interruptChan) if err := engine.Run(context, cf); err != nil {
log.Printf("engine failure: %s", err.Error())
os.Exit(1)
}
}, },
} }
@@ -64,240 +59,3 @@ func main() {
os.Exit(1) os.Exit(1)
} }
} }
func startEngineOrDie(cf *loader.ConfigLoader, interruptCh <-chan interface{}) {
sc, err := cf.LoadServerConfig()
if err != nil {
panic(err)
}
errCh := make(chan error)
ctx, cancel := cmdutils.InterruptContextFromChan(interruptCh)
wg := sync.WaitGroup{}
shutdown, err := telemetry.InitTracer(&telemetry.TracerOpts{
ServiceName: sc.OpenTelemetry.ServiceName,
CollectorURL: sc.OpenTelemetry.CollectorURL,
})
if err != nil {
panic(fmt.Sprintf("could not initialize tracer: %s", err))
}
defer shutdown(ctx) // nolint: errcheck
if sc.HasService("grpc") {
wg.Add(1)
// create the dispatcher
d, err := dispatcher.New(
dispatcher.WithTaskQueue(sc.TaskQueue),
dispatcher.WithRepository(sc.Repository),
dispatcher.WithLogger(sc.Logger),
)
if err != nil {
errCh <- err
return
}
go func() {
defer wg.Done()
err := d.Start(ctx)
if err != nil {
panic(err)
}
}()
// create the event ingestor
ei, err := ingestor.NewIngestor(
ingestor.WithEventRepository(
sc.Repository.Event(),
),
ingestor.WithTaskQueue(sc.TaskQueue),
)
if err != nil {
errCh <- err
return
}
adminSvc, err := admin.NewAdminService(
admin.WithRepository(sc.Repository),
admin.WithTaskQueue(sc.TaskQueue),
)
if err != nil {
errCh <- err
return
}
grpcOpts := []grpc.ServerOpt{
grpc.WithConfig(sc),
grpc.WithIngestor(ei),
grpc.WithDispatcher(d),
grpc.WithAdmin(adminSvc),
grpc.WithLogger(sc.Logger),
grpc.WithTLSConfig(sc.TLSConfig),
grpc.WithPort(sc.Runtime.GRPCPort),
grpc.WithBindAddress(sc.Runtime.GRPCBindAddress),
}
if sc.Runtime.GRPCInsecure {
grpcOpts = append(grpcOpts, grpc.WithInsecure())
}
// create the grpc server
s, err := grpc.NewServer(
grpcOpts...,
)
if err != nil {
errCh <- err
return
}
go func() {
err = s.Start(ctx)
if err != nil {
errCh <- err
return
}
}()
}
if sc.HasService("eventscontroller") {
// create separate events controller process
go func() {
ec, err := events.New(
events.WithTaskQueue(sc.TaskQueue),
events.WithRepository(sc.Repository),
events.WithLogger(sc.Logger),
)
if err != nil {
errCh <- err
return
}
err = ec.Start(ctx)
if err != nil {
errCh <- err
}
}()
}
if sc.HasService("jobscontroller") {
// create separate jobs controller process
go func() {
jc, err := jobs.New(
jobs.WithTaskQueue(sc.TaskQueue),
jobs.WithRepository(sc.Repository),
jobs.WithLogger(sc.Logger),
)
if err != nil {
errCh <- err
return
}
err = jc.Start(ctx)
if err != nil {
errCh <- err
}
}()
}
if sc.HasService("workflowscontroller") {
// create separate jobs controller process
go func() {
jc, err := workflows.New(
workflows.WithTaskQueue(sc.TaskQueue),
workflows.WithRepository(sc.Repository),
workflows.WithLogger(sc.Logger),
)
if err != nil {
errCh <- err
return
}
err = jc.Start(ctx)
if err != nil {
errCh <- err
}
}()
}
if sc.HasService("ticker") {
// create a ticker
go func() {
t, err := ticker.New(
ticker.WithTaskQueue(sc.TaskQueue),
ticker.WithRepository(sc.Repository),
ticker.WithLogger(sc.Logger),
)
if err != nil {
errCh <- err
return
}
err = t.Start(ctx)
if err != nil {
errCh <- err
}
}()
}
if sc.HasService("heartbeater") {
go func() {
h, err := heartbeat.New(
heartbeat.WithTaskQueue(sc.TaskQueue),
heartbeat.WithRepository(sc.Repository),
heartbeat.WithLogger(sc.Logger),
)
if err != nil {
errCh <- err
return
}
err = h.Start(ctx)
if err != nil {
errCh <- err
}
}()
}
Loop:
for {
select {
case err := <-errCh:
fmt.Fprintf(os.Stderr, "%s", err)
// exit with non-zero exit code
os.Exit(1) //nolint:gocritic
case <-interruptCh:
break Loop
}
}
cancel()
wg.Wait()
err = sc.Disconnect()
if err != nil {
panic(err)
}
}

View File

@@ -1,8 +1,11 @@
//go:build e2e //go:build load
package main package main
import ( import (
"context"
"log"
"sync"
"testing" "testing"
"time" "time"
@@ -44,23 +47,48 @@ func TestLoadCLI(t *testing.T) {
concurrency: 0, concurrency: 0,
}, },
}} }}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
setup := sync.WaitGroup{}
go func() {
setup.Add(1)
log.Printf("setup start")
testutils.SetupEngine(ctx, t)
setup.Done()
log.Printf("setup end")
}()
// TODO instead of waiting, figure out when the engine setup is complete
time.Sleep(10 * time.Second)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
defer func() {
time.Sleep(1 * time.Second)
goleak.VerifyNone(
t,
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
goleak.IgnoreTopFunction("google.golang.org/grpc/internal/grpcsync.(*CallbackSerializer).run"),
goleak.IgnoreTopFunction("internal/poll.runtime_pollWait"),
goleak.IgnoreTopFunction("google.golang.org/grpc/internal/transport.(*controlBuffer).get"),
)
}()
if err := do(tt.args.duration, tt.args.eventsPerSecond, tt.args.delay, tt.args.wait, tt.args.concurrency); (err != nil) != tt.wantErr { if err := do(tt.args.duration, tt.args.eventsPerSecond, tt.args.delay, tt.args.wait, tt.args.concurrency); (err != nil) != tt.wantErr {
t.Errorf("do() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("do() error = %v, wantErr %v", err, tt.wantErr)
} }
}) })
} }
cancel()
log.Printf("test complete")
setup.Wait()
log.Printf("cleanup complete")
goleak.VerifyNone(
t,
// worker
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
goleak.IgnoreTopFunction("google.golang.org/grpc/internal/grpcsync.(*CallbackSerializer).run"),
goleak.IgnoreTopFunction("internal/poll.runtime_pollWait"),
goleak.IgnoreTopFunction("google.golang.org/grpc/internal/transport.(*controlBuffer).get"),
// all engine related packages
goleak.IgnoreTopFunction("github.com/jackc/pgx/v5/pgxpool.(*Pool).backgroundHealthCheck"),
goleak.IgnoreTopFunction("github.com/rabbitmq/amqp091-go.(*Connection).heartbeater"),
goleak.IgnoreTopFunction("github.com/rabbitmq/amqp091-go.(*consumers).buffer"),
goleak.IgnoreTopFunction("google.golang.org/grpc/internal/transport.(*http2Server).keepalive"),
)
} }

View File

@@ -14,7 +14,6 @@ import (
) )
func TestMiddleware(t *testing.T) { func TestMiddleware(t *testing.T) {
t.Skip()
testutils.Prepare(t) testutils.Prepare(t)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)

2
go.mod
View File

@@ -13,6 +13,7 @@ require (
github.com/gorilla/sessions v1.2.2 github.com/gorilla/sessions v1.2.2
github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/jackc/pgx-zerolog v0.0.0-20230315001418-f978528409eb
github.com/jackc/pgx/v5 v5.5.0 github.com/jackc/pgx/v5 v5.5.0
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
github.com/labstack/echo/v4 v4.11.3 github.com/labstack/echo/v4 v4.11.3
@@ -100,7 +101,6 @@ require (
github.com/hashicorp/errwrap v1.0.0 // indirect github.com/hashicorp/errwrap v1.0.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect
github.com/invopop/jsonschema v0.12.0 github.com/invopop/jsonschema v0.12.0
github.com/jackc/pgx-zerolog v0.0.0-20230315001418-f978528409eb
github.com/magiconair/properties v1.8.7 // indirect github.com/magiconair/properties v1.8.7 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect

View File

@@ -5,6 +5,7 @@ package loader
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@@ -17,6 +18,7 @@ import (
"github.com/hatchet-dev/hatchet/internal/auth/cookie" "github.com/hatchet-dev/hatchet/internal/auth/cookie"
"github.com/hatchet-dev/hatchet/internal/auth/oauth" "github.com/hatchet-dev/hatchet/internal/auth/oauth"
"github.com/hatchet-dev/hatchet/internal/auth/token" "github.com/hatchet-dev/hatchet/internal/auth/token"
clientconfig "github.com/hatchet-dev/hatchet/internal/config/client"
"github.com/hatchet-dev/hatchet/internal/config/database" "github.com/hatchet-dev/hatchet/internal/config/database"
"github.com/hatchet-dev/hatchet/internal/config/loader/loaderutils" "github.com/hatchet-dev/hatchet/internal/config/loader/loaderutils"
"github.com/hatchet-dev/hatchet/internal/config/server" "github.com/hatchet-dev/hatchet/internal/config/server"
@@ -30,8 +32,6 @@ import (
"github.com/hatchet-dev/hatchet/internal/taskqueue/rabbitmq" "github.com/hatchet-dev/hatchet/internal/taskqueue/rabbitmq"
"github.com/hatchet-dev/hatchet/internal/validator" "github.com/hatchet-dev/hatchet/internal/validator"
"github.com/hatchet-dev/hatchet/pkg/client" "github.com/hatchet-dev/hatchet/pkg/client"
clientconfig "github.com/hatchet-dev/hatchet/internal/config/client"
) )
// LoadDatabaseConfigFile loads the database config file via viper // LoadDatabaseConfigFile loads the database config file via viper
@@ -81,24 +81,24 @@ func (c *ConfigLoader) LoadDatabaseConfig() (res *database.Config, err error) {
} }
// LoadServerConfig loads the server configuration // LoadServerConfig loads the server configuration
func (c *ConfigLoader) LoadServerConfig() (res *server.ServerConfig, err error) { func (c *ConfigLoader) LoadServerConfig() (cleanup func() error, res *server.ServerConfig, err error) {
log.Printf("Loading server config from %s", c.directory)
sharedFilePath := filepath.Join(c.directory, "server.yaml") sharedFilePath := filepath.Join(c.directory, "server.yaml")
configFileBytes, err := loaderutils.GetConfigBytes(sharedFilePath) log.Printf("Shared file path: %s", sharedFilePath)
configFileBytes, err := loaderutils.GetConfigBytes(sharedFilePath)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
dc, err := c.LoadDatabaseConfig() dc, err := c.LoadDatabaseConfig()
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
cf, err := LoadServerConfigFile(configFileBytes...) cf, err := LoadServerConfigFile(configFileBytes...)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
return GetServerConfigFromConfigfile(dc, cf) return GetServerConfigFromConfigfile(dc, cf)
@@ -119,16 +119,15 @@ func GetDatabaseConfigFromConfigFile(cf *database.ConfigFile) (res *database.Con
os.Setenv("DATABASE_URL", databaseUrl) os.Setenv("DATABASE_URL", databaseUrl)
client := db.NewClient( c := db.NewClient(
// db.WithDatasourceURL(databaseUrl), // db.WithDatasourceURL(databaseUrl),
) )
if err := client.Prisma.Connect(); err != nil { if err := c.Prisma.Connect(); err != nil {
return nil, err return nil, err
} }
config, err := pgxpool.ParseConfig(databaseUrl) config, err := pgxpool.ParseConfig(databaseUrl)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -143,25 +142,24 @@ func GetDatabaseConfigFromConfigFile(cf *database.ConfigFile) (res *database.Con
config.MaxConns = 20 config.MaxConns = 20
pool, err := pgxpool.NewWithConfig(context.Background(), config) pool, err := pgxpool.NewWithConfig(context.Background(), config)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not connect to database: %w", err) return nil, fmt.Errorf("could not connect to database: %w", err)
} }
return &database.Config{ return &database.Config{
Disconnect: client.Prisma.Disconnect, Disconnect: c.Prisma.Disconnect,
Repository: prisma.NewPrismaRepository(client, pool, prisma.WithLogger(&l)), Repository: prisma.NewPrismaRepository(c, pool, prisma.WithLogger(&l)),
Seed: cf.Seed, Seed: cf.Seed,
}, nil }, nil
} }
func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigFile) (res *server.ServerConfig, err error) { func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigFile) (cleanup func() error, res *server.ServerConfig, err error) {
l := logger.NewStdErr(&cf.Logger, "server") l := logger.NewStdErr(&cf.Logger, "server")
tls, err := loaderutils.LoadServerTLSConfig(&cf.TLS) tls, err := loaderutils.LoadServerTLSConfig(&cf.TLS)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not load TLS config: %w", err) return nil, nil, fmt.Errorf("could not load TLS config: %w", err)
} }
ss, err := cookie.NewUserSessionStore( ss, err := cookie.NewUserSessionStore(
@@ -173,11 +171,10 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("could not create session store: %w", err) return nil, nil, fmt.Errorf("could not create session store: %w", err)
} }
tq := rabbitmq.New( cleanup1, tq := rabbitmq.New(
context.Background(),
rabbitmq.WithURL(cf.TaskQueue.RabbitMQ.URL), rabbitmq.WithURL(cf.TaskQueue.RabbitMQ.URL),
rabbitmq.WithLogger(&l), rabbitmq.WithLogger(&l),
) )
@@ -188,7 +185,7 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("could not create ingestor: %w", err) return nil, nil, fmt.Errorf("could not create ingestor: %w", err)
} }
auth := server.AuthConfig{ auth := server.AuthConfig{
@@ -197,11 +194,11 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF
if cf.Auth.Google.Enabled { if cf.Auth.Google.Enabled {
if cf.Auth.Google.ClientID == "" { if cf.Auth.Google.ClientID == "" {
return nil, fmt.Errorf("google client id is required") return nil, nil, fmt.Errorf("google client id is required")
} }
if cf.Auth.Google.ClientSecret == "" { if cf.Auth.Google.ClientSecret == "" {
return nil, fmt.Errorf("google client secret is required") return nil, nil, fmt.Errorf("google client secret is required")
} }
gClient := oauth.NewGoogleClient(&oauth.Config{ gClient := oauth.NewGoogleClient(&oauth.Config{
@@ -217,7 +214,7 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF
encryptionSvc, err := loadEncryptionSvc(cf) encryptionSvc, err := loadEncryptionSvc(cf)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not load encryption service: %w", err) return nil, nil, fmt.Errorf("could not load encryption service: %w", err)
} }
// create a new JWT manager // create a new JWT manager
@@ -229,7 +226,7 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("could not create JWT manager: %w", err) return nil, nil, fmt.Errorf("could not create JWT manager: %w", err)
} }
vcsProviders := make(map[vcs.VCSRepositoryKind]vcs.VCSProvider) vcsProviders := make(map[vcs.VCSRepositoryKind]vcs.VCSProvider)
@@ -252,7 +249,7 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF
) )
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
githubProvider := github.NewGithubVCSProvider(githubAppConf, dc.Repository, cf.Runtime.ServerURL, encryptionSvc) githubProvider := github.NewGithubVCSProvider(githubAppConf, dc.Repository, cf.Runtime.ServerURL, encryptionSvc)
@@ -267,20 +264,20 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF
internalTenant, err := dc.Repository.Tenant().GetTenantBySlug("internal") internalTenant, err := dc.Repository.Tenant().GetTenantBySlug("internal")
if err != nil { if err != nil {
return nil, fmt.Errorf("could not get internal tenant: %w", err) return nil, nil, fmt.Errorf("could not get internal tenant: %w", err)
} }
tokenSuffix, err := encryption.GenerateRandomBytes(4) tokenSuffix, err := encryption.GenerateRandomBytes(4)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not generate token suffix: %w", err) return nil, nil, fmt.Errorf("could not generate token suffix: %w", err)
} }
// generate a token for the internal client // generate a token for the internal client
token, err := auth.JWTManager.GenerateTenantToken(internalTenant.ID, fmt.Sprintf("internal-%s", tokenSuffix)) token, err := auth.JWTManager.GenerateTenantToken(internalTenant.ID, fmt.Sprintf("internal-%s", tokenSuffix))
if err != nil { if err != nil {
return nil, fmt.Errorf("could not generate internal token: %w", err) return nil, nil, fmt.Errorf("could not generate internal token: %w", err)
} }
internalClient, err = client.NewFromConfigFile( internalClient, err = client.NewFromConfigFile(
@@ -291,11 +288,19 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("could not create internal client: %w", err) return nil, nil, fmt.Errorf("could not create internal client: %w", err)
} }
} }
return &server.ServerConfig{ cleanup = func() error {
log.Printf("cleaning up server config")
if err := cleanup1(); err != nil {
return fmt.Errorf("error cleaning up rabbitmq: %w", err)
}
return nil
}
return cleanup, &server.ServerConfig{
Runtime: cf.Runtime, Runtime: cf.Runtime,
Auth: auth, Auth: auth,
Encryption: encryptionSvc, Encryption: encryptionSvc,

View File

@@ -3,6 +3,7 @@ package events
import ( import (
"context" "context"
"fmt" "fmt"
"sync"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
@@ -94,25 +95,41 @@ func New(fs ...EventsControllerOpt) (*EventsControllerImpl, error) {
}, nil }, nil
} }
func (ec *EventsControllerImpl) Start(ctx context.Context) error { func (ec *EventsControllerImpl) Start() (func() error, error) {
taskChan, err := ec.tq.Subscribe(ctx, taskqueue.EVENT_PROCESSING_QUEUE) ctx, cancel := context.WithCancel(context.Background())
cleanupQueue, taskChan, err := ec.tq.Subscribe(taskqueue.EVENT_PROCESSING_QUEUE)
if err != nil { if err != nil {
return err cancel()
return nil, fmt.Errorf("could not subscribe to event processing queue: %w", err)
} }
// TODO: close when ctx is done wg := sync.WaitGroup{}
for task := range taskChan {
go func(task *taskqueue.Task) {
err = ec.handleTask(ctx, task)
if err != nil { go func() {
ec.l.Error().Err(err).Msgf("could not handle event task %s", task.ID) for task := range taskChan {
} wg.Add(1)
}(task) go func(task *taskqueue.Task) {
defer wg.Done()
err = ec.handleTask(ctx, task)
if err != nil {
ec.l.Error().Err(err).Msgf("could not handle event task %s", task.ID)
}
}(task)
}
}()
cleanup := func() error {
cancel()
if err := cleanupQueue(); err != nil {
return fmt.Errorf("could not cleanup event processing queue: %w", err)
}
return nil
} }
return nil return cleanup, nil
} }
func (ec *EventsControllerImpl) handleTask(ctx context.Context, task *taskqueue.Task) error { func (ec *EventsControllerImpl) handleTask(ctx context.Context, task *taskqueue.Task) error {

View File

@@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
"sync"
"time" "time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@@ -102,25 +103,40 @@ func New(fs ...JobsControllerOpt) (*JobsControllerImpl, error) {
}, nil }, nil
} }
func (jc *JobsControllerImpl) Start(ctx context.Context) error { func (jc *JobsControllerImpl) Start() (func() error, error) {
taskChan, err := jc.tq.Subscribe(ctx, taskqueue.JOB_PROCESSING_QUEUE) cleanupQueue, taskChan, err := jc.tq.Subscribe(taskqueue.JOB_PROCESSING_QUEUE)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("could not subscribe to job processing queue: %w", err)
} }
// TODO: close when ctx is done wg := sync.WaitGroup{}
for task := range taskChan {
go func(task *taskqueue.Task) {
err = jc.handleTask(ctx, task)
if err != nil { go func() {
jc.l.Error().Err(err).Msg("could not handle job task") for task := range taskChan {
} wg.Add(1)
}(task) go func(task *taskqueue.Task) {
defer wg.Done()
err = jc.handleTask(context.Background(), task)
if err != nil {
jc.l.Error().Err(err).Msg("could not handle job task")
}
}(task)
}
}()
cleanup := func() error {
if err := cleanupQueue(); err != nil {
return fmt.Errorf("could not cleanup job processing queue: %w", err)
}
wg.Wait()
return nil
} }
return nil return cleanup, nil
} }
func (ec *JobsControllerImpl) handleTask(ctx context.Context, task *taskqueue.Task) error { func (ec *JobsControllerImpl) handleTask(ctx context.Context, task *taskqueue.Task) error {

View File

@@ -3,6 +3,7 @@ package workflows
import ( import (
"context" "context"
"fmt" "fmt"
"sync"
"time" "time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@@ -95,27 +96,40 @@ func New(fs ...WorkflowsControllerOpt) (*WorkflowsControllerImpl, error) {
}, nil }, nil
} }
func (wc *WorkflowsControllerImpl) Start(ctx context.Context) error { func (wc *WorkflowsControllerImpl) Start() (func() error, error) {
wc.l.Debug().Msg("starting workflows controller") wc.l.Debug().Msg("starting workflows controller")
taskChan, err := wc.tq.Subscribe(ctx, taskqueue.WORKFLOW_PROCESSING_QUEUE) cleanupQueue, taskChan, err := wc.tq.Subscribe(taskqueue.WORKFLOW_PROCESSING_QUEUE)
if err != nil { if err != nil {
return err return nil, err
} }
// TODO: close when ctx is done wg := sync.WaitGroup{}
for task := range taskChan {
go func(task *taskqueue.Task) {
err = wc.handleTask(ctx, task)
if err != nil { go func() {
wc.l.Error().Err(err).Msg("could not handle job task") for task := range taskChan {
} wg.Add(1)
}(task) go func(task *taskqueue.Task) {
defer wg.Done()
err = wc.handleTask(context.Background(), task)
if err != nil {
wc.l.Error().Err(err).Msg("could not handle job task")
}
}(task)
}
}()
cleanup := func() error {
if err := cleanupQueue(); err != nil {
return fmt.Errorf("could not cleanup queue: %w", err)
}
wg.Wait()
return nil
} }
return nil return cleanup, nil
} }
func (wc *WorkflowsControllerImpl) handleTask(ctx context.Context, task *taskqueue.Task) error { func (wc *WorkflowsControllerImpl) handleTask(ctx context.Context, task *taskqueue.Task) error {

View File

@@ -22,7 +22,7 @@ import (
type Dispatcher interface { type Dispatcher interface {
contracts.DispatcherServer contracts.DispatcherServer
Start(ctx context.Context) error Start() (func() error, error)
} }
type DispatcherImpl struct { type DispatcherImpl struct {
@@ -122,21 +122,24 @@ func New(fs ...DispatcherOpt) (*DispatcherImpl, error) {
}, nil }, nil
} }
func (d *DispatcherImpl) Start(ctx context.Context) error { func (d *DispatcherImpl) Start() (func() error, error) {
// register the dispatcher by creating a new dispatcher in the database // register the dispatcher by creating a new dispatcher in the database
dispatcher, err := d.repo.Dispatcher().CreateNewDispatcher(&repository.CreateDispatcherOpts{ dispatcher, err := d.repo.Dispatcher().CreateNewDispatcher(&repository.CreateDispatcherOpts{
ID: d.dispatcherId, ID: d.dispatcherId,
}) })
if err != nil { if err != nil {
return err return nil, err
} }
ctx, cancel := context.WithCancel(context.Background())
// subscribe to a task queue with the dispatcher id // subscribe to a task queue with the dispatcher id
taskChan, err := d.tq.Subscribe(ctx, taskqueue.QueueTypeFromDispatcherID(dispatcher.ID)) cleanupQueue, taskChan, err := d.tq.Subscribe(taskqueue.QueueTypeFromDispatcherID(dispatcher.ID))
if err != nil { if err != nil {
return err cancel()
return nil, err
} }
_, err = d.s.NewJob( _, err = d.s.NewJob(
@@ -147,42 +150,64 @@ func (d *DispatcherImpl) Start(ctx context.Context) error {
) )
if err != nil { if err != nil {
return fmt.Errorf("could not schedule heartbeat update: %w", err) cancel()
return nil, fmt.Errorf("could not schedule heartbeat update: %w", err)
} }
d.s.Start() d.s.Start()
for { go func() {
select { for {
case <-ctx.Done(): select {
// drain the existing connections case <-ctx.Done():
d.l.Debug().Msg("draining existing connections") return
case task := <-taskChan:
go func(task *taskqueue.Task) {
err = d.handleTask(ctx, task)
d.workers.Range(func(key, value interface{}) bool { if err != nil {
w := value.(subscribedWorker) d.l.Error().Err(err).Msgf("could not handle dispatcher task %s", task.ID)
}
w.finished <- true }(task)
return true
})
err = d.repo.Dispatcher().Delete(dispatcher.ID)
if err == nil {
d.l.Debug().Msgf("deleted dispatcher %s", dispatcher.ID)
} }
return err
case task := <-taskChan:
go func(task *taskqueue.Task) {
err = d.handleTask(ctx, task)
if err != nil {
d.l.Error().Err(err).Msgf("could not handle dispatcher task %s", task.ID)
}
}(task)
} }
}()
cleanup := func() error {
d.l.Debug().Msgf("dispatcher is shutting down...")
cancel()
if err := cleanupQueue(); err != nil {
return fmt.Errorf("could not cleanup queue: %w", err)
}
// drain the existing connections
d.l.Debug().Msg("draining existing connections")
d.workers.Range(func(key, value interface{}) bool {
w := value.(subscribedWorker)
w.finished <- true
return true
})
err = d.repo.Dispatcher().Delete(dispatcher.ID)
if err != nil {
return fmt.Errorf("could not delete dispatcher: %w", err)
}
d.l.Debug().Msgf("deleted dispatcher %s", dispatcher.ID)
if err := d.s.Shutdown(); err != nil {
return fmt.Errorf("could not shutdown scheduler: %w", err)
}
d.l.Debug().Msgf("dispatcher has shutdown")
return nil
} }
return cleanup, nil
} }
func (d *DispatcherImpl) handleTask(ctx context.Context, task *taskqueue.Task) error { func (d *DispatcherImpl) handleTask(ctx context.Context, task *taskqueue.Task) error {
@@ -348,7 +373,7 @@ func (d *DispatcherImpl) runUpdateHeartbeat(ctx context.Context) func() {
}) })
if err != nil { if err != nil {
d.l.Err(err).Msg("could not update heartbeat") d.l.Err(err).Msg("dispatcher: could not update heartbeat")
} }
} }
} }

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"strconv" "strconv"
"sync"
"time" "time"
"github.com/steebchen/prisma-client-go/runtime/types" "github.com/steebchen/prisma-client-go/runtime/types"
@@ -301,19 +302,27 @@ func (s *DispatcherImpl) SubscribeToWorkflowEvents(request *contracts.SubscribeT
defer cancel() defer cancel()
// subscribe to the task queue for the tenant // subscribe to the task queue for the tenant
taskChan, err := s.tq.Subscribe(ctx, q) cleanupQueue, taskChan, err := s.tq.Subscribe(q)
if err != nil { if err != nil {
return err return err
} }
wg := sync.WaitGroup{}
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
if err := cleanupQueue(); err != nil {
return fmt.Errorf("could not cleanup queue: %w", err)
}
// drain the existing connections // drain the existing connections
wg.Wait()
return nil return nil
case task := <-taskChan: case task := <-taskChan:
wg.Add(1)
go func(task *taskqueue.Task) { go func(task *taskqueue.Task) {
defer wg.Done()
e, err := s.tenantTaskToWorkflowEvent(task, tenant.ID, request.WorkflowRunId) e, err := s.tenantTaskToWorkflowEvent(task, tenant.ID, request.WorkflowRunId)
if err != nil { if err != nil {

View File

@@ -1,7 +1,6 @@
package grpc package grpc
import ( import (
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"net" "net"
@@ -14,6 +13,7 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
"github.com/hatchet-dev/hatchet/internal/config/server" "github.com/hatchet-dev/hatchet/internal/config/server"
"github.com/hatchet-dev/hatchet/internal/logger" "github.com/hatchet-dev/hatchet/internal/logger"
@@ -24,8 +24,6 @@ import (
"github.com/hatchet-dev/hatchet/internal/services/grpc/middleware" "github.com/hatchet-dev/hatchet/internal/services/grpc/middleware"
"github.com/hatchet-dev/hatchet/internal/services/ingestor" "github.com/hatchet-dev/hatchet/internal/services/ingestor"
eventcontracts "github.com/hatchet-dev/hatchet/internal/services/ingestor/contracts" eventcontracts "github.com/hatchet-dev/hatchet/internal/services/ingestor/contracts"
"google.golang.org/grpc/status"
) )
type Server struct { type Server struct {
@@ -155,17 +153,17 @@ func NewServer(fs ...ServerOpt) (*Server, error) {
}, nil }, nil
} }
func (s *Server) Start(ctx context.Context) error { func (s *Server) Start() (func() error, error) {
return s.startGRPC(ctx) return s.startGRPC()
} }
func (s *Server) startGRPC(ctx context.Context) error { func (s *Server) startGRPC() (func() error, error) {
s.l.Debug().Msgf("starting grpc server on %s:%d", s.bindAddress, s.port) s.l.Debug().Msgf("starting grpc server on %s:%d", s.bindAddress, s.port)
lis, err := net.Listen("tcp", fmt.Sprintf("%s:%d", s.bindAddress, s.port)) lis, err := net.Listen("tcp", fmt.Sprintf("%s:%d", s.bindAddress, s.port))
if err != nil { if err != nil {
return fmt.Errorf("failed to listen: %w", err) return nil, fmt.Errorf("failed to listen: %w", err)
} }
serverOpts := []grpc.ServerOption{} serverOpts := []grpc.ServerOption{}
@@ -221,6 +219,16 @@ func (s *Server) startGRPC(ctx context.Context) error {
admincontracts.RegisterWorkflowServiceServer(grpcServer, s.admin) admincontracts.RegisterWorkflowServiceServer(grpcServer, s.admin)
} }
// Start listening go func() {
return grpcServer.Serve(lis) if err := grpcServer.Serve(lis); err != nil {
panic(fmt.Errorf("failed to serve: %w", err))
}
}()
cleanup := func() error {
grpcServer.GracefulStop()
return nil
}
return cleanup, nil
} }

View File

@@ -89,25 +89,30 @@ func New(fs ...HeartbeaterOpt) (*HeartbeaterImpl, error) {
}, nil }, nil
} }
func (t *HeartbeaterImpl) Start(ctx context.Context) error { func (t *HeartbeaterImpl) Start() (func() error, error) {
t.l.Debug().Msg("starting heartbeater") t.l.Debug().Msg("starting heartbeater")
_, err := t.s.NewJob( _, err := t.s.NewJob(
gocron.DurationJob(time.Second*5), gocron.DurationJob(time.Second*5),
gocron.NewTask( gocron.NewTask(
t.removeStaleTickers(ctx), t.removeStaleTickers(),
), ),
) )
if err != nil { if err != nil {
return fmt.Errorf("could not schedule ticker removal: %w", err) return nil, fmt.Errorf("could not schedule ticker removal: %w", err)
} }
t.s.Start() t.s.Start()
for range ctx.Done() { cleanup := func() error {
t.l.Debug().Msg("stopping heartbeater") t.l.Debug().Msg("stopping heartbeater")
if err := t.s.Shutdown(); err != nil {
return fmt.Errorf("could not shutdown scheduler: %w", err)
}
t.l.Debug().Msg("heartbeater has shutdown")
return nil
} }
return nil return cleanup, nil
} }

View File

@@ -11,7 +11,7 @@ import (
"github.com/hatchet-dev/hatchet/internal/taskqueue" "github.com/hatchet-dev/hatchet/internal/taskqueue"
) )
func (t *HeartbeaterImpl) removeStaleTickers(ctx context.Context) func() { func (t *HeartbeaterImpl) removeStaleTickers() func() {
return func() { return func() {
t.l.Debug().Msg("removing old tickers") t.l.Debug().Msg("removing old tickers")
@@ -40,7 +40,7 @@ func (t *HeartbeaterImpl) removeStaleTickers(ctx context.Context) func() {
// send a task to the job processing queue that the ticker is removed // send a task to the job processing queue that the ticker is removed
err = t.tq.AddTask( err = t.tq.AddTask(
ctx, context.Background(),
taskqueue.JOB_PROCESSING_QUEUE, taskqueue.JOB_PROCESSING_QUEUE,
tickerRemoved(ticker.ID), tickerRemoved(ticker.ID),
) )

View File

@@ -35,12 +35,8 @@ func (t *TickerImpl) handleScheduleJobRunTimeout(ctx context.Context, task *task
} }
// schedule the timeout // schedule the timeout
childCtx, cancel := context.WithDeadline(context.Background(), timeoutAt) // TODO: ??? make sure this doesn't have any side effects
childCtx, cancel := context.WithDeadline(ctx, timeoutAt)
go func() {
<-childCtx.Done()
t.runJobRunTimeout(metadata.TenantId, payload.JobRunId)
}()
// store the schedule in the step run map // store the schedule in the step run map
t.jobRuns.Store(payload.JobRunId, &timeoutCtx{ t.jobRuns.Store(payload.JobRunId, &timeoutCtx{
@@ -48,6 +44,16 @@ func (t *TickerImpl) handleScheduleJobRunTimeout(ctx context.Context, task *task
cancel: cancel, cancel: cancel,
}) })
go func() {
<-childCtx.Done()
t.runJobRunTimeout(metadata.TenantId, payload.JobRunId)
t.jobRuns.Range(func(key, value interface{}) bool {
v, _ := value.(*timeoutCtx)
v.cancel()
return true
})
}()
return nil return nil
} }

View File

@@ -115,7 +115,9 @@ func New(fs ...TickerOpt) (*TickerImpl, error) {
}, nil }, nil
} }
func (t *TickerImpl) Start(ctx context.Context) error { func (t *TickerImpl) Start() (func() error, error) {
ctx, cancel := context.WithCancel(context.Background())
t.l.Debug().Msgf("starting ticker %s", t.tickerId) t.l.Debug().Msgf("starting ticker %s", t.tickerId)
// register the ticker // register the ticker
@@ -124,14 +126,16 @@ func (t *TickerImpl) Start(ctx context.Context) error {
}) })
if err != nil { if err != nil {
return err cancel()
return nil, err
} }
// subscribe to a task queue with the dispatcher id // subscribe to a task queue with the dispatcher id
taskChan, err := t.tq.Subscribe(ctx, taskqueue.QueueTypeFromTickerID(ticker.ID)) cleanupQueue, taskChan, err := t.tq.Subscribe(taskqueue.QueueTypeFromTickerID(ticker.ID))
if err != nil { if err != nil {
return err cancel()
return nil, err
} }
_, err = t.s.NewJob( _, err = t.s.NewJob(
@@ -142,7 +146,8 @@ func (t *TickerImpl) Start(ctx context.Context) error {
) )
if err != nil { if err != nil {
return fmt.Errorf("could not schedule step run requeue: %w", err) cancel()
return nil, fmt.Errorf("could not schedule step run requeue: %w", err)
} }
_, err = t.s.NewJob( _, err = t.s.NewJob(
@@ -153,7 +158,8 @@ func (t *TickerImpl) Start(ctx context.Context) error {
) )
if err != nil { if err != nil {
return fmt.Errorf("could not schedule get group key run requeue: %w", err) cancel()
return nil, fmt.Errorf("could not schedule get group key run requeue: %w", err)
} }
_, err = t.s.NewJob( _, err = t.s.NewJob(
@@ -165,35 +171,13 @@ func (t *TickerImpl) Start(ctx context.Context) error {
t.s.Start() t.s.Start()
for { wg := sync.WaitGroup{}
select {
case <-ctx.Done():
t.l.Debug().Msg("removing ticker")
// delete the ticker go func() {
err = t.repo.Ticker().Delete(t.tickerId) for task := range taskChan {
wg.Add(1)
if err != nil {
t.l.Err(err).Msg("could not delete ticker")
return err
}
// add the task after the ticker is deleted
err := t.tq.AddTask(
ctx,
taskqueue.JOB_PROCESSING_QUEUE,
tickerRemoved(t.tickerId),
)
if err != nil {
t.l.Err(err).Msg("could not add ticker removed task")
return err
}
// return err
return nil
case task := <-taskChan:
go func(task *taskqueue.Task) { go func(task *taskqueue.Task) {
defer wg.Done()
err = t.handleTask(ctx, task) err = t.handleTask(ctx, task)
if err != nil { if err != nil {
@@ -201,7 +185,47 @@ func (t *TickerImpl) Start(ctx context.Context) error {
} }
}(task) }(task)
} }
}()
cleanup := func() error {
t.l.Debug().Msg("removing ticker")
cancel()
if err := cleanupQueue(); err != nil {
return fmt.Errorf("could not cleanup queue: %w", err)
}
wg.Wait()
// delete the ticker
err = t.repo.Ticker().Delete(t.tickerId)
if err != nil {
t.l.Err(err).Msg("could not delete ticker")
return err
}
// add the task after the ticker is deleted
err = t.tq.AddTask(
ctx,
taskqueue.JOB_PROCESSING_QUEUE,
tickerRemoved(t.tickerId),
)
if err != nil {
t.l.Err(err).Msg("could not add ticker removed task")
return err
}
if err := t.s.Shutdown(); err != nil {
return fmt.Errorf("could not shutdown scheduler: %w", err)
}
return nil
} }
return cleanup, nil
} }
func (t *TickerImpl) handleTask(ctx context.Context, task *taskqueue.Task) error { func (t *TickerImpl) handleTask(ctx context.Context, task *taskqueue.Task) error {

View File

@@ -6,6 +6,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
"sync"
"time" "time"
lru "github.com/hashicorp/golang-lru/v2" lru "github.com/hashicorp/golang-lru/v2"
@@ -70,7 +71,9 @@ func WithURL(url string) TaskQueueImplOpt {
} }
// New creates a new TaskQueueImpl. // New creates a new TaskQueueImpl.
func New(ctx context.Context, fs ...TaskQueueImplOpt) *TaskQueueImpl { func New(fs ...TaskQueueImplOpt) (func() error, *TaskQueueImpl) {
ctx, cancel := context.WithCancel(context.Background())
opts := defaultTaskQueueImplOpts() opts := defaultTaskQueueImplOpts()
for _, f := range fs { for _, f := range fs {
@@ -99,30 +102,40 @@ func New(ctx context.Context, fs ...TaskQueueImplOpt) *TaskQueueImpl {
sub := <-<-sessions sub := <-<-sessions
if _, err := t.initQueue(sub, taskqueue.EVENT_PROCESSING_QUEUE); err != nil { if _, err := t.initQueue(sub, taskqueue.EVENT_PROCESSING_QUEUE); err != nil {
t.l.Debug().Msgf("error initializing queue: %v", err) t.l.Debug().Msgf("error initializing queue: %v", err)
return nil cancel()
return nil, nil
} }
if _, err := t.initQueue(sub, taskqueue.JOB_PROCESSING_QUEUE); err != nil { if _, err := t.initQueue(sub, taskqueue.JOB_PROCESSING_QUEUE); err != nil {
t.l.Debug().Msgf("error initializing queue: %v", err) t.l.Debug().Msgf("error initializing queue: %v", err)
return nil cancel()
return nil, nil
} }
if _, err := t.initQueue(sub, taskqueue.WORKFLOW_PROCESSING_QUEUE); err != nil { if _, err := t.initQueue(sub, taskqueue.WORKFLOW_PROCESSING_QUEUE); err != nil {
t.l.Debug().Msgf("error initializing queue: %v", err) t.l.Debug().Msgf("error initializing queue: %v", err)
return nil cancel()
return nil, nil
} }
if _, err := t.initQueue(sub, taskqueue.SCHEDULING_QUEUE); err != nil { if _, err := t.initQueue(sub, taskqueue.SCHEDULING_QUEUE); err != nil {
t.l.Debug().Msgf("error initializing queue: %v", err) t.l.Debug().Msgf("error initializing queue: %v", err)
return nil cancel()
return nil, nil
} }
// create publisher go func // create publisher go func
go func() { cleanup1 := t.startPublishing()
t.publish()
}()
return t cleanup := func() error {
cancel()
if err := cleanup1(); err != nil {
return fmt.Errorf("error cleaning up rabbitmq publisher: %w", err)
}
return nil
}
return cleanup, t
} }
// AddTask adds a task to the queue. // AddTask adds a task to the queue.
@@ -136,12 +149,12 @@ func (t *TaskQueueImpl) AddTask(ctx context.Context, q taskqueue.Queue, task *ta
} }
// Subscribe subscribes to the task queue. // Subscribe subscribes to the task queue.
func (t *TaskQueueImpl) Subscribe(ctx context.Context, q taskqueue.Queue) (<-chan *taskqueue.Task, error) { func (t *TaskQueueImpl) Subscribe(q taskqueue.Queue) (func() error, <-chan *taskqueue.Task, error) {
t.l.Debug().Msgf("subscribing to queue: %s", q.Name()) t.l.Debug().Msgf("subscribing to queue: %s", q.Name())
tasks := make(chan *taskqueue.Task) tasks := make(chan *taskqueue.Task)
go t.subscribe(ctx, t.identity, q, t.sessions, t.tasks, tasks) cleanup := t.subscribe(t.identity, q, t.sessions, t.tasks, tasks)
return tasks, nil return cleanup, tasks, nil
} }
func (t *TaskQueueImpl) RegisterTenant(ctx context.Context, tenantId string) error { func (t *TaskQueueImpl) RegisterTenant(ctx context.Context, tenantId string) error {
@@ -204,107 +217,148 @@ func (t *TaskQueueImpl) initQueue(sub session, q taskqueue.Queue) (string, error
return name, nil return name, nil
} }
func (t *TaskQueueImpl) publish() { func (t *TaskQueueImpl) startPublishing() func() error {
for session := range t.sessions { ctx, cancel := context.WithCancel(t.ctx)
pub := <-session
for task := range t.tasks { cleanup := func() error {
go func(task *taskWithQueue) { cancel()
body, err := json.Marshal(task) return nil
}
if err != nil { go func() {
t.l.Error().Msgf("error marshaling task queue: %v", err) for session := range t.sessions {
pub := <-session
for {
select {
case <-ctx.Done():
return return
} case task := <-t.tasks:
go func(task *taskWithQueue) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) body, err := json.Marshal(task)
defer cancel()
t.l.Debug().Msgf("publishing task %s to queue %s", task.ID, task.q.Name())
err = pub.PublishWithContext(ctx, "", task.q.Name(), false, false, amqp.Publishing{
Body: body,
})
// TODO: retry failed delivery on the next session
if err != nil {
t.l.Error().Msgf("error publishing task: %v", err)
return
}
// if this is a tenant task, publish to the tenant exchange
if task.TenantID() != "" {
// determine if the tenant exchange exists
if _, ok := t.tenantIdCache.Get(task.TenantID()); !ok {
// register the tenant exchange
err = t.RegisterTenant(ctx, task.TenantID())
if err != nil { if err != nil {
t.l.Error().Msgf("error registering tenant exchange: %v", err) t.l.Error().Msgf("error marshaling task queue: %v", err)
return return
} }
}
t.l.Debug().Msgf("publishing tenant task %s to exchange %s", task.ID, task.TenantID()) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err = pub.PublishWithContext(ctx, task.TenantID(), "", false, false, amqp.Publishing{ t.l.Debug().Msgf("publishing task %s to queue %s", task.ID, task.q.Name())
Body: body,
})
if err != nil { err = pub.PublishWithContext(ctx, "", task.q.Name(), false, false, amqp.Publishing{
t.l.Error().Msgf("error publishing tenant task: %v", err) Body: body,
return })
}
// TODO: retry failed delivery on the next session
if err != nil {
t.l.Error().Msgf("error publishing task: %v", err)
return
}
// if this is a tenant task, publish to the tenant exchange
if task.TenantID() != "" {
// determine if the tenant exchange exists
if _, ok := t.tenantIdCache.Get(task.TenantID()); !ok {
// register the tenant exchange
err = t.RegisterTenant(ctx, task.TenantID())
if err != nil {
t.l.Error().Msgf("error registering tenant exchange: %v", err)
return
}
}
t.l.Debug().Msgf("publishing tenant task %s to exchange %s", task.ID, task.TenantID())
err = pub.PublishWithContext(ctx, task.TenantID(), "", false, false, amqp.Publishing{
Body: body,
})
if err != nil {
t.l.Error().Msgf("error publishing tenant task: %v", err)
return
}
}
t.l.Debug().Msgf("published task %s to queue %s", task.ID, task.q.Name())
}(task)
} }
}
t.l.Debug().Msgf("published task %s to queue %s", task.ID, task.q.Name())
}(task)
} }
} }()
return cleanup
} }
func (t *TaskQueueImpl) subscribe(ctx context.Context, subId string, q taskqueue.Queue, sessions chan chan session, messages chan *taskWithQueue, tasks chan<- *taskqueue.Task) { func (t *TaskQueueImpl) subscribe(subId string, q taskqueue.Queue, sessions chan chan session, messages chan *taskWithQueue, tasks chan<- *taskqueue.Task) func() error {
ctx, cancel := context.WithCancel(context.Background())
sessionCount := 0 sessionCount := 0
for session := range sessions { wg := sync.WaitGroup{}
sessionCount++
sub := <-session
// we initialize the queue here because exclusive queues are bound to the session/connection. however, it's not clear go func() {
// if the exclusive queue will be available to the next session. for session := range sessions {
queueName, err := t.initQueue(sub, q) sessionCount++
sub := <-session
if err != nil { // we initialize the queue here because exclusive queues are bound to the session/connection. however, it's not clear
return // if the exclusive queue will be available to the next session.
} queueName, err := t.initQueue(sub, q)
deliveries, err := sub.Consume(queueName, subId, false, q.Exclusive(), false, false, nil) if err != nil {
return
}
if err != nil { deliveries, err := sub.Consume(queueName, subId, false, q.Exclusive(), false, false, nil)
t.l.Error().Msgf("cannot consume from: %s, %v", queueName, err)
return
}
for msg := range deliveries { if err != nil {
go func(msg amqp.Delivery) { t.l.Error().Msgf("cannot consume from: %s, %v", queueName, err)
task := &taskWithQueue{} return
}
if err := json.Unmarshal(msg.Body, task); err != nil { for {
t.l.Error().Msgf("error unmarshaling message: %v", err) select {
case msg := <-deliveries:
wg.Add(1)
go func(msg amqp.Delivery) {
defer wg.Done()
task := &taskWithQueue{}
if err := json.Unmarshal(msg.Body, task); err != nil {
t.l.Error().Msgf("error unmarshaling message: %v", err)
return
}
t.l.Debug().Msgf("(session: %d) got task: %v", sessionCount, task.ID)
tasks <- task.Task
if err := sub.Ack(msg.DeliveryTag, false); err != nil {
t.l.Error().Msgf("error acknowledging message: %v", err)
return
}
}(msg)
case <-ctx.Done():
return return
} }
}
t.l.Debug().Msgf("(session: %d) got task: %v", sessionCount, task.ID)
tasks <- task.Task
if err := sub.Ack(msg.DeliveryTag, false); err != nil {
t.l.Error().Msgf("error acknowledging message: %v", err)
return
}
}(msg)
} }
}()
cleanup := func() error {
cancel()
t.l.Debug().Msgf("shutting down subscriber: %s", subId)
wg.Wait()
close(tasks)
t.l.Debug().Msgf("successfully shut down subscriber: %s", subId)
return nil
} }
return cleanup
} }
// redial continually connects to the URL, exiting the program when no longer possible // redial continually connects to the URL, exiting the program when no longer possible

View File

@@ -21,9 +21,10 @@ func TestTaskQueueIntegration(t *testing.T) {
url := "amqp://user:password@localhost:5672/" url := "amqp://user:password@localhost:5672/"
// Initialize the task queue implementation // Initialize the task queue implementation
tq := rabbitmq.New(ctx, cleanup, tq := rabbitmq.New(
rabbitmq.WithURL(url), rabbitmq.WithURL(url),
) )
defer cleanup()
require.NotNil(t, tq, "task queue implementation should not be nil") require.NotNil(t, tq, "task queue implementation should not be nil")
@@ -41,7 +42,7 @@ func TestTaskQueueIntegration(t *testing.T) {
assert.NoError(t, err, "adding task to static queue should not error") assert.NoError(t, err, "adding task to static queue should not error")
// Test subscription to the static queue // Test subscription to the static queue
taskChan, err := tq.Subscribe(ctx, staticQueue) cleanupQueue, taskChan, err := tq.Subscribe(staticQueue)
require.NoError(t, err, "subscribing to static queue should not error") require.NoError(t, err, "subscribing to static queue should not error")
select { select {
@@ -64,7 +65,7 @@ func TestTaskQueueIntegration(t *testing.T) {
} }
// Test subscription to the tenant-specific queue // Test subscription to the tenant-specific queue
tenantTaskChan, err := tq.Subscribe(ctx, tenantQueue) cleanupTenantQueue, tenantTaskChan, err := tq.Subscribe(tenantQueue)
require.NoError(t, err, "subscribing to tenant-specific queue should not error") require.NoError(t, err, "subscribing to tenant-specific queue should not error")
// send task to tenant-specific queue after 1 second to give time for subscriber // send task to tenant-specific queue after 1 second to give time for subscriber
@@ -77,9 +78,14 @@ func TestTaskQueueIntegration(t *testing.T) {
select { select {
case receivedTask := <-tenantTaskChan: case receivedTask := <-tenantTaskChan:
assert.Equal(t, task.ID, receivedTask.ID, "received tenant task ID should match sent task ID") assert.Equal(t, task.ID, receivedTask.ID, "received tenant task ID should match sent task ID")
break
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for task from tenant-specific queue") t.Fatal("timed out waiting for task from tenant-specific queue")
break }
if err := cleanupQueue(); err != nil {
t.Fatalf("error cleaning up queue: %v", err)
}
if err := cleanupTenantQueue(); err != nil {
t.Fatalf("error cleaning up queue: %v", err)
} }
} }

View File

@@ -139,7 +139,7 @@ type TaskQueue interface {
AddTask(ctx context.Context, queue Queue, task *Task) error AddTask(ctx context.Context, queue Queue, task *Task) error
// Subscribe subscribes to the task queue. // Subscribe subscribes to the task queue.
Subscribe(ctx context.Context, queueType Queue) (<-chan *Task, error) Subscribe(queueType Queue) (func() error, <-chan *Task, error)
// RegisterTenant registers a new pub/sub mechanism for a tenant. This should be called when a // RegisterTenant registers a new pub/sub mechanism for a tenant. This should be called when a
// new tenant is created. If this is not called, implementors should ensure that there's a check // new tenant is created. If this is not called, implementors should ensure that there's a check

View File

@@ -0,0 +1,45 @@
package testutils
import (
"context"
"log"
"os"
"path"
"path/filepath"
"runtime"
"testing"
"github.com/hatchet-dev/hatchet/cmd/hatchet-engine/engine"
"github.com/hatchet-dev/hatchet/internal/config/loader"
)
func SetupEngine(ctx context.Context, t *testing.T) {
t.Helper()
_, b, _, _ := runtime.Caller(0)
testPath := filepath.Dir(b)
dir := path.Join(testPath, "../..")
log.Printf("dir: %s", dir)
_ = os.Setenv("DATABASE_URL", "postgresql://hatchet:hatchet@127.0.0.1:5431/hatchet")
_ = os.Setenv("SERVER_TLS_CERT_FILE", path.Join(dir, "hack/dev/certs/cluster.pem"))
_ = os.Setenv("SERVER_TLS_KEY_FILE", path.Join(dir, "hack/dev/certs/cluster.key"))
_ = os.Setenv("SERVER_TLS_ROOT_CA_FILE", path.Join(dir, "hack/dev/certs/ca.cert"))
_ = os.Setenv("SERVER_PORT", "8080")
_ = os.Setenv("SERVER_URL", "https://app.dev.hatchet-tools.com")
_ = os.Setenv("SERVER_AUTH_COOKIE_SECRETS", "something something")
_ = os.Setenv("SERVER_AUTH_COOKIE_DOMAIN", "app.dev.hatchet-tools.com")
_ = os.Setenv("SERVER_AUTH_COOKIE_INSECURE", "false")
_ = os.Setenv("SERVER_AUTH_SET_EMAIL_VERIFIED", "true")
_ = os.Setenv("SERVER_LOGGER_LEVEL", "debug")
_ = os.Setenv("SERVER_LOGGER_FORMAT", "console")
_ = os.Setenv("DATABASE_LOGGER_LEVEL", "debug")
_ = os.Setenv("DATABASE_LOGGER_FORMAT", "console")
cf := loader.NewConfigLoader(path.Join(dir, "./generated/"))
if err := engine.Run(ctx, cf); err != nil {
t.Fatalf("engine failure: %s", err.Error())
}
}

View File

@@ -24,7 +24,7 @@ model User {
// the user's oauth providers // the user's oauth providers
oauthProviders UserOAuth[] oauthProviders UserOAuth[]
// The hashed user's password. This is placed in a separate table so that it isn't returned by default. // The hashed user's password. This is placed in a separate table so that it isn't returned by default.
password UserPassword? password UserPassword?
// the user's name // the user's name
@@ -196,7 +196,7 @@ model APIToken {
tenantId String? @db.Uuid tenantId String? @db.Uuid
} }
// Event represents an event in the database. // Event represents an event in the database.
model Event { model Event {
// base fields // base fields
id String @id @unique @default(uuid()) @db.Uuid id String @id @unique @default(uuid()) @db.Uuid
@@ -319,7 +319,7 @@ model WorkflowVersion {
// concurrency limits for the workflow // concurrency limits for the workflow
concurrency WorkflowConcurrency? concurrency WorkflowConcurrency?
// the declared jobs // the declared jobs
jobs Job[] jobs Job[]
// all runs for the workflow // all runs for the workflow
@@ -819,7 +819,7 @@ model StepRun {
// the run output // the run output
output Json? output Json?
// inputSchema is a JSON object which declares a JSON schema for the input data // inputSchema is a JSON object which declares a JSON schema for the input data
inputSchema Json? inputSchema Json?
// when the step should be requeued // when the step should be requeued