mirror of
https://github.com/hatchet-dev/hatchet.git
synced 2026-01-06 16:59:39 -06:00
feat(engine): standalone tests and engine teardown (#172)
This commit is contained in:
68
.github/workflows/test.yml
vendored
68
.github/workflows/test.yml
vendored
@@ -182,7 +182,73 @@ jobs:
|
||||
run: |
|
||||
export HATCHET_CLIENT_TOKEN="$(go run ./cmd/hatchet-admin token create --config ./generated/ --tenant-id 707d0855-80ab-4e1f-a156-f1c4546cbf52)"
|
||||
|
||||
go test -tags e2e ./... -race -p 1 -v -failfast
|
||||
go test -tags e2e ./... -p 1 -v -failfast
|
||||
|
||||
- name: Teardown
|
||||
run: docker compose down
|
||||
|
||||
load:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
env:
|
||||
DATABASE_URL: postgresql://hatchet:hatchet@127.0.0.1:5431/hatchet
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install Task
|
||||
uses: arduino/setup-task@v1
|
||||
|
||||
- name: Install Protoc
|
||||
uses: arduino/setup-protoc@v2
|
||||
with:
|
||||
version: "25.1"
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.21"
|
||||
|
||||
- name: Compose
|
||||
run: docker compose up -d
|
||||
|
||||
- name: Go deps
|
||||
run: go mod download
|
||||
|
||||
- name: Generate
|
||||
run: |
|
||||
go run github.com/steebchen/prisma-client-go db push
|
||||
task generate-certs
|
||||
task generate-local-encryption-keys
|
||||
|
||||
- name: Prepare
|
||||
run: |
|
||||
cat > .env <<EOF
|
||||
DATABASE_URL='postgresql://hatchet:hatchet@127.0.0.1:5431/hatchet'
|
||||
SERVER_TLS_CERT_FILE=./hack/dev/certs/cluster.pem
|
||||
SERVER_TLS_KEY_FILE=./hack/dev/certs/cluster.key
|
||||
SERVER_TLS_ROOT_CA_FILE=./hack/dev/certs/ca.cert
|
||||
SERVER_PORT=8080
|
||||
SERVER_URL=https://app.dev.hatchet-tools.com
|
||||
SERVER_AUTH_COOKIE_SECRETS="something something"
|
||||
SERVER_AUTH_COOKIE_DOMAIN=app.dev.hatchet-tools.com
|
||||
SERVER_AUTH_COOKIE_INSECURE=false
|
||||
SERVER_AUTH_SET_EMAIL_VERIFIED=true
|
||||
EOF
|
||||
|
||||
- name: Setup
|
||||
run: |
|
||||
set -a
|
||||
. .env
|
||||
set +a
|
||||
|
||||
go run ./cmd/hatchet-admin quickstart --generated-config-dir ./generated/
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
export HATCHET_CLIENT_TOKEN="$(go run ./cmd/hatchet-admin token create --config ./generated/ --tenant-id 707d0855-80ab-4e1f-a156-f1c4546cbf52)"
|
||||
|
||||
go test -tags load ./... -p 1 -v -failfast
|
||||
|
||||
- name: Teardown
|
||||
run: docker compose down
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -89,3 +89,5 @@ postgres-data
|
||||
rabbitmq.conf
|
||||
|
||||
*encryption-keys
|
||||
generated/
|
||||
certs/
|
||||
|
||||
@@ -93,6 +93,7 @@ tasks:
|
||||
- go mod download
|
||||
- cd frontend/app/ && pnpm install
|
||||
- cd frontend/docs/ && pnpm install
|
||||
- cd typescript-sdk/ && pnpm install
|
||||
generate-api:
|
||||
cmds:
|
||||
- task: generate-api-server
|
||||
|
||||
@@ -59,7 +59,12 @@ func runCreateAPIToken() error {
|
||||
// read in the local config
|
||||
configLoader := loader.NewConfigLoader(configDirectory)
|
||||
|
||||
serverConf, err := configLoader.LoadServerConfig()
|
||||
cleanup, serverConf, err := configLoader.LoadServerConfig()
|
||||
defer func() {
|
||||
if err := cleanup(); err != nil {
|
||||
panic(fmt.Errorf("could not cleanup server config: %v", err))
|
||||
}
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -59,16 +59,23 @@ func main() {
|
||||
}
|
||||
|
||||
func startServerOrDie(cf *loader.ConfigLoader, interruptCh <-chan interface{}) {
|
||||
ctx, cancel := cmdutils.InterruptContextFromChan(interruptCh)
|
||||
defer cancel()
|
||||
|
||||
// init the repository
|
||||
sc, err := cf.LoadServerConfig()
|
||||
cleanup, sc, err := cf.LoadServerConfig()
|
||||
defer func() {
|
||||
if err := cleanup(); err != nil {
|
||||
panic(fmt.Errorf("could not cleanup server config: %v", err))
|
||||
}
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
errCh := make(chan error)
|
||||
ctx, cancel := cmdutils.InterruptContextFromChan(interruptCh)
|
||||
defer cancel()
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
if sc.InternalClient != nil {
|
||||
|
||||
259
cmd/hatchet-engine/engine/run.go
Normal file
259
cmd/hatchet-engine/engine/run.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/hatchet-dev/hatchet/internal/config/loader"
|
||||
"github.com/hatchet-dev/hatchet/internal/services/admin"
|
||||
"github.com/hatchet-dev/hatchet/internal/services/controllers/events"
|
||||
"github.com/hatchet-dev/hatchet/internal/services/controllers/jobs"
|
||||
"github.com/hatchet-dev/hatchet/internal/services/controllers/workflows"
|
||||
"github.com/hatchet-dev/hatchet/internal/services/dispatcher"
|
||||
"github.com/hatchet-dev/hatchet/internal/services/grpc"
|
||||
"github.com/hatchet-dev/hatchet/internal/services/heartbeat"
|
||||
"github.com/hatchet-dev/hatchet/internal/services/ingestor"
|
||||
"github.com/hatchet-dev/hatchet/internal/services/ticker"
|
||||
"github.com/hatchet-dev/hatchet/internal/telemetry"
|
||||
)
|
||||
|
||||
type Teardown struct {
|
||||
name string
|
||||
fn func() error
|
||||
}
|
||||
|
||||
func Run(ctx context.Context, cf *loader.ConfigLoader) error {
|
||||
serverCleanup, sc, err := cf.LoadServerConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not load server config: %w", err)
|
||||
}
|
||||
var l = sc.Logger
|
||||
|
||||
shutdown, err := telemetry.InitTracer(&telemetry.TracerOpts{
|
||||
ServiceName: sc.OpenTelemetry.ServiceName,
|
||||
CollectorURL: sc.OpenTelemetry.CollectorURL,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not initialize tracer: %w", err)
|
||||
}
|
||||
|
||||
var teardown []Teardown
|
||||
|
||||
if sc.HasService("ticker") {
|
||||
t, err := ticker.New(
|
||||
ticker.WithTaskQueue(sc.TaskQueue),
|
||||
ticker.WithRepository(sc.Repository),
|
||||
ticker.WithLogger(sc.Logger),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create ticker: %w", err)
|
||||
}
|
||||
|
||||
cleanup, err := t.Start()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not start ticker: %w", err)
|
||||
}
|
||||
teardown = append(teardown, Teardown{
|
||||
name: "ticker",
|
||||
fn: cleanup,
|
||||
})
|
||||
}
|
||||
|
||||
if sc.HasService("eventscontroller") {
|
||||
ec, err := events.New(
|
||||
events.WithTaskQueue(sc.TaskQueue),
|
||||
events.WithRepository(sc.Repository),
|
||||
events.WithLogger(sc.Logger),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create events controller: %w", err)
|
||||
}
|
||||
|
||||
cleanup, err := ec.Start()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not start events controller: %w", err)
|
||||
}
|
||||
teardown = append(teardown, Teardown{
|
||||
name: "events controller",
|
||||
fn: cleanup,
|
||||
})
|
||||
}
|
||||
|
||||
if sc.HasService("jobscontroller") {
|
||||
jc, err := jobs.New(
|
||||
jobs.WithTaskQueue(sc.TaskQueue),
|
||||
jobs.WithRepository(sc.Repository),
|
||||
jobs.WithLogger(sc.Logger),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create jobs controller: %w", err)
|
||||
}
|
||||
|
||||
cleanup, err := jc.Start()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not start jobs controller: %w", err)
|
||||
}
|
||||
teardown = append(teardown, Teardown{
|
||||
name: "jobs controller",
|
||||
fn: cleanup,
|
||||
})
|
||||
}
|
||||
|
||||
if sc.HasService("workflowscontroller") {
|
||||
wc, err := workflows.New(
|
||||
workflows.WithTaskQueue(sc.TaskQueue),
|
||||
workflows.WithRepository(sc.Repository),
|
||||
workflows.WithLogger(sc.Logger),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create workflows controller: %w", err)
|
||||
}
|
||||
|
||||
cleanup, err := wc.Start()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not start workflows controller: %w", err)
|
||||
}
|
||||
teardown = append(teardown, Teardown{
|
||||
name: "workflows controller",
|
||||
fn: cleanup,
|
||||
})
|
||||
}
|
||||
|
||||
if sc.HasService("heartbeater") {
|
||||
h, err := heartbeat.New(
|
||||
heartbeat.WithTaskQueue(sc.TaskQueue),
|
||||
heartbeat.WithRepository(sc.Repository),
|
||||
heartbeat.WithLogger(sc.Logger),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create heartbeater: %w", err)
|
||||
}
|
||||
|
||||
cleanup, err := h.Start()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not start heartbeater: %w", err)
|
||||
}
|
||||
teardown = append(teardown, Teardown{
|
||||
name: "heartbeater",
|
||||
fn: cleanup,
|
||||
})
|
||||
}
|
||||
|
||||
if sc.HasService("grpc") {
|
||||
// create the dispatcher
|
||||
d, err := dispatcher.New(
|
||||
dispatcher.WithTaskQueue(sc.TaskQueue),
|
||||
dispatcher.WithRepository(sc.Repository),
|
||||
dispatcher.WithLogger(sc.Logger),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create dispatcher: %w", err)
|
||||
}
|
||||
|
||||
dispatcherCleanup, err := d.Start()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not start dispatcher: %w", err)
|
||||
}
|
||||
|
||||
teardown = append(teardown, Teardown{
|
||||
name: "grpc dispatcher",
|
||||
fn: dispatcherCleanup,
|
||||
})
|
||||
|
||||
// create the event ingestor
|
||||
ei, err := ingestor.NewIngestor(
|
||||
ingestor.WithEventRepository(
|
||||
sc.Repository.Event(),
|
||||
),
|
||||
ingestor.WithTaskQueue(sc.TaskQueue),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create ingestor: %w", err)
|
||||
}
|
||||
|
||||
adminSvc, err := admin.NewAdminService(
|
||||
admin.WithRepository(sc.Repository),
|
||||
admin.WithTaskQueue(sc.TaskQueue),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create admin service: %w", err)
|
||||
}
|
||||
|
||||
grpcOpts := []grpc.ServerOpt{
|
||||
grpc.WithConfig(sc),
|
||||
grpc.WithIngestor(ei),
|
||||
grpc.WithDispatcher(d),
|
||||
grpc.WithAdmin(adminSvc),
|
||||
grpc.WithLogger(sc.Logger),
|
||||
grpc.WithTLSConfig(sc.TLSConfig),
|
||||
grpc.WithPort(sc.Runtime.GRPCPort),
|
||||
grpc.WithBindAddress(sc.Runtime.GRPCBindAddress),
|
||||
}
|
||||
|
||||
if sc.Runtime.GRPCInsecure {
|
||||
grpcOpts = append(grpcOpts, grpc.WithInsecure())
|
||||
}
|
||||
|
||||
// create the grpc server
|
||||
s, err := grpc.NewServer(
|
||||
grpcOpts...,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create grpc server: %w", err)
|
||||
}
|
||||
|
||||
grpcServerCleanup, err := s.Start()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not start grpc server: %w", err)
|
||||
}
|
||||
|
||||
teardown = append(teardown, Teardown{
|
||||
name: "grpc server",
|
||||
fn: grpcServerCleanup,
|
||||
})
|
||||
}
|
||||
|
||||
teardown = append(teardown, Teardown{
|
||||
name: "telemetry",
|
||||
fn: func() error {
|
||||
return shutdown(ctx)
|
||||
},
|
||||
})
|
||||
teardown = append(teardown, Teardown{
|
||||
name: "server",
|
||||
fn: func() error {
|
||||
return serverCleanup()
|
||||
},
|
||||
})
|
||||
teardown = append(teardown, Teardown{
|
||||
name: "database",
|
||||
fn: func() error {
|
||||
return sc.Disconnect()
|
||||
},
|
||||
})
|
||||
|
||||
l.Debug().Msgf("engine has started")
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
l.Debug().Msgf("interrupt received, shutting down")
|
||||
|
||||
l.Debug().Msgf("waiting for all other services to gracefully exit...")
|
||||
for i, t := range teardown {
|
||||
l.Debug().Msgf("shutting down %s (%d/%d)", t.name, i+1, len(teardown))
|
||||
err := t.fn()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not teardown %s: %w", t.name, err)
|
||||
}
|
||||
l.Debug().Msgf("successfully shutdown %s (%d/%d)", t.name, i+1, len(teardown))
|
||||
}
|
||||
l.Debug().Msgf("all services have successfully gracefully exited")
|
||||
|
||||
l.Debug().Msgf("successfully shutdown")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -2,22 +2,13 @@ package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/hatchet-dev/hatchet/cmd/hatchet-engine/engine"
|
||||
"github.com/hatchet-dev/hatchet/internal/config/loader"
|
||||
"github.com/hatchet-dev/hatchet/internal/services/admin"
|
||||
"github.com/hatchet-dev/hatchet/internal/services/controllers/events"
|
||||
"github.com/hatchet-dev/hatchet/internal/services/controllers/jobs"
|
||||
"github.com/hatchet-dev/hatchet/internal/services/controllers/workflows"
|
||||
"github.com/hatchet-dev/hatchet/internal/services/dispatcher"
|
||||
"github.com/hatchet-dev/hatchet/internal/services/grpc"
|
||||
"github.com/hatchet-dev/hatchet/internal/services/heartbeat"
|
||||
"github.com/hatchet-dev/hatchet/internal/services/ingestor"
|
||||
"github.com/hatchet-dev/hatchet/internal/services/ticker"
|
||||
"github.com/hatchet-dev/hatchet/internal/telemetry"
|
||||
"github.com/hatchet-dev/hatchet/pkg/cmdutils"
|
||||
)
|
||||
|
||||
@@ -35,9 +26,13 @@ var rootCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
cf := loader.NewConfigLoader(configDirectory)
|
||||
interruptChan := cmdutils.InterruptChan()
|
||||
context, cancel := cmdutils.NewInterruptContext()
|
||||
defer cancel()
|
||||
|
||||
startEngineOrDie(cf, interruptChan)
|
||||
if err := engine.Run(context, cf); err != nil {
|
||||
log.Printf("engine failure: %s", err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@@ -64,240 +59,3 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func startEngineOrDie(cf *loader.ConfigLoader, interruptCh <-chan interface{}) {
|
||||
sc, err := cf.LoadServerConfig()
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
errCh := make(chan error)
|
||||
ctx, cancel := cmdutils.InterruptContextFromChan(interruptCh)
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
shutdown, err := telemetry.InitTracer(&telemetry.TracerOpts{
|
||||
ServiceName: sc.OpenTelemetry.ServiceName,
|
||||
CollectorURL: sc.OpenTelemetry.CollectorURL,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("could not initialize tracer: %s", err))
|
||||
}
|
||||
|
||||
defer shutdown(ctx) // nolint: errcheck
|
||||
|
||||
if sc.HasService("grpc") {
|
||||
wg.Add(1)
|
||||
|
||||
// create the dispatcher
|
||||
d, err := dispatcher.New(
|
||||
dispatcher.WithTaskQueue(sc.TaskQueue),
|
||||
dispatcher.WithRepository(sc.Repository),
|
||||
dispatcher.WithLogger(sc.Logger),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := d.Start(ctx)
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// create the event ingestor
|
||||
ei, err := ingestor.NewIngestor(
|
||||
ingestor.WithEventRepository(
|
||||
sc.Repository.Event(),
|
||||
),
|
||||
ingestor.WithTaskQueue(sc.TaskQueue),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
adminSvc, err := admin.NewAdminService(
|
||||
admin.WithRepository(sc.Repository),
|
||||
admin.WithTaskQueue(sc.TaskQueue),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
grpcOpts := []grpc.ServerOpt{
|
||||
grpc.WithConfig(sc),
|
||||
grpc.WithIngestor(ei),
|
||||
grpc.WithDispatcher(d),
|
||||
grpc.WithAdmin(adminSvc),
|
||||
grpc.WithLogger(sc.Logger),
|
||||
grpc.WithTLSConfig(sc.TLSConfig),
|
||||
grpc.WithPort(sc.Runtime.GRPCPort),
|
||||
grpc.WithBindAddress(sc.Runtime.GRPCBindAddress),
|
||||
}
|
||||
|
||||
if sc.Runtime.GRPCInsecure {
|
||||
grpcOpts = append(grpcOpts, grpc.WithInsecure())
|
||||
}
|
||||
|
||||
// create the grpc server
|
||||
s, err := grpc.NewServer(
|
||||
grpcOpts...,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
err = s.Start(ctx)
|
||||
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if sc.HasService("eventscontroller") {
|
||||
// create separate events controller process
|
||||
go func() {
|
||||
ec, err := events.New(
|
||||
events.WithTaskQueue(sc.TaskQueue),
|
||||
events.WithRepository(sc.Repository),
|
||||
events.WithLogger(sc.Logger),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
err = ec.Start(ctx)
|
||||
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if sc.HasService("jobscontroller") {
|
||||
// create separate jobs controller process
|
||||
go func() {
|
||||
jc, err := jobs.New(
|
||||
jobs.WithTaskQueue(sc.TaskQueue),
|
||||
jobs.WithRepository(sc.Repository),
|
||||
jobs.WithLogger(sc.Logger),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
err = jc.Start(ctx)
|
||||
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if sc.HasService("workflowscontroller") {
|
||||
// create separate jobs controller process
|
||||
go func() {
|
||||
jc, err := workflows.New(
|
||||
workflows.WithTaskQueue(sc.TaskQueue),
|
||||
workflows.WithRepository(sc.Repository),
|
||||
workflows.WithLogger(sc.Logger),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
err = jc.Start(ctx)
|
||||
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if sc.HasService("ticker") {
|
||||
// create a ticker
|
||||
go func() {
|
||||
t, err := ticker.New(
|
||||
ticker.WithTaskQueue(sc.TaskQueue),
|
||||
ticker.WithRepository(sc.Repository),
|
||||
ticker.WithLogger(sc.Logger),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
err = t.Start(ctx)
|
||||
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if sc.HasService("heartbeater") {
|
||||
go func() {
|
||||
h, err := heartbeat.New(
|
||||
heartbeat.WithTaskQueue(sc.TaskQueue),
|
||||
heartbeat.WithRepository(sc.Repository),
|
||||
heartbeat.WithLogger(sc.Logger),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
err = h.Start(ctx)
|
||||
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
Loop:
|
||||
for {
|
||||
select {
|
||||
case err := <-errCh:
|
||||
fmt.Fprintf(os.Stderr, "%s", err)
|
||||
|
||||
// exit with non-zero exit code
|
||||
os.Exit(1) //nolint:gocritic
|
||||
case <-interruptCh:
|
||||
break Loop
|
||||
}
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
err = sc.Disconnect()
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
//go:build e2e
|
||||
//go:build load
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -44,23 +47,48 @@ func TestLoadCLI(t *testing.T) {
|
||||
concurrency: 0,
|
||||
},
|
||||
}}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
|
||||
setup := sync.WaitGroup{}
|
||||
|
||||
go func() {
|
||||
setup.Add(1)
|
||||
log.Printf("setup start")
|
||||
testutils.SetupEngine(ctx, t)
|
||||
setup.Done()
|
||||
log.Printf("setup end")
|
||||
}()
|
||||
|
||||
// TODO instead of waiting, figure out when the engine setup is complete
|
||||
time.Sleep(10 * time.Second)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer func() {
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
goleak.VerifyNone(
|
||||
t,
|
||||
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
|
||||
goleak.IgnoreTopFunction("google.golang.org/grpc/internal/grpcsync.(*CallbackSerializer).run"),
|
||||
goleak.IgnoreTopFunction("internal/poll.runtime_pollWait"),
|
||||
goleak.IgnoreTopFunction("google.golang.org/grpc/internal/transport.(*controlBuffer).get"),
|
||||
)
|
||||
}()
|
||||
|
||||
if err := do(tt.args.duration, tt.args.eventsPerSecond, tt.args.delay, tt.args.wait, tt.args.concurrency); (err != nil) != tt.wantErr {
|
||||
t.Errorf("do() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
log.Printf("test complete")
|
||||
setup.Wait()
|
||||
log.Printf("cleanup complete")
|
||||
|
||||
goleak.VerifyNone(
|
||||
t,
|
||||
// worker
|
||||
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
|
||||
goleak.IgnoreTopFunction("google.golang.org/grpc/internal/grpcsync.(*CallbackSerializer).run"),
|
||||
goleak.IgnoreTopFunction("internal/poll.runtime_pollWait"),
|
||||
goleak.IgnoreTopFunction("google.golang.org/grpc/internal/transport.(*controlBuffer).get"),
|
||||
// all engine related packages
|
||||
goleak.IgnoreTopFunction("github.com/jackc/pgx/v5/pgxpool.(*Pool).backgroundHealthCheck"),
|
||||
goleak.IgnoreTopFunction("github.com/rabbitmq/amqp091-go.(*Connection).heartbeater"),
|
||||
goleak.IgnoreTopFunction("github.com/rabbitmq/amqp091-go.(*consumers).buffer"),
|
||||
goleak.IgnoreTopFunction("google.golang.org/grpc/internal/transport.(*http2Server).keepalive"),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
)
|
||||
|
||||
func TestMiddleware(t *testing.T) {
|
||||
t.Skip()
|
||||
testutils.Prepare(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
||||
|
||||
2
go.mod
2
go.mod
@@ -13,6 +13,7 @@ require (
|
||||
github.com/gorilla/sessions v1.2.2
|
||||
github.com/hashicorp/go-multierror v1.1.1
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7
|
||||
github.com/jackc/pgx-zerolog v0.0.0-20230315001418-f978528409eb
|
||||
github.com/jackc/pgx/v5 v5.5.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/labstack/echo/v4 v4.11.3
|
||||
@@ -100,7 +101,6 @@ require (
|
||||
github.com/hashicorp/errwrap v1.0.0 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/invopop/jsonschema v0.12.0
|
||||
github.com/jackc/pgx-zerolog v0.0.0-20230315001418-f978528409eb
|
||||
github.com/magiconair/properties v1.8.7 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
|
||||
@@ -5,6 +5,7 @@ package loader
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -17,6 +18,7 @@ import (
|
||||
"github.com/hatchet-dev/hatchet/internal/auth/cookie"
|
||||
"github.com/hatchet-dev/hatchet/internal/auth/oauth"
|
||||
"github.com/hatchet-dev/hatchet/internal/auth/token"
|
||||
clientconfig "github.com/hatchet-dev/hatchet/internal/config/client"
|
||||
"github.com/hatchet-dev/hatchet/internal/config/database"
|
||||
"github.com/hatchet-dev/hatchet/internal/config/loader/loaderutils"
|
||||
"github.com/hatchet-dev/hatchet/internal/config/server"
|
||||
@@ -30,8 +32,6 @@ import (
|
||||
"github.com/hatchet-dev/hatchet/internal/taskqueue/rabbitmq"
|
||||
"github.com/hatchet-dev/hatchet/internal/validator"
|
||||
"github.com/hatchet-dev/hatchet/pkg/client"
|
||||
|
||||
clientconfig "github.com/hatchet-dev/hatchet/internal/config/client"
|
||||
)
|
||||
|
||||
// LoadDatabaseConfigFile loads the database config file via viper
|
||||
@@ -81,24 +81,24 @@ func (c *ConfigLoader) LoadDatabaseConfig() (res *database.Config, err error) {
|
||||
}
|
||||
|
||||
// LoadServerConfig loads the server configuration
|
||||
func (c *ConfigLoader) LoadServerConfig() (res *server.ServerConfig, err error) {
|
||||
func (c *ConfigLoader) LoadServerConfig() (cleanup func() error, res *server.ServerConfig, err error) {
|
||||
log.Printf("Loading server config from %s", c.directory)
|
||||
sharedFilePath := filepath.Join(c.directory, "server.yaml")
|
||||
configFileBytes, err := loaderutils.GetConfigBytes(sharedFilePath)
|
||||
log.Printf("Shared file path: %s", sharedFilePath)
|
||||
|
||||
configFileBytes, err := loaderutils.GetConfigBytes(sharedFilePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
dc, err := c.LoadDatabaseConfig()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cf, err := LoadServerConfigFile(configFileBytes...)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return GetServerConfigFromConfigfile(dc, cf)
|
||||
@@ -119,16 +119,15 @@ func GetDatabaseConfigFromConfigFile(cf *database.ConfigFile) (res *database.Con
|
||||
|
||||
os.Setenv("DATABASE_URL", databaseUrl)
|
||||
|
||||
client := db.NewClient(
|
||||
c := db.NewClient(
|
||||
// db.WithDatasourceURL(databaseUrl),
|
||||
)
|
||||
|
||||
if err := client.Prisma.Connect(); err != nil {
|
||||
if err := c.Prisma.Connect(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config, err := pgxpool.ParseConfig(databaseUrl)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -143,25 +142,24 @@ func GetDatabaseConfigFromConfigFile(cf *database.ConfigFile) (res *database.Con
|
||||
config.MaxConns = 20
|
||||
|
||||
pool, err := pgxpool.NewWithConfig(context.Background(), config)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not connect to database: %w", err)
|
||||
}
|
||||
|
||||
return &database.Config{
|
||||
Disconnect: client.Prisma.Disconnect,
|
||||
Repository: prisma.NewPrismaRepository(client, pool, prisma.WithLogger(&l)),
|
||||
Disconnect: c.Prisma.Disconnect,
|
||||
Repository: prisma.NewPrismaRepository(c, pool, prisma.WithLogger(&l)),
|
||||
Seed: cf.Seed,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigFile) (res *server.ServerConfig, err error) {
|
||||
func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigFile) (cleanup func() error, res *server.ServerConfig, err error) {
|
||||
l := logger.NewStdErr(&cf.Logger, "server")
|
||||
|
||||
tls, err := loaderutils.LoadServerTLSConfig(&cf.TLS)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not load TLS config: %w", err)
|
||||
return nil, nil, fmt.Errorf("could not load TLS config: %w", err)
|
||||
}
|
||||
|
||||
ss, err := cookie.NewUserSessionStore(
|
||||
@@ -173,11 +171,10 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create session store: %w", err)
|
||||
return nil, nil, fmt.Errorf("could not create session store: %w", err)
|
||||
}
|
||||
|
||||
tq := rabbitmq.New(
|
||||
context.Background(),
|
||||
cleanup1, tq := rabbitmq.New(
|
||||
rabbitmq.WithURL(cf.TaskQueue.RabbitMQ.URL),
|
||||
rabbitmq.WithLogger(&l),
|
||||
)
|
||||
@@ -188,7 +185,7 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create ingestor: %w", err)
|
||||
return nil, nil, fmt.Errorf("could not create ingestor: %w", err)
|
||||
}
|
||||
|
||||
auth := server.AuthConfig{
|
||||
@@ -197,11 +194,11 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF
|
||||
|
||||
if cf.Auth.Google.Enabled {
|
||||
if cf.Auth.Google.ClientID == "" {
|
||||
return nil, fmt.Errorf("google client id is required")
|
||||
return nil, nil, fmt.Errorf("google client id is required")
|
||||
}
|
||||
|
||||
if cf.Auth.Google.ClientSecret == "" {
|
||||
return nil, fmt.Errorf("google client secret is required")
|
||||
return nil, nil, fmt.Errorf("google client secret is required")
|
||||
}
|
||||
|
||||
gClient := oauth.NewGoogleClient(&oauth.Config{
|
||||
@@ -217,7 +214,7 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF
|
||||
encryptionSvc, err := loadEncryptionSvc(cf)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not load encryption service: %w", err)
|
||||
return nil, nil, fmt.Errorf("could not load encryption service: %w", err)
|
||||
}
|
||||
|
||||
// create a new JWT manager
|
||||
@@ -229,7 +226,7 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create JWT manager: %w", err)
|
||||
return nil, nil, fmt.Errorf("could not create JWT manager: %w", err)
|
||||
}
|
||||
|
||||
vcsProviders := make(map[vcs.VCSRepositoryKind]vcs.VCSProvider)
|
||||
@@ -252,7 +249,7 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
githubProvider := github.NewGithubVCSProvider(githubAppConf, dc.Repository, cf.Runtime.ServerURL, encryptionSvc)
|
||||
@@ -267,20 +264,20 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF
|
||||
internalTenant, err := dc.Repository.Tenant().GetTenantBySlug("internal")
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not get internal tenant: %w", err)
|
||||
return nil, nil, fmt.Errorf("could not get internal tenant: %w", err)
|
||||
}
|
||||
|
||||
tokenSuffix, err := encryption.GenerateRandomBytes(4)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not generate token suffix: %w", err)
|
||||
return nil, nil, fmt.Errorf("could not generate token suffix: %w", err)
|
||||
}
|
||||
|
||||
// generate a token for the internal client
|
||||
token, err := auth.JWTManager.GenerateTenantToken(internalTenant.ID, fmt.Sprintf("internal-%s", tokenSuffix))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not generate internal token: %w", err)
|
||||
return nil, nil, fmt.Errorf("could not generate internal token: %w", err)
|
||||
}
|
||||
|
||||
internalClient, err = client.NewFromConfigFile(
|
||||
@@ -291,11 +288,19 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create internal client: %w", err)
|
||||
return nil, nil, fmt.Errorf("could not create internal client: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &server.ServerConfig{
|
||||
cleanup = func() error {
|
||||
log.Printf("cleaning up server config")
|
||||
if err := cleanup1(); err != nil {
|
||||
return fmt.Errorf("error cleaning up rabbitmq: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return cleanup, &server.ServerConfig{
|
||||
Runtime: cf.Runtime,
|
||||
Auth: auth,
|
||||
Encryption: encryptionSvc,
|
||||
|
||||
@@ -3,6 +3,7 @@ package events
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/sync/errgroup"
|
||||
@@ -94,25 +95,41 @@ func New(fs ...EventsControllerOpt) (*EventsControllerImpl, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ec *EventsControllerImpl) Start(ctx context.Context) error {
|
||||
taskChan, err := ec.tq.Subscribe(ctx, taskqueue.EVENT_PROCESSING_QUEUE)
|
||||
func (ec *EventsControllerImpl) Start() (func() error, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
cleanupQueue, taskChan, err := ec.tq.Subscribe(taskqueue.EVENT_PROCESSING_QUEUE)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
cancel()
|
||||
return nil, fmt.Errorf("could not subscribe to event processing queue: %w", err)
|
||||
}
|
||||
|
||||
// TODO: close when ctx is done
|
||||
for task := range taskChan {
|
||||
go func(task *taskqueue.Task) {
|
||||
err = ec.handleTask(ctx, task)
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
if err != nil {
|
||||
ec.l.Error().Err(err).Msgf("could not handle event task %s", task.ID)
|
||||
}
|
||||
}(task)
|
||||
go func() {
|
||||
for task := range taskChan {
|
||||
wg.Add(1)
|
||||
go func(task *taskqueue.Task) {
|
||||
defer wg.Done()
|
||||
err = ec.handleTask(ctx, task)
|
||||
|
||||
if err != nil {
|
||||
ec.l.Error().Err(err).Msgf("could not handle event task %s", task.ID)
|
||||
}
|
||||
}(task)
|
||||
}
|
||||
}()
|
||||
|
||||
cleanup := func() error {
|
||||
cancel()
|
||||
if err := cleanupQueue(); err != nil {
|
||||
return fmt.Errorf("could not cleanup event processing queue: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
return cleanup, nil
|
||||
}
|
||||
|
||||
func (ec *EventsControllerImpl) handleTask(ctx context.Context, task *taskqueue.Task) error {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
@@ -102,25 +103,40 @@ func New(fs ...JobsControllerOpt) (*JobsControllerImpl, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (jc *JobsControllerImpl) Start(ctx context.Context) error {
|
||||
taskChan, err := jc.tq.Subscribe(ctx, taskqueue.JOB_PROCESSING_QUEUE)
|
||||
func (jc *JobsControllerImpl) Start() (func() error, error) {
|
||||
cleanupQueue, taskChan, err := jc.tq.Subscribe(taskqueue.JOB_PROCESSING_QUEUE)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, fmt.Errorf("could not subscribe to job processing queue: %w", err)
|
||||
}
|
||||
|
||||
// TODO: close when ctx is done
|
||||
for task := range taskChan {
|
||||
go func(task *taskqueue.Task) {
|
||||
err = jc.handleTask(ctx, task)
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
if err != nil {
|
||||
jc.l.Error().Err(err).Msg("could not handle job task")
|
||||
}
|
||||
}(task)
|
||||
go func() {
|
||||
for task := range taskChan {
|
||||
wg.Add(1)
|
||||
go func(task *taskqueue.Task) {
|
||||
defer wg.Done()
|
||||
err = jc.handleTask(context.Background(), task)
|
||||
|
||||
if err != nil {
|
||||
jc.l.Error().Err(err).Msg("could not handle job task")
|
||||
}
|
||||
}(task)
|
||||
}
|
||||
}()
|
||||
|
||||
cleanup := func() error {
|
||||
if err := cleanupQueue(); err != nil {
|
||||
return fmt.Errorf("could not cleanup job processing queue: %w", err)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
return cleanup, nil
|
||||
}
|
||||
|
||||
func (ec *JobsControllerImpl) handleTask(ctx context.Context, task *taskqueue.Task) error {
|
||||
|
||||
@@ -3,6 +3,7 @@ package workflows
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
@@ -95,27 +96,40 @@ func New(fs ...WorkflowsControllerOpt) (*WorkflowsControllerImpl, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (wc *WorkflowsControllerImpl) Start(ctx context.Context) error {
|
||||
func (wc *WorkflowsControllerImpl) Start() (func() error, error) {
|
||||
wc.l.Debug().Msg("starting workflows controller")
|
||||
|
||||
taskChan, err := wc.tq.Subscribe(ctx, taskqueue.WORKFLOW_PROCESSING_QUEUE)
|
||||
cleanupQueue, taskChan, err := wc.tq.Subscribe(taskqueue.WORKFLOW_PROCESSING_QUEUE)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: close when ctx is done
|
||||
for task := range taskChan {
|
||||
go func(task *taskqueue.Task) {
|
||||
err = wc.handleTask(ctx, task)
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
if err != nil {
|
||||
wc.l.Error().Err(err).Msg("could not handle job task")
|
||||
}
|
||||
}(task)
|
||||
go func() {
|
||||
for task := range taskChan {
|
||||
wg.Add(1)
|
||||
go func(task *taskqueue.Task) {
|
||||
defer wg.Done()
|
||||
err = wc.handleTask(context.Background(), task)
|
||||
|
||||
if err != nil {
|
||||
wc.l.Error().Err(err).Msg("could not handle job task")
|
||||
}
|
||||
}(task)
|
||||
}
|
||||
}()
|
||||
|
||||
cleanup := func() error {
|
||||
if err := cleanupQueue(); err != nil {
|
||||
return fmt.Errorf("could not cleanup queue: %w", err)
|
||||
}
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
return cleanup, nil
|
||||
}
|
||||
|
||||
func (wc *WorkflowsControllerImpl) handleTask(ctx context.Context, task *taskqueue.Task) error {
|
||||
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
|
||||
type Dispatcher interface {
|
||||
contracts.DispatcherServer
|
||||
Start(ctx context.Context) error
|
||||
Start() (func() error, error)
|
||||
}
|
||||
|
||||
type DispatcherImpl struct {
|
||||
@@ -122,21 +122,24 @@ func New(fs ...DispatcherOpt) (*DispatcherImpl, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *DispatcherImpl) Start(ctx context.Context) error {
|
||||
func (d *DispatcherImpl) Start() (func() error, error) {
|
||||
// register the dispatcher by creating a new dispatcher in the database
|
||||
dispatcher, err := d.repo.Dispatcher().CreateNewDispatcher(&repository.CreateDispatcherOpts{
|
||||
ID: d.dispatcherId,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// subscribe to a task queue with the dispatcher id
|
||||
taskChan, err := d.tq.Subscribe(ctx, taskqueue.QueueTypeFromDispatcherID(dispatcher.ID))
|
||||
cleanupQueue, taskChan, err := d.tq.Subscribe(taskqueue.QueueTypeFromDispatcherID(dispatcher.ID))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
cancel()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = d.s.NewJob(
|
||||
@@ -147,42 +150,64 @@ func (d *DispatcherImpl) Start(ctx context.Context) error {
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not schedule heartbeat update: %w", err)
|
||||
cancel()
|
||||
return nil, fmt.Errorf("could not schedule heartbeat update: %w", err)
|
||||
}
|
||||
|
||||
d.s.Start()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// drain the existing connections
|
||||
d.l.Debug().Msg("draining existing connections")
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case task := <-taskChan:
|
||||
go func(task *taskqueue.Task) {
|
||||
err = d.handleTask(ctx, task)
|
||||
|
||||
d.workers.Range(func(key, value interface{}) bool {
|
||||
w := value.(subscribedWorker)
|
||||
|
||||
w.finished <- true
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
err = d.repo.Dispatcher().Delete(dispatcher.ID)
|
||||
|
||||
if err == nil {
|
||||
d.l.Debug().Msgf("deleted dispatcher %s", dispatcher.ID)
|
||||
if err != nil {
|
||||
d.l.Error().Err(err).Msgf("could not handle dispatcher task %s", task.ID)
|
||||
}
|
||||
}(task)
|
||||
}
|
||||
|
||||
return err
|
||||
case task := <-taskChan:
|
||||
go func(task *taskqueue.Task) {
|
||||
err = d.handleTask(ctx, task)
|
||||
|
||||
if err != nil {
|
||||
d.l.Error().Err(err).Msgf("could not handle dispatcher task %s", task.ID)
|
||||
}
|
||||
}(task)
|
||||
}
|
||||
}()
|
||||
|
||||
cleanup := func() error {
|
||||
d.l.Debug().Msgf("dispatcher is shutting down...")
|
||||
cancel()
|
||||
|
||||
if err := cleanupQueue(); err != nil {
|
||||
return fmt.Errorf("could not cleanup queue: %w", err)
|
||||
}
|
||||
|
||||
// drain the existing connections
|
||||
d.l.Debug().Msg("draining existing connections")
|
||||
|
||||
d.workers.Range(func(key, value interface{}) bool {
|
||||
w := value.(subscribedWorker)
|
||||
|
||||
w.finished <- true
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
err = d.repo.Dispatcher().Delete(dispatcher.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not delete dispatcher: %w", err)
|
||||
}
|
||||
|
||||
d.l.Debug().Msgf("deleted dispatcher %s", dispatcher.ID)
|
||||
|
||||
if err := d.s.Shutdown(); err != nil {
|
||||
return fmt.Errorf("could not shutdown scheduler: %w", err)
|
||||
}
|
||||
|
||||
d.l.Debug().Msgf("dispatcher has shutdown")
|
||||
return nil
|
||||
}
|
||||
|
||||
return cleanup, nil
|
||||
}
|
||||
|
||||
func (d *DispatcherImpl) handleTask(ctx context.Context, task *taskqueue.Task) error {
|
||||
@@ -348,7 +373,7 @@ func (d *DispatcherImpl) runUpdateHeartbeat(ctx context.Context) func() {
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
d.l.Err(err).Msg("could not update heartbeat")
|
||||
d.l.Err(err).Msg("dispatcher: could not update heartbeat")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/steebchen/prisma-client-go/runtime/types"
|
||||
@@ -301,19 +302,27 @@ func (s *DispatcherImpl) SubscribeToWorkflowEvents(request *contracts.SubscribeT
|
||||
defer cancel()
|
||||
|
||||
// subscribe to the task queue for the tenant
|
||||
taskChan, err := s.tq.Subscribe(ctx, q)
|
||||
cleanupQueue, taskChan, err := s.tq.Subscribe(q)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if err := cleanupQueue(); err != nil {
|
||||
return fmt.Errorf("could not cleanup queue: %w", err)
|
||||
}
|
||||
// drain the existing connections
|
||||
wg.Wait()
|
||||
return nil
|
||||
case task := <-taskChan:
|
||||
wg.Add(1)
|
||||
go func(task *taskqueue.Task) {
|
||||
defer wg.Done()
|
||||
e, err := s.tenantTaskToWorkflowEvent(task, tenant.ID, request.WorkflowRunId)
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -14,6 +13,7 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/hatchet-dev/hatchet/internal/config/server"
|
||||
"github.com/hatchet-dev/hatchet/internal/logger"
|
||||
@@ -24,8 +24,6 @@ import (
|
||||
"github.com/hatchet-dev/hatchet/internal/services/grpc/middleware"
|
||||
"github.com/hatchet-dev/hatchet/internal/services/ingestor"
|
||||
eventcontracts "github.com/hatchet-dev/hatchet/internal/services/ingestor/contracts"
|
||||
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
@@ -155,17 +153,17 @@ func NewServer(fs ...ServerOpt) (*Server, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) Start(ctx context.Context) error {
|
||||
return s.startGRPC(ctx)
|
||||
func (s *Server) Start() (func() error, error) {
|
||||
return s.startGRPC()
|
||||
}
|
||||
|
||||
func (s *Server) startGRPC(ctx context.Context) error {
|
||||
func (s *Server) startGRPC() (func() error, error) {
|
||||
s.l.Debug().Msgf("starting grpc server on %s:%d", s.bindAddress, s.port)
|
||||
|
||||
lis, err := net.Listen("tcp", fmt.Sprintf("%s:%d", s.bindAddress, s.port))
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen: %w", err)
|
||||
return nil, fmt.Errorf("failed to listen: %w", err)
|
||||
}
|
||||
|
||||
serverOpts := []grpc.ServerOption{}
|
||||
@@ -221,6 +219,16 @@ func (s *Server) startGRPC(ctx context.Context) error {
|
||||
admincontracts.RegisterWorkflowServiceServer(grpcServer, s.admin)
|
||||
}
|
||||
|
||||
// Start listening
|
||||
return grpcServer.Serve(lis)
|
||||
go func() {
|
||||
if err := grpcServer.Serve(lis); err != nil {
|
||||
panic(fmt.Errorf("failed to serve: %w", err))
|
||||
}
|
||||
}()
|
||||
|
||||
cleanup := func() error {
|
||||
grpcServer.GracefulStop()
|
||||
return nil
|
||||
}
|
||||
|
||||
return cleanup, nil
|
||||
}
|
||||
|
||||
@@ -89,25 +89,30 @@ func New(fs ...HeartbeaterOpt) (*HeartbeaterImpl, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *HeartbeaterImpl) Start(ctx context.Context) error {
|
||||
func (t *HeartbeaterImpl) Start() (func() error, error) {
|
||||
t.l.Debug().Msg("starting heartbeater")
|
||||
|
||||
_, err := t.s.NewJob(
|
||||
gocron.DurationJob(time.Second*5),
|
||||
gocron.NewTask(
|
||||
t.removeStaleTickers(ctx),
|
||||
t.removeStaleTickers(),
|
||||
),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not schedule ticker removal: %w", err)
|
||||
return nil, fmt.Errorf("could not schedule ticker removal: %w", err)
|
||||
}
|
||||
|
||||
t.s.Start()
|
||||
|
||||
for range ctx.Done() {
|
||||
cleanup := func() error {
|
||||
t.l.Debug().Msg("stopping heartbeater")
|
||||
if err := t.s.Shutdown(); err != nil {
|
||||
return fmt.Errorf("could not shutdown scheduler: %w", err)
|
||||
}
|
||||
t.l.Debug().Msg("heartbeater has shutdown")
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
return cleanup, nil
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/hatchet-dev/hatchet/internal/taskqueue"
|
||||
)
|
||||
|
||||
func (t *HeartbeaterImpl) removeStaleTickers(ctx context.Context) func() {
|
||||
func (t *HeartbeaterImpl) removeStaleTickers() func() {
|
||||
return func() {
|
||||
t.l.Debug().Msg("removing old tickers")
|
||||
|
||||
@@ -40,7 +40,7 @@ func (t *HeartbeaterImpl) removeStaleTickers(ctx context.Context) func() {
|
||||
|
||||
// send a task to the job processing queue that the ticker is removed
|
||||
err = t.tq.AddTask(
|
||||
ctx,
|
||||
context.Background(),
|
||||
taskqueue.JOB_PROCESSING_QUEUE,
|
||||
tickerRemoved(ticker.ID),
|
||||
)
|
||||
|
||||
@@ -35,12 +35,8 @@ func (t *TickerImpl) handleScheduleJobRunTimeout(ctx context.Context, task *task
|
||||
}
|
||||
|
||||
// schedule the timeout
|
||||
childCtx, cancel := context.WithDeadline(context.Background(), timeoutAt)
|
||||
|
||||
go func() {
|
||||
<-childCtx.Done()
|
||||
t.runJobRunTimeout(metadata.TenantId, payload.JobRunId)
|
||||
}()
|
||||
// TODO: ??? make sure this doesn't have any side effects
|
||||
childCtx, cancel := context.WithDeadline(ctx, timeoutAt)
|
||||
|
||||
// store the schedule in the step run map
|
||||
t.jobRuns.Store(payload.JobRunId, &timeoutCtx{
|
||||
@@ -48,6 +44,16 @@ func (t *TickerImpl) handleScheduleJobRunTimeout(ctx context.Context, task *task
|
||||
cancel: cancel,
|
||||
})
|
||||
|
||||
go func() {
|
||||
<-childCtx.Done()
|
||||
t.runJobRunTimeout(metadata.TenantId, payload.JobRunId)
|
||||
t.jobRuns.Range(func(key, value interface{}) bool {
|
||||
v, _ := value.(*timeoutCtx)
|
||||
v.cancel()
|
||||
return true
|
||||
})
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -115,7 +115,9 @@ func New(fs ...TickerOpt) (*TickerImpl, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *TickerImpl) Start(ctx context.Context) error {
|
||||
func (t *TickerImpl) Start() (func() error, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
t.l.Debug().Msgf("starting ticker %s", t.tickerId)
|
||||
|
||||
// register the ticker
|
||||
@@ -124,14 +126,16 @@ func (t *TickerImpl) Start(ctx context.Context) error {
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
cancel()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// subscribe to a task queue with the dispatcher id
|
||||
taskChan, err := t.tq.Subscribe(ctx, taskqueue.QueueTypeFromTickerID(ticker.ID))
|
||||
cleanupQueue, taskChan, err := t.tq.Subscribe(taskqueue.QueueTypeFromTickerID(ticker.ID))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
cancel()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = t.s.NewJob(
|
||||
@@ -142,7 +146,8 @@ func (t *TickerImpl) Start(ctx context.Context) error {
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not schedule step run requeue: %w", err)
|
||||
cancel()
|
||||
return nil, fmt.Errorf("could not schedule step run requeue: %w", err)
|
||||
}
|
||||
|
||||
_, err = t.s.NewJob(
|
||||
@@ -153,7 +158,8 @@ func (t *TickerImpl) Start(ctx context.Context) error {
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not schedule get group key run requeue: %w", err)
|
||||
cancel()
|
||||
return nil, fmt.Errorf("could not schedule get group key run requeue: %w", err)
|
||||
}
|
||||
|
||||
_, err = t.s.NewJob(
|
||||
@@ -165,35 +171,13 @@ func (t *TickerImpl) Start(ctx context.Context) error {
|
||||
|
||||
t.s.Start()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.l.Debug().Msg("removing ticker")
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
// delete the ticker
|
||||
err = t.repo.Ticker().Delete(t.tickerId)
|
||||
|
||||
if err != nil {
|
||||
t.l.Err(err).Msg("could not delete ticker")
|
||||
return err
|
||||
}
|
||||
|
||||
// add the task after the ticker is deleted
|
||||
err := t.tq.AddTask(
|
||||
ctx,
|
||||
taskqueue.JOB_PROCESSING_QUEUE,
|
||||
tickerRemoved(t.tickerId),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.l.Err(err).Msg("could not add ticker removed task")
|
||||
return err
|
||||
}
|
||||
|
||||
// return err
|
||||
return nil
|
||||
case task := <-taskChan:
|
||||
go func() {
|
||||
for task := range taskChan {
|
||||
wg.Add(1)
|
||||
go func(task *taskqueue.Task) {
|
||||
defer wg.Done()
|
||||
err = t.handleTask(ctx, task)
|
||||
|
||||
if err != nil {
|
||||
@@ -201,7 +185,47 @@ func (t *TickerImpl) Start(ctx context.Context) error {
|
||||
}
|
||||
}(task)
|
||||
}
|
||||
}()
|
||||
|
||||
cleanup := func() error {
|
||||
t.l.Debug().Msg("removing ticker")
|
||||
|
||||
cancel()
|
||||
|
||||
if err := cleanupQueue(); err != nil {
|
||||
return fmt.Errorf("could not cleanup queue: %w", err)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// delete the ticker
|
||||
err = t.repo.Ticker().Delete(t.tickerId)
|
||||
|
||||
if err != nil {
|
||||
t.l.Err(err).Msg("could not delete ticker")
|
||||
return err
|
||||
}
|
||||
|
||||
// add the task after the ticker is deleted
|
||||
err = t.tq.AddTask(
|
||||
ctx,
|
||||
taskqueue.JOB_PROCESSING_QUEUE,
|
||||
tickerRemoved(t.tickerId),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.l.Err(err).Msg("could not add ticker removed task")
|
||||
return err
|
||||
}
|
||||
|
||||
if err := t.s.Shutdown(); err != nil {
|
||||
return fmt.Errorf("could not shutdown scheduler: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return cleanup, nil
|
||||
}
|
||||
|
||||
func (t *TickerImpl) handleTask(ctx context.Context, task *taskqueue.Task) error {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
lru "github.com/hashicorp/golang-lru/v2"
|
||||
@@ -70,7 +71,9 @@ func WithURL(url string) TaskQueueImplOpt {
|
||||
}
|
||||
|
||||
// New creates a new TaskQueueImpl.
|
||||
func New(ctx context.Context, fs ...TaskQueueImplOpt) *TaskQueueImpl {
|
||||
func New(fs ...TaskQueueImplOpt) (func() error, *TaskQueueImpl) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
opts := defaultTaskQueueImplOpts()
|
||||
|
||||
for _, f := range fs {
|
||||
@@ -99,30 +102,40 @@ func New(ctx context.Context, fs ...TaskQueueImplOpt) *TaskQueueImpl {
|
||||
sub := <-<-sessions
|
||||
if _, err := t.initQueue(sub, taskqueue.EVENT_PROCESSING_QUEUE); err != nil {
|
||||
t.l.Debug().Msgf("error initializing queue: %v", err)
|
||||
return nil
|
||||
cancel()
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if _, err := t.initQueue(sub, taskqueue.JOB_PROCESSING_QUEUE); err != nil {
|
||||
t.l.Debug().Msgf("error initializing queue: %v", err)
|
||||
return nil
|
||||
cancel()
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if _, err := t.initQueue(sub, taskqueue.WORKFLOW_PROCESSING_QUEUE); err != nil {
|
||||
t.l.Debug().Msgf("error initializing queue: %v", err)
|
||||
return nil
|
||||
cancel()
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if _, err := t.initQueue(sub, taskqueue.SCHEDULING_QUEUE); err != nil {
|
||||
t.l.Debug().Msgf("error initializing queue: %v", err)
|
||||
return nil
|
||||
cancel()
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// create publisher go func
|
||||
go func() {
|
||||
t.publish()
|
||||
}()
|
||||
cleanup1 := t.startPublishing()
|
||||
|
||||
return t
|
||||
cleanup := func() error {
|
||||
cancel()
|
||||
if err := cleanup1(); err != nil {
|
||||
return fmt.Errorf("error cleaning up rabbitmq publisher: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return cleanup, t
|
||||
}
|
||||
|
||||
// AddTask adds a task to the queue.
|
||||
@@ -136,12 +149,12 @@ func (t *TaskQueueImpl) AddTask(ctx context.Context, q taskqueue.Queue, task *ta
|
||||
}
|
||||
|
||||
// Subscribe subscribes to the task queue.
|
||||
func (t *TaskQueueImpl) Subscribe(ctx context.Context, q taskqueue.Queue) (<-chan *taskqueue.Task, error) {
|
||||
func (t *TaskQueueImpl) Subscribe(q taskqueue.Queue) (func() error, <-chan *taskqueue.Task, error) {
|
||||
t.l.Debug().Msgf("subscribing to queue: %s", q.Name())
|
||||
|
||||
tasks := make(chan *taskqueue.Task)
|
||||
go t.subscribe(ctx, t.identity, q, t.sessions, t.tasks, tasks)
|
||||
return tasks, nil
|
||||
cleanup := t.subscribe(t.identity, q, t.sessions, t.tasks, tasks)
|
||||
return cleanup, tasks, nil
|
||||
}
|
||||
|
||||
func (t *TaskQueueImpl) RegisterTenant(ctx context.Context, tenantId string) error {
|
||||
@@ -204,107 +217,148 @@ func (t *TaskQueueImpl) initQueue(sub session, q taskqueue.Queue) (string, error
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func (t *TaskQueueImpl) publish() {
|
||||
for session := range t.sessions {
|
||||
pub := <-session
|
||||
func (t *TaskQueueImpl) startPublishing() func() error {
|
||||
ctx, cancel := context.WithCancel(t.ctx)
|
||||
|
||||
for task := range t.tasks {
|
||||
go func(task *taskWithQueue) {
|
||||
body, err := json.Marshal(task)
|
||||
cleanup := func() error {
|
||||
cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.l.Error().Msgf("error marshaling task queue: %v", err)
|
||||
go func() {
|
||||
for session := range t.sessions {
|
||||
pub := <-session
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
t.l.Debug().Msgf("publishing task %s to queue %s", task.ID, task.q.Name())
|
||||
|
||||
err = pub.PublishWithContext(ctx, "", task.q.Name(), false, false, amqp.Publishing{
|
||||
Body: body,
|
||||
})
|
||||
|
||||
// TODO: retry failed delivery on the next session
|
||||
if err != nil {
|
||||
t.l.Error().Msgf("error publishing task: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// if this is a tenant task, publish to the tenant exchange
|
||||
if task.TenantID() != "" {
|
||||
// determine if the tenant exchange exists
|
||||
if _, ok := t.tenantIdCache.Get(task.TenantID()); !ok {
|
||||
// register the tenant exchange
|
||||
err = t.RegisterTenant(ctx, task.TenantID())
|
||||
case task := <-t.tasks:
|
||||
go func(task *taskWithQueue) {
|
||||
body, err := json.Marshal(task)
|
||||
|
||||
if err != nil {
|
||||
t.l.Error().Msgf("error registering tenant exchange: %v", err)
|
||||
t.l.Error().Msgf("error marshaling task queue: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
t.l.Debug().Msgf("publishing tenant task %s to exchange %s", task.ID, task.TenantID())
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = pub.PublishWithContext(ctx, task.TenantID(), "", false, false, amqp.Publishing{
|
||||
Body: body,
|
||||
})
|
||||
t.l.Debug().Msgf("publishing task %s to queue %s", task.ID, task.q.Name())
|
||||
|
||||
if err != nil {
|
||||
t.l.Error().Msgf("error publishing tenant task: %v", err)
|
||||
return
|
||||
}
|
||||
err = pub.PublishWithContext(ctx, "", task.q.Name(), false, false, amqp.Publishing{
|
||||
Body: body,
|
||||
})
|
||||
|
||||
// TODO: retry failed delivery on the next session
|
||||
if err != nil {
|
||||
t.l.Error().Msgf("error publishing task: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// if this is a tenant task, publish to the tenant exchange
|
||||
if task.TenantID() != "" {
|
||||
// determine if the tenant exchange exists
|
||||
if _, ok := t.tenantIdCache.Get(task.TenantID()); !ok {
|
||||
// register the tenant exchange
|
||||
err = t.RegisterTenant(ctx, task.TenantID())
|
||||
|
||||
if err != nil {
|
||||
t.l.Error().Msgf("error registering tenant exchange: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
t.l.Debug().Msgf("publishing tenant task %s to exchange %s", task.ID, task.TenantID())
|
||||
|
||||
err = pub.PublishWithContext(ctx, task.TenantID(), "", false, false, amqp.Publishing{
|
||||
Body: body,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.l.Error().Msgf("error publishing tenant task: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
t.l.Debug().Msgf("published task %s to queue %s", task.ID, task.q.Name())
|
||||
}(task)
|
||||
}
|
||||
|
||||
t.l.Debug().Msgf("published task %s to queue %s", task.ID, task.q.Name())
|
||||
}(task)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return cleanup
|
||||
}
|
||||
|
||||
func (t *TaskQueueImpl) subscribe(ctx context.Context, subId string, q taskqueue.Queue, sessions chan chan session, messages chan *taskWithQueue, tasks chan<- *taskqueue.Task) {
|
||||
func (t *TaskQueueImpl) subscribe(subId string, q taskqueue.Queue, sessions chan chan session, messages chan *taskWithQueue, tasks chan<- *taskqueue.Task) func() error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
sessionCount := 0
|
||||
|
||||
for session := range sessions {
|
||||
sessionCount++
|
||||
sub := <-session
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
// we initialize the queue here because exclusive queues are bound to the session/connection. however, it's not clear
|
||||
// if the exclusive queue will be available to the next session.
|
||||
queueName, err := t.initQueue(sub, q)
|
||||
go func() {
|
||||
for session := range sessions {
|
||||
sessionCount++
|
||||
sub := <-session
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// we initialize the queue here because exclusive queues are bound to the session/connection. however, it's not clear
|
||||
// if the exclusive queue will be available to the next session.
|
||||
queueName, err := t.initQueue(sub, q)
|
||||
|
||||
deliveries, err := sub.Consume(queueName, subId, false, q.Exclusive(), false, false, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.l.Error().Msgf("cannot consume from: %s, %v", queueName, err)
|
||||
return
|
||||
}
|
||||
deliveries, err := sub.Consume(queueName, subId, false, q.Exclusive(), false, false, nil)
|
||||
|
||||
for msg := range deliveries {
|
||||
go func(msg amqp.Delivery) {
|
||||
task := &taskWithQueue{}
|
||||
if err != nil {
|
||||
t.l.Error().Msgf("cannot consume from: %s, %v", queueName, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(msg.Body, task); err != nil {
|
||||
t.l.Error().Msgf("error unmarshaling message: %v", err)
|
||||
for {
|
||||
select {
|
||||
case msg := <-deliveries:
|
||||
wg.Add(1)
|
||||
go func(msg amqp.Delivery) {
|
||||
defer wg.Done()
|
||||
task := &taskWithQueue{}
|
||||
|
||||
if err := json.Unmarshal(msg.Body, task); err != nil {
|
||||
t.l.Error().Msgf("error unmarshaling message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
t.l.Debug().Msgf("(session: %d) got task: %v", sessionCount, task.ID)
|
||||
|
||||
tasks <- task.Task
|
||||
|
||||
if err := sub.Ack(msg.DeliveryTag, false); err != nil {
|
||||
t.l.Error().Msgf("error acknowledging message: %v", err)
|
||||
return
|
||||
}
|
||||
}(msg)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
t.l.Debug().Msgf("(session: %d) got task: %v", sessionCount, task.ID)
|
||||
|
||||
tasks <- task.Task
|
||||
|
||||
if err := sub.Ack(msg.DeliveryTag, false); err != nil {
|
||||
t.l.Error().Msgf("error acknowledging message: %v", err)
|
||||
return
|
||||
}
|
||||
}(msg)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
cleanup := func() error {
|
||||
cancel()
|
||||
|
||||
t.l.Debug().Msgf("shutting down subscriber: %s", subId)
|
||||
wg.Wait()
|
||||
close(tasks)
|
||||
t.l.Debug().Msgf("successfully shut down subscriber: %s", subId)
|
||||
return nil
|
||||
}
|
||||
|
||||
return cleanup
|
||||
}
|
||||
|
||||
// redial continually connects to the URL, exiting the program when no longer possible
|
||||
|
||||
@@ -21,9 +21,10 @@ func TestTaskQueueIntegration(t *testing.T) {
|
||||
url := "amqp://user:password@localhost:5672/"
|
||||
|
||||
// Initialize the task queue implementation
|
||||
tq := rabbitmq.New(ctx,
|
||||
cleanup, tq := rabbitmq.New(
|
||||
rabbitmq.WithURL(url),
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
require.NotNil(t, tq, "task queue implementation should not be nil")
|
||||
|
||||
@@ -41,7 +42,7 @@ func TestTaskQueueIntegration(t *testing.T) {
|
||||
assert.NoError(t, err, "adding task to static queue should not error")
|
||||
|
||||
// Test subscription to the static queue
|
||||
taskChan, err := tq.Subscribe(ctx, staticQueue)
|
||||
cleanupQueue, taskChan, err := tq.Subscribe(staticQueue)
|
||||
require.NoError(t, err, "subscribing to static queue should not error")
|
||||
|
||||
select {
|
||||
@@ -64,7 +65,7 @@ func TestTaskQueueIntegration(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test subscription to the tenant-specific queue
|
||||
tenantTaskChan, err := tq.Subscribe(ctx, tenantQueue)
|
||||
cleanupTenantQueue, tenantTaskChan, err := tq.Subscribe(tenantQueue)
|
||||
require.NoError(t, err, "subscribing to tenant-specific queue should not error")
|
||||
|
||||
// send task to tenant-specific queue after 1 second to give time for subscriber
|
||||
@@ -77,9 +78,14 @@ func TestTaskQueueIntegration(t *testing.T) {
|
||||
select {
|
||||
case receivedTask := <-tenantTaskChan:
|
||||
assert.Equal(t, task.ID, receivedTask.ID, "received tenant task ID should match sent task ID")
|
||||
break
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timed out waiting for task from tenant-specific queue")
|
||||
break
|
||||
}
|
||||
|
||||
if err := cleanupQueue(); err != nil {
|
||||
t.Fatalf("error cleaning up queue: %v", err)
|
||||
}
|
||||
if err := cleanupTenantQueue(); err != nil {
|
||||
t.Fatalf("error cleaning up queue: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -139,7 +139,7 @@ type TaskQueue interface {
|
||||
AddTask(ctx context.Context, queue Queue, task *Task) error
|
||||
|
||||
// Subscribe subscribes to the task queue.
|
||||
Subscribe(ctx context.Context, queueType Queue) (<-chan *Task, error)
|
||||
Subscribe(queueType Queue) (func() error, <-chan *Task, error)
|
||||
|
||||
// RegisterTenant registers a new pub/sub mechanism for a tenant. This should be called when a
|
||||
// new tenant is created. If this is not called, implementors should ensure that there's a check
|
||||
|
||||
45
internal/testutils/setup.go
Normal file
45
internal/testutils/setup.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/hatchet-dev/hatchet/cmd/hatchet-engine/engine"
|
||||
"github.com/hatchet-dev/hatchet/internal/config/loader"
|
||||
)
|
||||
|
||||
func SetupEngine(ctx context.Context, t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
_, b, _, _ := runtime.Caller(0)
|
||||
testPath := filepath.Dir(b)
|
||||
dir := path.Join(testPath, "../..")
|
||||
|
||||
log.Printf("dir: %s", dir)
|
||||
|
||||
_ = os.Setenv("DATABASE_URL", "postgresql://hatchet:hatchet@127.0.0.1:5431/hatchet")
|
||||
_ = os.Setenv("SERVER_TLS_CERT_FILE", path.Join(dir, "hack/dev/certs/cluster.pem"))
|
||||
_ = os.Setenv("SERVER_TLS_KEY_FILE", path.Join(dir, "hack/dev/certs/cluster.key"))
|
||||
_ = os.Setenv("SERVER_TLS_ROOT_CA_FILE", path.Join(dir, "hack/dev/certs/ca.cert"))
|
||||
_ = os.Setenv("SERVER_PORT", "8080")
|
||||
_ = os.Setenv("SERVER_URL", "https://app.dev.hatchet-tools.com")
|
||||
_ = os.Setenv("SERVER_AUTH_COOKIE_SECRETS", "something something")
|
||||
_ = os.Setenv("SERVER_AUTH_COOKIE_DOMAIN", "app.dev.hatchet-tools.com")
|
||||
_ = os.Setenv("SERVER_AUTH_COOKIE_INSECURE", "false")
|
||||
_ = os.Setenv("SERVER_AUTH_SET_EMAIL_VERIFIED", "true")
|
||||
_ = os.Setenv("SERVER_LOGGER_LEVEL", "debug")
|
||||
_ = os.Setenv("SERVER_LOGGER_FORMAT", "console")
|
||||
_ = os.Setenv("DATABASE_LOGGER_LEVEL", "debug")
|
||||
_ = os.Setenv("DATABASE_LOGGER_FORMAT", "console")
|
||||
|
||||
cf := loader.NewConfigLoader(path.Join(dir, "./generated/"))
|
||||
|
||||
if err := engine.Run(ctx, cf); err != nil {
|
||||
t.Fatalf("engine failure: %s", err.Error())
|
||||
}
|
||||
}
|
||||
@@ -24,7 +24,7 @@ model User {
|
||||
// the user's oauth providers
|
||||
oauthProviders UserOAuth[]
|
||||
|
||||
// The hashed user's password. This is placed in a separate table so that it isn't returned by default.
|
||||
// The hashed user's password. This is placed in a separate table so that it isn't returned by default.
|
||||
password UserPassword?
|
||||
|
||||
// the user's name
|
||||
@@ -196,7 +196,7 @@ model APIToken {
|
||||
tenantId String? @db.Uuid
|
||||
}
|
||||
|
||||
// Event represents an event in the database.
|
||||
// Event represents an event in the database.
|
||||
model Event {
|
||||
// base fields
|
||||
id String @id @unique @default(uuid()) @db.Uuid
|
||||
@@ -319,7 +319,7 @@ model WorkflowVersion {
|
||||
// concurrency limits for the workflow
|
||||
concurrency WorkflowConcurrency?
|
||||
|
||||
// the declared jobs
|
||||
// the declared jobs
|
||||
jobs Job[]
|
||||
|
||||
// all runs for the workflow
|
||||
@@ -819,7 +819,7 @@ model StepRun {
|
||||
// the run output
|
||||
output Json?
|
||||
|
||||
// inputSchema is a JSON object which declares a JSON schema for the input data
|
||||
// inputSchema is a JSON object which declares a JSON schema for the input data
|
||||
inputSchema Json?
|
||||
|
||||
// when the step should be requeued
|
||||
|
||||
Reference in New Issue
Block a user