refactor: separate api and engine repositories, change ticker logic (#281)

* refactor: separate api and engine repositories, change ticker logic

* fix: nil error blocks

* fix: run migration on load test

* fix: generate db package in load test

* fix: test.yml

* fix: add pnpm to load test

* fix: don't lock CTEs with columns that don't get updated

* fix: update heartbeat for worker every 4 seconds, not 5

* chore: remove dead code

* chore: update python sdk

* chore: add back telemetry attributes
This commit is contained in:
abelanger5
2024-03-21 11:10:34 -07:00
committed by GitHub
parent f82cfb4eef
commit 092f54c64f
151 changed files with 4365 additions and 5451 deletions
+9 -2
View File
@@ -113,7 +113,7 @@ jobs:
- name: Generate
run: |
go run github.com/steebchen/prisma-client-go db push
task generate-all
task generate-certs
task generate-local-encryption-keys
@@ -235,6 +235,12 @@ jobs:
with:
go-version: "1.21"
- name: Setup pnpm
uses: pnpm/action-setup@v2
with:
version: 8
run_install: false
- name: Compose
run: docker compose up -d
@@ -243,7 +249,8 @@ jobs:
- name: Generate
run: |
go run github.com/steebchen/prisma-client-go db push
go run github.com/steebchen/prisma-client-go migrate deploy
task generate-all
task generate-certs
task generate-local-encryption-keys
-15
View File
@@ -7,8 +7,6 @@ import "google/protobuf/timestamp.proto";
service EventsService {
rpc Push(PushEventRequest) returns (Event) {}
rpc List(ListEventRequest) returns (ListEventResponse) {}
rpc ReplaySingleEvent(ReplayEventRequest) returns (Event) {}
rpc PutLog(PutLogRequest) returns (PutLogResponse) {}
@@ -61,19 +59,6 @@ message PushEventRequest {
google.protobuf.Timestamp eventTimestamp = 3;
}
message ListEventRequest {
// (optional) the number of events to skip
int32 offset = 1;
// (optional) the key for the event
string key = 2;
}
message ListEventResponse {
// the events
repeated Event events = 1;
}
message ReplayEventRequest {
// the event id to replay
string eventId = 1;
+2 -76
View File
@@ -3,17 +3,12 @@ syntax = "proto3";
option go_package = "github.com/hatchet-dev/hatchet/internal/services/admin/contracts";
import "google/protobuf/timestamp.proto";
import "google/protobuf/wrappers.proto"; // For optional fields
// WorkflowService represents a set of RPCs for managing workflows.
service WorkflowService {
rpc ListWorkflows(ListWorkflowsRequest) returns (ListWorkflowsResponse);
rpc PutWorkflow(PutWorkflowRequest) returns (WorkflowVersion);
rpc ScheduleWorkflow(ScheduleWorkflowRequest) returns (WorkflowVersion);
rpc TriggerWorkflow(TriggerWorkflowRequest) returns (TriggerWorkflowResponse);
rpc GetWorkflowByName(GetWorkflowByNameRequest) returns (Workflow);
rpc ListWorkflowsForEvent(ListWorkflowsForEventRequest) returns (ListWorkflowsResponse);
rpc DeleteWorkflow(DeleteWorkflowRequest) returns (Workflow);
}
message PutWorkflowRequest {
@@ -69,34 +64,13 @@ message CreateWorkflowStepOpts {
message ListWorkflowsRequest {}
message ScheduleWorkflowRequest {
string workflow_id = 1;
string name = 1;
repeated google.protobuf.Timestamp schedules = 2;
// (optional) the input data for the workflow
string input = 3;
}
// ListWorkflowsResponse is the response for ListWorkflows.
message ListWorkflowsResponse {
repeated Workflow workflows = 1;
}
// ListWorkflowsForEventRequest is the request for ListWorkflowsForEvent.
message ListWorkflowsForEventRequest {
string event_key = 1;
}
// Workflow represents the Workflow model.
message Workflow {
string id = 1;
google.protobuf.Timestamp created_at = 2;
google.protobuf.Timestamp updated_at = 3;
string tenant_id = 5;
string name = 6;
google.protobuf.StringValue description = 7; // Optional
repeated WorkflowVersion versions = 8;
}
// WorkflowVersion represents the WorkflowVersion model.
message WorkflowVersion {
string id = 1;
@@ -105,19 +79,6 @@ message WorkflowVersion {
string version = 5;
int32 order = 6;
string workflow_id = 7;
WorkflowTriggers triggers = 8;
repeated Job jobs = 9;
}
// WorkflowTriggers represents the WorkflowTriggers model.
message WorkflowTriggers {
string id = 1;
google.protobuf.Timestamp created_at = 2;
google.protobuf.Timestamp updated_at = 3;
string workflow_version_id = 5;
string tenant_id = 6;
repeated WorkflowTriggerEventRef events = 7;
repeated WorkflowTriggerCronRef crons = 8;
}
// WorkflowTriggerEventRef represents the WorkflowTriggerEventRef model.
@@ -132,41 +93,6 @@ message WorkflowTriggerCronRef {
string cron = 2;
}
// Job represents the Job model.
message Job {
string id = 1;
google.protobuf.Timestamp created_at = 2;
google.protobuf.Timestamp updated_at = 3;
string tenant_id = 5;
string workflow_version_id = 6;
string name = 7;
google.protobuf.StringValue description = 8; // Optional
repeated Step steps = 9;
google.protobuf.StringValue timeout = 10; // Optional
}
// Step represents the Step model.
message Step {
string id = 1;
google.protobuf.Timestamp created_at = 2;
google.protobuf.Timestamp updated_at = 3;
google.protobuf.StringValue readable_id = 5; // Optional
string tenant_id = 6;
string job_id = 7;
string action = 8;
google.protobuf.StringValue timeout = 9; // Optional
repeated string parents = 10;
repeated string children = 11;
}
message DeleteWorkflowRequest {
string workflow_id = 1;
}
message GetWorkflowByNameRequest {
string name = 1;
}
message TriggerWorkflowRequest {
string name = 1;
+1 -1
View File
@@ -135,7 +135,7 @@ func (a *AuthN) handleCookieAuth(c echo.Context) error {
return forbidden
}
user, err := a.config.Repository.User().GetUserByID(userID)
user, err := a.config.APIRepository.User().GetUserByID(userID)
if err != nil {
a.l.Debug().Err(err).Msg("error getting user by id")
+1 -1
View File
@@ -74,7 +74,7 @@ func (a *AuthZ) handleCookieAuth(c echo.Context, r *middleware.RouteInfo) error
}
// check if the user is a member of the tenant
tenantMember, err := a.config.Repository.Tenant().GetTenantMemberByUserID(tenant.ID, user.ID)
tenantMember, err := a.config.APIRepository.Tenant().GetTenantMemberByUserID(tenant.ID, user.ID)
if err != nil {
a.l.Debug().Err(err).Msgf("error getting tenant member")
+1 -1
View File
@@ -11,7 +11,7 @@ import (
func (a *APITokenService) ApiTokenList(ctx echo.Context, request gen.ApiTokenListRequestObject) (gen.ApiTokenListResponseObject, error) {
tenant := ctx.Get("tenant").(*db.TenantModel)
tokens, err := a.config.Repository.APIToken().ListAPITokensByTenant(tenant.ID)
tokens, err := a.config.APIRepository.APIToken().ListAPITokensByTenant(tenant.ID)
if err != nil {
return nil, err
+1 -1
View File
@@ -10,7 +10,7 @@ import (
func (a *APITokenService) ApiTokenUpdateRevoke(ctx echo.Context, request gen.ApiTokenUpdateRevokeRequestObject) (gen.ApiTokenUpdateRevokeResponseObject, error) {
apiToken := ctx.Get("api-token").(*db.APITokenModel)
err := a.config.Repository.APIToken().RevokeAPIToken(apiToken.ID)
err := a.config.APIRepository.APIToken().RevokeAPIToken(apiToken.ID)
if err != nil {
return nil, err
+1 -1
View File
@@ -63,7 +63,7 @@ func (t *EventService) EventList(ctx echo.Context, request gen.EventListRequestO
listOpts.WorkflowRunStatus = statuses
}
listRes, err := t.config.Repository.Event().ListEvents(tenant.ID, listOpts)
listRes, err := t.config.APIRepository.Event().ListEvents(tenant.ID, listOpts)
if err != nil {
return nil, err
+1 -1
View File
@@ -10,7 +10,7 @@ import (
func (t *EventService) EventKeyList(ctx echo.Context, request gen.EventKeyListRequestObject) (gen.EventKeyListResponseObject, error) {
tenant := ctx.Get("tenant").(*db.TenantModel)
eventKeys, err := t.config.Repository.Event().ListEventKeys(tenant.ID)
eventKeys, err := t.config.APIRepository.Event().ListEventKeys(tenant.ID)
if err != nil {
return nil, err
+11 -4
View File
@@ -7,6 +7,7 @@ import (
"github.com/hatchet-dev/hatchet/api/v1/server/oas/gen"
"github.com/hatchet-dev/hatchet/api/v1/server/oas/transformers"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/db"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/sqlchelpers"
)
func (t *EventService) EventUpdateReplay(ctx echo.Context, request gen.EventUpdateReplayRequestObject) (gen.EventUpdateReplayResponseObject, error) {
@@ -18,32 +19,38 @@ func (t *EventService) EventUpdateReplay(ctx echo.Context, request gen.EventUpda
eventIds[i] = request.Body.EventIds[i].String()
}
events, err := t.config.Repository.Event().ListEventsById(tenant.ID, eventIds)
events, err := t.config.EngineRepository.Event().ListEventsByIds(tenant.ID, eventIds)
if err != nil {
return nil, err
}
newEvents := make([]db.EventModel, len(events))
newEventIds := make([]string, len(events))
var allErrs error
for i := range events {
event := events[i]
newEvent, err := t.config.Ingestor.IngestReplayedEvent(ctx.Request().Context(), tenant.ID, &event)
newEvent, err := t.config.Ingestor.IngestReplayedEvent(ctx.Request().Context(), tenant.ID, event)
if err != nil {
allErrs = multierror.Append(allErrs, err)
}
newEvents[i] = *newEvent
newEventIds[i] = sqlchelpers.UUIDToStr(newEvent.ID)
}
if allErrs != nil {
return nil, allErrs
}
newEvents, err := t.config.APIRepository.Event().ListEventsById(tenant.ID, newEventIds)
if err != nil {
return nil, err
}
rows := make([]gen.Event, len(newEvents))
for i := range newEvents {
@@ -15,7 +15,7 @@ import (
func (g *GithubAppService) GithubUpdateTenantWebhook(ctx echo.Context, req gen.GithubUpdateTenantWebhookRequestObject) (gen.GithubUpdateTenantWebhookResponseObject, error) {
webhookId := req.Webhook.String()
webhook, err := g.config.Repository.Github().ReadGithubWebhookById(webhookId)
webhook, err := g.config.APIRepository.Github().ReadGithubWebhookById(webhookId)
if err != nil {
return nil, err
@@ -53,13 +53,13 @@ func (g *GithubAppService) GithubUpdateTenantWebhook(ctx echo.Context, req gen.G
func (g *GithubAppService) processPullRequestEvent(tenantId string, event *githubsdk.PullRequestEvent, r *http.Request) error {
pr := github.ToVCSRepositoryPullRequest(*event.GetRepo().GetOwner().Login, event.GetRepo().GetName(), event.GetPullRequest())
dbPR, err := g.config.Repository.Github().GetPullRequest(tenantId, pr.GetRepoOwner(), pr.GetRepoName(), int(pr.GetPRNumber()))
dbPR, err := g.config.APIRepository.Github().GetPullRequest(tenantId, pr.GetRepoOwner(), pr.GetRepoName(), int(pr.GetPRNumber()))
if err != nil {
return err
}
_, err = g.config.Repository.Github().UpdatePullRequest(tenantId, dbPR.ID, &repository.UpdatePullRequestOpts{
_, err = g.config.APIRepository.Github().UpdatePullRequest(tenantId, dbPR.ID, &repository.UpdatePullRequestOpts{
HeadBranch: repository.StringPtr(pr.GetHeadBranch()),
BaseBranch: repository.StringPtr(pr.GetBaseBranch()),
Title: repository.StringPtr(pr.GetTitle()),
+1 -1
View File
@@ -47,7 +47,7 @@ func GetGithubAppClientFromRequest(ctx echo.Context, config *server.ServerConfig
user := ctx.Get("user").(*db.UserModel)
gai := ctx.Get("gh-installation").(*db.GithubAppInstallationModel)
if canAccess, err := config.Repository.Github().CanUserAccessInstallation(gai.ID, user.ID); err != nil || !canAccess {
if canAccess, err := config.APIRepository.Github().CanUserAccessInstallation(gai.ID, user.ID); err != nil || !canAccess {
respErr := apierrors.NewAPIErrors("User does not have access to the installation")
return nil, &respErr
}
@@ -11,7 +11,7 @@ import (
func (g *GithubAppService) GithubAppListInstallations(ctx echo.Context, req gen.GithubAppListInstallationsRequestObject) (gen.GithubAppListInstallationsResponseObject, error) {
user := ctx.Get("user").(*db.UserModel)
gais, err := g.config.Repository.Github().ListGithubAppInstallationsByUserID(user.ID)
gais, err := g.config.APIRepository.Github().ListGithubAppInstallationsByUserID(user.ID)
if err != nil {
return nil, err
@@ -67,7 +67,7 @@ func (g *GithubAppService) UserUpdateGithubOauthCallback(ctx echo.Context, _ gen
}
// upsert in database
_, err = g.config.Repository.Github().UpsertGithubAppOAuth(user.ID, &repository.CreateGithubAppOAuthOpts{
_, err = g.config.APIRepository.Github().UpsertGithubAppOAuth(user.ID, &repository.CreateGithubAppOAuthOpts{
GithubUserID: int(*githubUser.ID),
AccessToken: accessTokenEncrypted,
RefreshToken: &refreshTokenEncrypted,
+6 -6
View File
@@ -57,17 +57,17 @@ func (g *GithubAppService) GithubUpdateGlobalWebhook(ctx echo.Context, req gen.G
func (g *GithubAppService) handleInstallationEvent(senderID int64, i *githubsdk.Installation) error {
// make sure the sender exists in the database
gao, err := g.config.Repository.Github().ReadGithubAppOAuthByGithubUserID(int(senderID))
gao, err := g.config.APIRepository.Github().ReadGithubAppOAuthByGithubUserID(int(senderID))
if err != nil {
return err
}
_, err = g.config.Repository.Github().ReadGithubAppInstallationByInstallationAndAccountID(int(*i.ID), int(*i.Account.ID))
_, err = g.config.APIRepository.Github().ReadGithubAppInstallationByInstallationAndAccountID(int(*i.ID), int(*i.Account.ID))
if err != nil && errors.Is(err, db.ErrNotFound) {
// insert account/installation pair into database
_, err := g.config.Repository.Github().CreateInstallation(gao.GithubUserID, &repository.CreateInstallationOpts{
_, err := g.config.APIRepository.Github().CreateInstallation(gao.GithubUserID, &repository.CreateInstallationOpts{
InstallationID: int(*i.ID),
AccountID: int(*i.Account.ID),
AccountName: *i.Account.Login,
@@ -83,19 +83,19 @@ func (g *GithubAppService) handleInstallationEvent(senderID int64, i *githubsdk.
}
// associate the github user id with this installation in the database
_, err = g.config.Repository.Github().AddGithubUserIdToInstallation(int(*i.ID), int(*i.Account.ID), gao.GithubUserID)
_, err = g.config.APIRepository.Github().AddGithubUserIdToInstallation(int(*i.ID), int(*i.Account.ID), gao.GithubUserID)
return err
}
func (g *GithubAppService) handleDeletionEvent(i *githubsdk.Installation) error {
_, err := g.config.Repository.Github().ReadGithubAppInstallationByInstallationAndAccountID(int(*i.ID), int(*i.Account.ID))
_, err := g.config.APIRepository.Github().ReadGithubAppInstallationByInstallationAndAccountID(int(*i.ID), int(*i.Account.ID))
if err != nil {
return err
}
_, err = g.config.Repository.Github().DeleteInstallation(int(*i.ID), int(*i.Account.ID))
_, err = g.config.APIRepository.Github().DeleteInstallation(int(*i.ID), int(*i.Account.ID))
if err != nil {
return err
+2 -2
View File
@@ -34,7 +34,7 @@ func (i *IngestorsService) SnsUpdate(ctx echo.Context, req gen.SnsUpdateRequestO
tenantId := req.Tenant.String()
// verify that the tenant and the topic ARN are set in the database
snsInt, err := i.config.Repository.SNS().GetSNSIntegration(tenantId, payload.TopicArn)
snsInt, err := i.config.APIRepository.SNS().GetSNSIntegration(tenantId, payload.TopicArn)
if err != nil {
return nil, err
@@ -58,7 +58,7 @@ func (i *IngestorsService) SnsUpdate(ctx echo.Context, req gen.SnsUpdateRequestO
return nil, err
}
default:
_, err := i.config.Ingestor.IngestEvent(ctx.Request().Context(), req.Tenant.String(), req.Event, payload)
_, err := i.config.Ingestor.IngestEvent(ctx.Request().Context(), req.Tenant.String(), req.Event, body)
if err != nil {
return nil, err
@@ -25,7 +25,7 @@ func (i *IngestorsService) SnsCreate(ctx echo.Context, req gen.SnsCreateRequestO
}
// create the SNS integration
snsIntegration, err := i.config.Repository.SNS().CreateSNSIntegration(tenant.ID, opts)
snsIntegration, err := i.config.APIRepository.SNS().CreateSNSIntegration(tenant.ID, opts)
if err != nil {
return nil, err
@@ -13,7 +13,7 @@ func (i *IngestorsService) SnsDelete(ctx echo.Context, req gen.SnsDeleteRequestO
sns := ctx.Get("sns").(*db.SNSIntegrationModel)
// create the SNS integration
err := i.config.Repository.SNS().DeleteSNSIntegration(tenant.ID, sns.ID)
err := i.config.APIRepository.SNS().DeleteSNSIntegration(tenant.ID, sns.ID)
if err != nil {
return nil, err
+1 -1
View File
@@ -13,7 +13,7 @@ func (i *IngestorsService) SnsList(ctx echo.Context, req gen.SnsListRequestObjec
tenant := ctx.Get("tenant").(*db.TenantModel)
// create the SNS integration
snsIntegrations, err := i.config.Repository.SNS().ListSNSIntegrations(tenant.ID)
snsIntegrations, err := i.config.APIRepository.SNS().ListSNSIntegrations(tenant.ID)
if err != nil {
return nil, err
+1 -1
View File
@@ -57,7 +57,7 @@ func (t *LogService) LogLineList(ctx echo.Context, request gen.LogLineListReques
listOpts.Offset = &offset
}
listRes, err := t.config.Repository.Log().ListLogLines(tenant.ID, listOpts)
listRes, err := t.config.APIRepository.Log().ListLogLines(tenant.ID, listOpts)
if err != nil {
return nil, err
+1 -1
View File
@@ -13,7 +13,7 @@ func (u *MetadataService) LivenessGet(ctx echo.Context, request gen.LivenessGetR
}
func (u *MetadataService) ReadinessGet(ctx echo.Context, request gen.ReadinessGetRequestObject) (gen.ReadinessGetResponseObject, error) {
if !u.config.Repository.Health().IsHealthy() {
if !u.config.APIRepository.Health().IsHealthy() {
return nil, fmt.Errorf("repository is not healthy")
}
+4 -4
View File
@@ -26,7 +26,7 @@ func (t *StepRunService) StepRunUpdateRerun(ctx echo.Context, request gen.StepRu
sixSecAgo := time.Now().Add(-6 * time.Second)
workers, err := t.config.Repository.Worker().ListWorkers(tenant.ID, &repository.ListWorkersOpts{
workers, err := t.config.APIRepository.Worker().ListWorkers(tenant.ID, &repository.ListWorkersOpts{
Action: &action,
LastHeartbeatAfter: &sixSecAgo,
Assignable: repository.BoolPtr(true),
@@ -64,13 +64,13 @@ func (t *StepRunService) StepRunUpdateRerun(ctx echo.Context, request gen.StepRu
}
// set the job run and workflow run to running status
err = t.config.Repository.JobRun().SetJobRunStatusRunning(tenant.ID, stepRun.JobRunID)
err = t.config.APIRepository.JobRun().SetJobRunStatusRunning(tenant.ID, stepRun.JobRunID)
if err != nil {
return nil, err
}
engineStepRun, err := t.config.Repository.StepRun().GetStepRunForEngine(tenant.ID, stepRun.ID)
engineStepRun, err := t.config.EngineRepository.StepRun().GetStepRunForEngine(tenant.ID, stepRun.ID)
if err != nil {
return nil, fmt.Errorf("could not get step run for engine: %w", err)
@@ -89,7 +89,7 @@ func (t *StepRunService) StepRunUpdateRerun(ctx echo.Context, request gen.StepRu
// wait for a short period of time
for i := 0; i < 5; i++ {
newStepRun, err := t.config.Repository.StepRun().GetStepRunById(tenant.ID, stepRun.ID)
newStepRun, err := t.config.APIRepository.StepRun().GetStepRunById(tenant.ID, stepRun.ID)
if err != nil {
return nil, fmt.Errorf("could not get step run: %w", err)
+3 -3
View File
@@ -23,7 +23,7 @@ func (t *TenantService) TenantCreate(ctx echo.Context, request gen.TenantCreateR
}
// determine if a tenant with the slug already exists
existingTenant, err := t.config.Repository.Tenant().GetTenantBySlug(request.Body.Slug)
existingTenant, err := t.config.APIRepository.Tenant().GetTenantBySlug(request.Body.Slug)
if err != nil && !errors.Is(err, db.ErrNotFound) {
return nil, err
@@ -42,14 +42,14 @@ func (t *TenantService) TenantCreate(ctx echo.Context, request gen.TenantCreateR
}
// write the user to the db
tenant, err := t.config.Repository.Tenant().CreateTenant(createOpts)
tenant, err := t.config.APIRepository.Tenant().CreateTenant(createOpts)
if err != nil {
return nil, err
}
// add the user as an owner of the tenant
_, err = t.config.Repository.Tenant().CreateTenantMember(tenant.ID, &repository.CreateTenantMemberOpts{
_, err = t.config.APIRepository.Tenant().CreateTenantMember(tenant.ID, &repository.CreateTenantMemberOpts{
UserId: user.ID,
Role: "OWNER",
})
@@ -25,7 +25,7 @@ func (t *TenantService) TenantInviteCreate(ctx echo.Context, request gen.TenantI
}
// ensure that this user isn't already a member of the tenant
if _, err := t.config.Repository.Tenant().GetTenantMemberByEmail(tenant.ID, request.Body.Email); err == nil {
if _, err := t.config.APIRepository.Tenant().GetTenantMemberByEmail(tenant.ID, request.Body.Email); err == nil {
return gen.TenantInviteCreate400JSONResponse(
apierrors.NewAPIErrors("this user is already a member of this tenant"),
), nil
@@ -47,7 +47,7 @@ func (t *TenantService) TenantInviteCreate(ctx echo.Context, request gen.TenantI
}
// create the invite
invite, err := t.config.Repository.TenantInvite().CreateTenantInvite(tenant.ID, createOpts)
invite, err := t.config.APIRepository.TenantInvite().CreateTenantInvite(tenant.ID, createOpts)
if err != nil {
return nil, err
@@ -12,7 +12,7 @@ func (t *TenantService) TenantInviteDelete(ctx echo.Context, request gen.TenantI
invite := ctx.Get("tenant-invite").(*db.TenantInviteLinkModel)
// delete the invite
err := t.config.Repository.TenantInvite().DeleteTenantInvite(invite.ID)
err := t.config.APIRepository.TenantInvite().DeleteTenantInvite(invite.ID)
if err != nil {
return nil, err
@@ -12,7 +12,7 @@ import (
func (t *TenantService) TenantInviteList(ctx echo.Context, request gen.TenantInviteListRequestObject) (gen.TenantInviteListResponseObject, error) {
tenant := ctx.Get("tenant").(*db.TenantModel)
tenantInvites, err := t.config.Repository.TenantInvite().ListTenantInvitesByTenantId(tenant.ID, &repository.ListTenantInvitesOpts{
tenantInvites, err := t.config.APIRepository.TenantInvite().ListTenantInvitesByTenantId(tenant.ID, &repository.ListTenantInvitesOpts{
Expired: repository.BoolPtr(false),
Status: repository.StringPtr("PENDING"),
})
@@ -11,7 +11,7 @@ import (
func (t *TenantService) TenantMemberList(ctx echo.Context, request gen.TenantMemberListRequestObject) (gen.TenantMemberListResponseObject, error) {
tenant := ctx.Get("tenant").(*db.TenantModel)
members, err := t.config.Repository.Tenant().ListTenantMembers(tenant.ID)
members, err := t.config.APIRepository.Tenant().ListTenantMembers(tenant.ID)
if err != nil {
return nil, err
@@ -34,7 +34,7 @@ func (t *TenantService) TenantInviteUpdate(ctx echo.Context, request gen.TenantI
}
// update the invite
invite, err := t.config.Repository.TenantInvite().UpdateTenantInvite(invite.ID, updateOpts)
invite, err := t.config.APIRepository.TenantInvite().UpdateTenantInvite(invite.ID, updateOpts)
if err != nil {
return nil, err
@@ -29,7 +29,7 @@ func (u *UserService) TenantInviteAccept(ctx echo.Context, request gen.TenantInv
}
// get the invite
invite, err := u.config.Repository.TenantInvite().GetTenantInvite(inviteId)
invite, err := u.config.APIRepository.TenantInvite().GetTenantInvite(inviteId)
if err != nil {
return nil, err
@@ -51,7 +51,7 @@ func (u *UserService) TenantInviteAccept(ctx echo.Context, request gen.TenantInv
}
// ensure the user is not already a member of the tenant
_, err = u.config.Repository.Tenant().GetTenantMemberByEmail(invite.TenantID, user.Email)
_, err = u.config.APIRepository.Tenant().GetTenantMemberByEmail(invite.TenantID, user.Email)
if err != nil && !errors.Is(err, db.ErrNotFound) {
return nil, err
@@ -65,14 +65,14 @@ func (u *UserService) TenantInviteAccept(ctx echo.Context, request gen.TenantInv
}
// update the invite
invite, err = u.config.Repository.TenantInvite().UpdateTenantInvite(invite.ID, updateOpts)
invite, err = u.config.APIRepository.TenantInvite().UpdateTenantInvite(invite.ID, updateOpts)
if err != nil {
return nil, err
}
// add the user to the tenant
_, err = u.config.Repository.Tenant().CreateTenantMember(invite.TenantID, &repository.CreateTenantMemberOpts{
_, err = u.config.APIRepository.Tenant().CreateTenantMember(invite.TenantID, &repository.CreateTenantMemberOpts{
UserId: user.ID,
Role: string(invite.Role),
})
+2 -2
View File
@@ -37,7 +37,7 @@ func (u *UserService) UserCreate(ctx echo.Context, request gen.UserCreateRequest
}
// determine if the user exists before attempting to write the user
existingUser, err := u.config.Repository.User().GetUserByEmail(string(request.Body.Email))
existingUser, err := u.config.APIRepository.User().GetUserByEmail(string(request.Body.Email))
if err != nil && !errors.Is(err, db.ErrNotFound) {
return nil, err
@@ -68,7 +68,7 @@ func (u *UserService) UserCreate(ctx echo.Context, request gen.UserCreateRequest
}
// write the user to the db
user, err := u.config.Repository.User().CreateUser(createOpts)
user, err := u.config.APIRepository.User().CreateUser(createOpts)
if err != nil {
return nil, err
}
@@ -87,11 +87,11 @@ func (u *UserService) upsertGoogleUserFromToken(config *server.ServerConfig, tok
ExpiresAt: &expiresAt,
}
user, err := u.config.Repository.User().GetUserByEmail(gInfo.Email)
user, err := u.config.APIRepository.User().GetUserByEmail(gInfo.Email)
switch err {
case nil:
user, err = u.config.Repository.User().UpdateUser(user.ID, &repository.UpdateUserOpts{
user, err = u.config.APIRepository.User().UpdateUser(user.ID, &repository.UpdateUserOpts{
EmailVerified: repository.BoolPtr(gInfo.EmailVerified),
Name: repository.StringPtr(gInfo.Name),
OAuth: oauthOpts,
@@ -101,7 +101,7 @@ func (u *UserService) upsertGoogleUserFromToken(config *server.ServerConfig, tok
return nil, fmt.Errorf("failed to update user: %s", err.Error())
}
case db.ErrNotFound:
user, err = u.config.Repository.User().CreateUser(&repository.CreateUserOpts{
user, err = u.config.APIRepository.User().CreateUser(&repository.CreateUserOpts{
Email: gInfo.Email,
EmailVerified: repository.BoolPtr(gInfo.EmailVerified),
Name: repository.StringPtr(gInfo.Name),
@@ -11,7 +11,7 @@ import (
func (t *UserService) TenantMembershipsList(ctx echo.Context, request gen.TenantMembershipsListRequestObject) (gen.TenantMembershipsListResponseObject, error) {
user := ctx.Get("user").(*db.UserModel)
memberships, err := t.config.Repository.User().ListTenantMemberships(user.ID)
memberships, err := t.config.APIRepository.User().ListTenantMemberships(user.ID)
if err != nil {
return nil, err
@@ -11,7 +11,7 @@ import (
func (t *UserService) UserListTenantInvites(ctx echo.Context, request gen.UserListTenantInvitesRequestObject) (gen.UserListTenantInvitesResponseObject, error) {
user := ctx.Get("user").(*db.UserModel)
invites, err := t.config.Repository.TenantInvite().ListTenantInvitesByEmail(user.Email)
invites, err := t.config.APIRepository.TenantInvite().ListTenantInvitesByEmail(user.Email)
if err != nil {
return nil, err
@@ -29,7 +29,7 @@ func (u *UserService) TenantInviteReject(ctx echo.Context, request gen.TenantInv
}
// get the invite
invite, err := u.config.Repository.TenantInvite().GetTenantInvite(inviteId)
invite, err := u.config.APIRepository.TenantInvite().GetTenantInvite(inviteId)
if err != nil {
return nil, err
@@ -56,7 +56,7 @@ func (u *UserService) TenantInviteReject(ctx echo.Context, request gen.TenantInv
}
// update the invite
invite, err = u.config.Repository.TenantInvite().UpdateTenantInvite(invite.ID, updateOpts)
invite, err = u.config.APIRepository.TenantInvite().UpdateTenantInvite(invite.ID, updateOpts)
if err != nil {
return nil, err
+2 -2
View File
@@ -30,7 +30,7 @@ func (u *UserService) UserUpdateLogin(ctx echo.Context, request gen.UserUpdateLo
}
// determine if the user exists before attempting to write the user
existingUser, err := u.config.Repository.User().GetUserByEmail(string(request.Body.Email))
existingUser, err := u.config.APIRepository.User().GetUserByEmail(string(request.Body.Email))
if err != nil {
if errors.Is(err, db.ErrNotFound) {
return gen.UserUpdateLogin400JSONResponse(apierrors.NewAPIErrors("user not found")), nil
@@ -39,7 +39,7 @@ func (u *UserService) UserUpdateLogin(ctx echo.Context, request gen.UserUpdateLo
return nil, err
}
userPass, err := u.config.Repository.User().GetUserPassword(existingUser.ID)
userPass, err := u.config.APIRepository.User().GetUserPassword(existingUser.ID)
if err != nil {
return nil, fmt.Errorf("could not get user password: %w", err)
+1 -1
View File
@@ -11,7 +11,7 @@ import (
func (t *WorkerService) WorkerGet(ctx echo.Context, request gen.WorkerGetRequestObject) (gen.WorkerGetResponseObject, error) {
worker := ctx.Get("worker").(*db.WorkerModel)
stepRuns, err := t.config.Repository.Worker().ListRecentWorkerStepRuns(worker.TenantID, worker.ID)
stepRuns, err := t.config.APIRepository.Worker().ListRecentWorkerStepRuns(worker.TenantID, worker.ID)
if err != nil {
return nil, err
+1 -1
View File
@@ -16,7 +16,7 @@ func (t *WorkerService) WorkerList(ctx echo.Context, request gen.WorkerListReque
sixSecAgo := time.Now().Add(-6 * time.Second)
workers, err := t.config.Repository.Worker().ListWorkers(tenant.ID, &repository.ListWorkersOpts{
workers, err := t.config.APIRepository.Worker().ListWorkers(tenant.ID, &repository.ListWorkersOpts{
LastHeartbeatAfter: &sixSecAgo,
})
+1 -1
View File
@@ -11,7 +11,7 @@ func (t *WorkflowService) WorkflowDelete(ctx echo.Context, request gen.WorkflowD
tenant := ctx.Get("tenant").(*db.TenantModel)
workflow := ctx.Get("workflow").(*db.WorkflowModel)
workflow, err := t.config.Repository.Workflow().DeleteWorkflow(tenant.ID, workflow.ID)
workflow, err := t.config.APIRepository.Workflow().DeleteWorkflow(tenant.ID, workflow.ID)
if err != nil {
return nil, err
@@ -31,7 +31,7 @@ func (t *WorkflowService) WorkflowVersionGetDefinition(ctx echo.Context, request
workflowVersionId = versions[0].ID
}
workflowVersion, err := t.config.Repository.Workflow().GetWorkflowVersionById(tenant.ID, workflowVersionId)
workflowVersion, err := t.config.APIRepository.Workflow().GetWorkflowVersionById(tenant.ID, workflowVersionId)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
@@ -13,7 +13,7 @@ import (
func (t *WorkflowService) StepRunGetDiff(ctx echo.Context, request gen.StepRunGetDiffRequestObject) (gen.StepRunGetDiffResponseObject, error) {
stepRun := ctx.Get("step-run").(*db.StepRunModel)
diffs, originalValues, err := vcsutils.GetStepRunOverrideDiffs(t.config.Repository.StepRun(), stepRun)
diffs, originalValues, err := vcsutils.GetStepRunOverrideDiffs(t.config.APIRepository.StepRun(), stepRun)
if err != nil {
return nil, fmt.Errorf("could not get diffs: %s", err)
@@ -31,7 +31,7 @@ func (t *WorkflowService) WorkflowVersionGet(ctx echo.Context, request gen.Workf
workflowVersionId = versions[0].ID
}
workflowVersion, err := t.config.Repository.Workflow().GetWorkflowVersionById(tenant.ID, workflowVersionId)
workflowVersion, err := t.config.APIRepository.Workflow().GetWorkflowVersionById(tenant.ID, workflowVersionId)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
@@ -18,7 +18,7 @@ func (t *WorkflowService) WorkflowUpdateLinkGithub(ctx echo.Context, request gen
// check that the user has access to the installation id
installationId := request.Body.InstallationId
_, err := t.config.Repository.Github().ReadGithubAppInstallationByID(installationId)
_, err := t.config.APIRepository.Github().ReadGithubAppInstallationByID(installationId)
if err != nil {
return gen.WorkflowUpdateLinkGithub404JSONResponse(
@@ -26,13 +26,13 @@ func (t *WorkflowService) WorkflowUpdateLinkGithub(ctx echo.Context, request gen
), nil
}
if canAccess, err := t.config.Repository.Github().CanUserAccessInstallation(installationId, user.ID); err != nil || !canAccess {
if canAccess, err := t.config.APIRepository.Github().CanUserAccessInstallation(installationId, user.ID); err != nil || !canAccess {
return gen.WorkflowUpdateLinkGithub403JSONResponse(
apierrors.NewAPIErrors("User does not have access to the installation"),
), nil
}
_, err = t.config.Repository.Workflow().UpsertWorkflowDeploymentConfig(
_, err = t.config.APIRepository.Workflow().UpsertWorkflowDeploymentConfig(
workflow.ID,
&repository.UpsertWorkflowDeploymentConfigOpts{
GithubAppInstallationId: installationId,
@@ -46,7 +46,7 @@ func (t *WorkflowService) WorkflowUpdateLinkGithub(ctx echo.Context, request gen
return nil, err
}
workflow, err = t.config.Repository.Workflow().GetWorkflowById(workflow.ID)
workflow, err = t.config.APIRepository.Workflow().GetWorkflowById(workflow.ID)
if err != nil {
return nil, err
+1 -1
View File
@@ -22,7 +22,7 @@ func (t *WorkflowService) WorkflowList(ctx echo.Context, request gen.WorkflowLis
Offset: &offset,
}
listResp, err := t.config.Repository.Workflow().ListWorkflows(tenant.ID, listOpts)
listResp, err := t.config.APIRepository.Workflow().ListWorkflows(tenant.ID, listOpts)
if err != nil {
return nil, err
@@ -18,7 +18,7 @@ func (t *WorkflowService) WorkflowRunListPullRequests(ctx echo.Context, request
listOpts.State = repository.StringPtr(string(*request.Params.State))
}
prs, err := t.config.Repository.WorkflowRun().ListPullRequestsForWorkflowRun(
prs, err := t.config.APIRepository.WorkflowRun().ListPullRequestsForWorkflowRun(
workflowRun.TenantID,
workflowRun.ID,
listOpts,
@@ -42,7 +42,7 @@ func (t *WorkflowService) WorkflowRunList(ctx echo.Context, request gen.Workflow
listOpts.EventId = &eventIdStr
}
workflowRuns, err := t.config.Repository.WorkflowRun().ListWorkflowRuns(tenant.ID, listOpts)
workflowRuns, err := t.config.APIRepository.WorkflowRun().ListWorkflowRuns(tenant.ID, listOpts)
if err != nil {
return nil, err
+3 -3
View File
@@ -36,7 +36,7 @@ func (t *WorkflowService) WorkflowRunCreate(ctx echo.Context, request gen.Workfl
workflowVersionId = versions[0].ID
}
workflowVersion, err := t.config.Repository.Workflow().GetWorkflowVersionById(tenant.ID, workflowVersionId)
workflowVersion, err := t.config.EngineRepository.Workflow().GetWorkflowVersionById(tenant.ID, workflowVersionId)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
@@ -63,7 +63,7 @@ func (t *WorkflowService) WorkflowRunCreate(ctx echo.Context, request gen.Workfl
return nil, err
}
workflowRun, err := t.config.Repository.WorkflowRun().CreateNewWorkflowRun(ctx.Request().Context(), tenant.ID, createOpts)
workflowRun, err := t.config.APIRepository.WorkflowRun().CreateNewWorkflowRun(ctx.Request().Context(), tenant.ID, createOpts)
if err != nil {
return nil, fmt.Errorf("could not create workflow run: %w", err)
@@ -73,7 +73,7 @@ func (t *WorkflowService) WorkflowRunCreate(ctx echo.Context, request gen.Workfl
err = t.config.MessageQueue.AddMessage(
ctx.Request().Context(),
msgqueue.WORKFLOW_PROCESSING_QUEUE,
tasktypes.WorkflowRunQueuedToTask(workflowRun),
tasktypes.WorkflowRunQueuedToTask(workflowRun.TenantID, workflowRun.ID),
)
if err != nil {
+10 -10
View File
@@ -80,7 +80,7 @@ func (t *APIServer) Run() (func() error, error) {
populatorMW := populator.NewPopulator(t.config)
populatorMW.RegisterGetter("tenant", func(config *server.ServerConfig, parentId, id string) (result interface{}, uniqueParentId string, err error) {
tenant, err := config.Repository.Tenant().GetTenantByID(id)
tenant, err := config.APIRepository.Tenant().GetTenantByID(id)
if err != nil {
return nil, "", err
@@ -90,7 +90,7 @@ func (t *APIServer) Run() (func() error, error) {
})
populatorMW.RegisterGetter("api-token", func(config *server.ServerConfig, parentId, id string) (result interface{}, uniqueParentId string, err error) {
apiToken, err := config.Repository.APIToken().GetAPITokenById(id)
apiToken, err := config.APIRepository.APIToken().GetAPITokenById(id)
if err != nil {
return nil, "", err
@@ -109,7 +109,7 @@ func (t *APIServer) Run() (func() error, error) {
})
populatorMW.RegisterGetter("tenant-invite", func(config *server.ServerConfig, parentId, id string) (result interface{}, uniqueParentId string, err error) {
tenantInvite, err := config.Repository.TenantInvite().GetTenantInvite(id)
tenantInvite, err := config.APIRepository.TenantInvite().GetTenantInvite(id)
if err != nil {
return nil, "", err
@@ -119,7 +119,7 @@ func (t *APIServer) Run() (func() error, error) {
})
populatorMW.RegisterGetter("sns", func(config *server.ServerConfig, parentId, id string) (result interface{}, uniqueParentId string, err error) {
snsIntegration, err := config.Repository.SNS().GetSNSIntegrationById(id)
snsIntegration, err := config.APIRepository.SNS().GetSNSIntegrationById(id)
if err != nil {
return nil, "", err
@@ -129,7 +129,7 @@ func (t *APIServer) Run() (func() error, error) {
})
populatorMW.RegisterGetter("workflow", func(config *server.ServerConfig, parentId, id string) (result interface{}, uniqueParentId string, err error) {
workflow, err := config.Repository.Workflow().GetWorkflowById(id)
workflow, err := config.APIRepository.Workflow().GetWorkflowById(id)
if err != nil {
return nil, "", err
@@ -139,7 +139,7 @@ func (t *APIServer) Run() (func() error, error) {
})
populatorMW.RegisterGetter("workflow-run", func(config *server.ServerConfig, parentId, id string) (result interface{}, uniqueParentId string, err error) {
workflowRun, err := config.Repository.WorkflowRun().GetWorkflowRunById(parentId, id)
workflowRun, err := config.APIRepository.WorkflowRun().GetWorkflowRunById(parentId, id)
if err != nil {
return nil, "", err
@@ -149,7 +149,7 @@ func (t *APIServer) Run() (func() error, error) {
})
populatorMW.RegisterGetter("step-run", func(config *server.ServerConfig, parentId, id string) (result interface{}, uniqueParentId string, err error) {
stepRun, err := config.Repository.StepRun().GetStepRunById(parentId, id)
stepRun, err := config.APIRepository.StepRun().GetStepRunById(parentId, id)
if err != nil {
return nil, "", err
@@ -159,7 +159,7 @@ func (t *APIServer) Run() (func() error, error) {
})
populatorMW.RegisterGetter("event", func(config *server.ServerConfig, parentId, id string) (result interface{}, uniqueParentId string, err error) {
event, err := config.Repository.Event().GetEventById(id)
event, err := config.APIRepository.Event().GetEventById(id)
if err != nil {
return nil, "", err
@@ -169,7 +169,7 @@ func (t *APIServer) Run() (func() error, error) {
})
populatorMW.RegisterGetter("worker", func(config *server.ServerConfig, parentId, id string) (result interface{}, uniqueParentId string, err error) {
worker, err := config.Repository.Worker().GetWorkerById(id)
worker, err := config.APIRepository.Worker().GetWorkerById(id)
if err != nil {
return nil, "", err
@@ -179,7 +179,7 @@ func (t *APIServer) Run() (func() error, error) {
})
populatorMW.RegisterGetter("gh-installation", func(config *server.ServerConfig, parentId, id string) (result interface{}, uniqueParentId string, err error) {
ghInstallation, err := config.Repository.Github().ReadGithubAppInstallationByID(id)
ghInstallation, err := config.APIRepository.Github().ReadGithubAppInstallationByID(id)
if err != nil {
return nil, "", err
+13 -9
View File
@@ -6,11 +6,13 @@ import (
"log"
"os"
"github.com/jackc/pgx/v5"
"github.com/spf13/cobra"
"github.com/hatchet-dev/hatchet/internal/config/loader"
"github.com/hatchet-dev/hatchet/internal/repository"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/db"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/sqlchelpers"
)
// seedCmd seeds the database with initial data
@@ -55,11 +57,11 @@ func runSeed(cf *loader.ConfigLoader) error {
return err
}
user, err := dc.Repository.User().GetUserByEmail(dc.Seed.AdminEmail)
user, err := dc.APIRepository.User().GetUserByEmail(dc.Seed.AdminEmail)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
user, err = dc.Repository.User().CreateUser(&repository.CreateUserOpts{
user, err = dc.APIRepository.User().CreateUser(&repository.CreateUserOpts{
Email: dc.Seed.AdminEmail,
Name: repository.StringPtr(dc.Seed.AdminName),
EmailVerified: repository.BoolPtr(true),
@@ -77,13 +79,13 @@ func runSeed(cf *loader.ConfigLoader) error {
userId = user.ID
}
tenant, err := dc.Repository.Tenant().GetTenantBySlug("default")
tenant, err := dc.APIRepository.Tenant().GetTenantBySlug("default")
if err != nil {
if errors.Is(err, db.ErrNotFound) {
// seed an example tenant
// initialize a tenant
tenant, err = dc.Repository.Tenant().CreateTenant(&repository.CreateTenantOpts{
tenant, err = dc.APIRepository.Tenant().CreateTenant(&repository.CreateTenantOpts{
ID: &dc.Seed.DefaultTenantID,
Name: dc.Seed.DefaultTenantName,
Slug: dc.Seed.DefaultTenantSlug,
@@ -96,7 +98,7 @@ func runSeed(cf *loader.ConfigLoader) error {
fmt.Println("created tenant", tenant.ID)
// add the user to the tenant
_, err = dc.Repository.Tenant().CreateTenantMember(tenant.ID, &repository.CreateTenantMemberOpts{
_, err = dc.APIRepository.Tenant().CreateTenantMember(tenant.ID, &repository.CreateTenantMemberOpts{
Role: "OWNER",
UserId: userId,
})
@@ -110,7 +112,7 @@ func runSeed(cf *loader.ConfigLoader) error {
}
if dc.Seed.IsDevelopment {
err = seedDev(dc.Repository, tenant.ID)
err = seedDev(dc.EngineRepository, tenant.ID)
if err != nil {
return err
@@ -120,11 +122,11 @@ func runSeed(cf *loader.ConfigLoader) error {
return nil
}
func seedDev(repo repository.Repository, tenantId string) error {
func seedDev(repo repository.EngineRepository, tenantId string) error {
_, err := repo.Workflow().GetWorkflowByName(tenantId, "test-workflow")
if err != nil {
if !errors.Is(err, db.ErrNotFound) {
if !errors.Is(err, pgx.ErrNoRows) {
return err
}
@@ -173,7 +175,9 @@ func seedDev(repo repository.Repository, tenantId string) error {
return err
}
fmt.Println("created workflow", wf.ID, wf.Workflow().Name)
workflowVersionId := sqlchelpers.UUIDToStr(wf.WorkflowVersion.ID)
fmt.Println("created workflow version", workflowVersionId)
}
return nil
+1 -1
View File
@@ -19,7 +19,7 @@ func Start(cf *loader.ConfigLoader, interruptCh <-chan interface{}) error {
if sc.InternalClient != nil {
w, err := worker.NewWorker(
worker.WithRepository(sc.Repository),
worker.WithRepository(sc.APIRepository),
worker.WithClient(sc.InternalClient),
worker.WithVCSProviders(sc.VCSProviders),
)
+10 -10
View File
@@ -63,7 +63,7 @@ func Run(ctx context.Context, cf *loader.ConfigLoader) error {
var h *health.Health
healthProbes := sc.HasService("health")
if healthProbes {
h = health.New(sc.Repository, sc.MessageQueue)
h = health.New(sc.EngineRepository, sc.MessageQueue)
cleanup, err := h.Start()
if err != nil {
return fmt.Errorf("could not start health: %w", err)
@@ -78,7 +78,7 @@ func Run(ctx context.Context, cf *loader.ConfigLoader) error {
if sc.HasService("ticker") {
t, err := ticker.New(
ticker.WithMessageQueue(sc.MessageQueue),
ticker.WithRepository(sc.Repository),
ticker.WithRepository(sc.EngineRepository),
ticker.WithLogger(sc.Logger),
)
@@ -99,7 +99,7 @@ func Run(ctx context.Context, cf *loader.ConfigLoader) error {
if sc.HasService("eventscontroller") {
ec, err := events.New(
events.WithMessageQueue(sc.MessageQueue),
events.WithRepository(sc.Repository),
events.WithRepository(sc.EngineRepository),
events.WithLogger(sc.Logger),
)
if err != nil {
@@ -120,7 +120,7 @@ func Run(ctx context.Context, cf *loader.ConfigLoader) error {
jc, err := jobs.New(
jobs.WithAlerter(sc.Alerter),
jobs.WithMessageQueue(sc.MessageQueue),
jobs.WithRepository(sc.Repository),
jobs.WithRepository(sc.EngineRepository),
jobs.WithLogger(sc.Logger),
)
@@ -141,7 +141,7 @@ func Run(ctx context.Context, cf *loader.ConfigLoader) error {
if sc.HasService("workflowscontroller") {
wc, err := workflows.New(
workflows.WithMessageQueue(sc.MessageQueue),
workflows.WithRepository(sc.Repository),
workflows.WithRepository(sc.EngineRepository),
workflows.WithLogger(sc.Logger),
)
if err != nil {
@@ -161,7 +161,7 @@ func Run(ctx context.Context, cf *loader.ConfigLoader) error {
if sc.HasService("heartbeater") {
h, err := heartbeat.New(
heartbeat.WithMessageQueue(sc.MessageQueue),
heartbeat.WithRepository(sc.Repository),
heartbeat.WithRepository(sc.EngineRepository),
heartbeat.WithLogger(sc.Logger),
)
@@ -184,7 +184,7 @@ func Run(ctx context.Context, cf *loader.ConfigLoader) error {
d, err := dispatcher.New(
dispatcher.WithAlerter(sc.Alerter),
dispatcher.WithMessageQueue(sc.MessageQueue),
dispatcher.WithRepository(sc.Repository),
dispatcher.WithRepository(sc.EngineRepository),
dispatcher.WithLogger(sc.Logger),
)
if err != nil {
@@ -199,10 +199,10 @@ func Run(ctx context.Context, cf *loader.ConfigLoader) error {
// create the event ingestor
ei, err := ingestor.NewIngestor(
ingestor.WithEventRepository(
sc.Repository.Event(),
sc.EngineRepository.Event(),
),
ingestor.WithLogRepository(
sc.Repository.Log(),
sc.EngineRepository.Log(),
),
ingestor.WithMessageQueue(sc.MessageQueue),
)
@@ -211,7 +211,7 @@ func Run(ctx context.Context, cf *loader.ConfigLoader) error {
}
adminSvc, err := admin.NewAdminService(
admin.WithRepository(sc.Repository),
admin.WithRepository(sc.EngineRepository),
admin.WithMessageQueue(sc.MessageQueue),
)
if err != nil {
+3 -1
View File
@@ -87,10 +87,12 @@ func main() {
time.Sleep(5 * time.Second)
executeAt := time.Now().Add(time.Second * 10)
executeAt2 := time.Now().Add(time.Second * 20)
executeAt3 := time.Now().Add(time.Second * 30)
err = c.Admin().ScheduleWorkflow(
"scheduled-workflow",
client.WithSchedules(executeAt),
client.WithSchedules(executeAt, executeAt2, executeAt3),
client.WithInput(&scheduledInput{
ScheduledAt: time.Now(),
ExecuteAt: executeAt,
+3 -2
View File
@@ -7,11 +7,12 @@ import (
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/hatchet-dev/hatchet/internal/auth/cookie"
"github.com/hatchet-dev/hatchet/internal/config/database"
"github.com/hatchet-dev/hatchet/internal/encryption"
"github.com/hatchet-dev/hatchet/internal/testutils"
"github.com/stretchr/testify/assert"
)
func TestSessionStoreSave(t *testing.T) {
@@ -79,7 +80,7 @@ func newSessionStore(t *testing.T, conf *database.Config, cookieName string) *co
cookie.WithCookieDomain("hatchet.run"),
cookie.WithCookieName(cookieName),
cookie.WithCookieAllowInsecure(false),
cookie.WithSessionRepository(conf.Repository.UserSession()),
cookie.WithSessionRepository(conf.APIRepository.UserSession()),
)
if err != nil {
+3 -3
View File
@@ -26,11 +26,11 @@ type TokenOpts struct {
type jwtManagerImpl struct {
encryption encryption.EncryptionService
opts *TokenOpts
tokenRepo repository.APITokenRepository
tokenRepo repository.EngineTokenRepository
verifier jwt.Verifier
}
func NewJWTManager(encryptionSvc encryption.EncryptionService, tokenRepo repository.APITokenRepository, opts *TokenOpts) (JWTManager, error) {
func NewJWTManager(encryptionSvc encryption.EncryptionService, tokenRepo repository.EngineTokenRepository, opts *TokenOpts) (JWTManager, error) {
verifier, err := jwt.NewVerifier(encryptionSvc.GetPublicJWTHandle())
if err != nil {
@@ -150,7 +150,7 @@ func (j *jwtManagerImpl) ValidateTenantToken(token string) (tenantId string, err
return "", fmt.Errorf("token has been revoked")
}
if expiresAt, ok := dbToken.ExpiresAt(); ok && expiresAt.Before(time.Now()) {
if expiresAt := dbToken.ExpiresAt.Time; expiresAt.Before(time.Now()) {
return "", fmt.Errorf("token has expired")
}
+8 -8
View File
@@ -30,7 +30,7 @@ func TestCreateTenantToken(t *testing.T) { // make sure no cache is used for tes
t.Fatal(err.Error())
}
_, err = conf.Repository.Tenant().CreateTenant(&repository.CreateTenantOpts{
_, err = conf.APIRepository.Tenant().CreateTenant(&repository.CreateTenantOpts{
ID: &tenantId,
Name: "test-tenant",
Slug: fmt.Sprintf("test-tenant-%s", slugSuffix),
@@ -71,7 +71,7 @@ func TestRevokeTenantToken(t *testing.T) {
t.Fatal(err.Error())
}
_, err = conf.Repository.Tenant().CreateTenant(&repository.CreateTenantOpts{
_, err = conf.APIRepository.Tenant().CreateTenant(&repository.CreateTenantOpts{
ID: &tenantId,
Name: "test-tenant",
Slug: fmt.Sprintf("test-tenant-%s", slugSuffix),
@@ -93,14 +93,14 @@ func TestRevokeTenantToken(t *testing.T) {
assert.NoError(t, err)
// revoke the token
apiTokens, err := conf.Repository.APIToken().ListAPITokensByTenant(tenantId)
apiTokens, err := conf.APIRepository.APIToken().ListAPITokensByTenant(tenantId)
if err != nil {
t.Fatal(err.Error())
}
assert.Len(t, apiTokens, 1)
err = conf.Repository.APIToken().RevokeAPIToken(apiTokens[0].ID)
err = conf.APIRepository.APIToken().RevokeAPIToken(apiTokens[0].ID)
if err != nil {
t.Fatal(err.Error())
@@ -131,7 +131,7 @@ func TestRevokeTenantTokenCache(t *testing.T) {
t.Fatal(err.Error())
}
_, err = conf.Repository.Tenant().CreateTenant(&repository.CreateTenantOpts{
_, err = conf.APIRepository.Tenant().CreateTenant(&repository.CreateTenantOpts{
ID: &tenantId,
Name: "test-tenant",
Slug: fmt.Sprintf("test-tenant-%s", slugSuffix),
@@ -153,14 +153,14 @@ func TestRevokeTenantTokenCache(t *testing.T) {
assert.NoError(t, err)
// revoke the token
apiTokens, err := conf.Repository.APIToken().ListAPITokensByTenant(tenantId)
apiTokens, err := conf.APIRepository.APIToken().ListAPITokensByTenant(tenantId)
if err != nil {
t.Fatal(err.Error())
}
assert.Len(t, apiTokens, 1)
err = conf.Repository.APIToken().RevokeAPIToken(apiTokens[0].ID)
err = conf.APIRepository.APIToken().RevokeAPIToken(apiTokens[0].ID)
if err != nil {
t.Fatal(err.Error())
@@ -191,7 +191,7 @@ func getJWTManager(t *testing.T, conf *database.Config) token.JWTManager {
t.Fatal(err.Error())
}
tokenRepo := conf.Repository.APIToken()
tokenRepo := conf.EngineRepository.APIToken()
jwtManager, err := token.NewJWTManager(encryptionService, tokenRepo, &token.TokenOpts{
Issuer: "hatchet",
+6 -1
View File
@@ -17,6 +17,8 @@ type ConfigFile struct {
PostgresDbName string `mapstructure:"dbName" json:"dbName,omitempty" default:"hatchet"`
PostgresSSLMode string `mapstructure:"sslMode" json:"sslMode,omitempty" default:"disable"`
MaxConns int `mapstructure:"maxConns" json:"maxConns,omitempty" default:"5"`
Seed SeedConfigFile `mapstructure:"seed" json:"seed,omitempty"`
Logger shared.LoggerConfigFile `mapstructure:"logger" json:"logger,omitempty"`
@@ -41,7 +43,9 @@ type SeedConfigFile struct {
type Config struct {
Disconnect func() error
Repository repository.Repository
APIRepository repository.APIRepository
EngineRepository repository.EngineRepository
Seed SeedConfigFile
}
@@ -54,6 +58,7 @@ func BindAllEnv(v *viper.Viper) {
_ = v.BindEnv("dbName", "DATABASE_POSTGRES_DB_NAME")
_ = v.BindEnv("sslMode", "DATABASE_POSTGRES_SSL_MODE")
_ = v.BindEnv("logQueries", "DATABASE_LOG_QUERIES")
_ = v.BindEnv("maxConns", "DATABASE_MAX_CONNS")
_ = v.BindEnv("cacheDuration", "CACHE_DURATION")
+10 -9
View File
@@ -143,7 +143,7 @@ func GetDatabaseConfigFromConfigFile(cf *database.ConfigFile) (res *database.Con
config.ConnConfig.Tracer = otelpgx.NewTracer()
config.MaxConns = 20
config.MaxConns = int32(cf.MaxConns)
pool, err := pgxpool.NewWithConfig(context.Background(), config)
if err != nil {
@@ -157,8 +157,9 @@ func GetDatabaseConfigFromConfigFile(cf *database.ConfigFile) (res *database.Con
ch.Stop()
return c.Prisma.Disconnect()
},
Repository: prisma.NewPrismaRepository(c, pool, prisma.WithLogger(&l), prisma.WithCache(ch)),
Seed: cf.Seed,
APIRepository: prisma.NewAPIRepository(c, pool, prisma.WithLogger(&l), prisma.WithCache(ch)),
EngineRepository: prisma.NewEngineRepository(pool, prisma.WithLogger(&l), prisma.WithCache(ch)),
Seed: cf.Seed,
}, nil
}
@@ -172,7 +173,7 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF
}
ss, err := cookie.NewUserSessionStore(
cookie.WithSessionRepository(dc.Repository.UserSession()),
cookie.WithSessionRepository(dc.APIRepository.UserSession()),
cookie.WithCookieAllowInsecure(cf.Auth.Cookie.Insecure),
cookie.WithCookieDomain(cf.Auth.Cookie.Domain),
cookie.WithCookieName(cf.Auth.Cookie.Name),
@@ -189,8 +190,8 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF
)
ingestor, err := ingestor.NewIngestor(
ingestor.WithEventRepository(dc.Repository.Event()),
ingestor.WithLogRepository(dc.Repository.Log()),
ingestor.WithEventRepository(dc.EngineRepository.Event()),
ingestor.WithLogRepository(dc.EngineRepository.Log()),
ingestor.WithMessageQueue(mq),
)
@@ -243,7 +244,7 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF
}
// create a new JWT manager
auth.JWTManager, err = token.NewJWTManager(encryptionSvc, dc.Repository.APIToken(), &token.TokenOpts{
auth.JWTManager, err = token.NewJWTManager(encryptionSvc, dc.EngineRepository.APIToken(), &token.TokenOpts{
Issuer: cf.Runtime.ServerURL,
Audience: cf.Runtime.ServerURL,
GRPCBroadcastAddress: cf.Runtime.GRPCBroadcastAddress,
@@ -277,7 +278,7 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF
return nil, nil, err
}
githubProvider := github.NewGithubVCSProvider(githubAppConf, dc.Repository, cf.Runtime.ServerURL, encryptionSvc)
githubProvider := github.NewGithubVCSProvider(githubAppConf, dc.APIRepository, cf.Runtime.ServerURL, encryptionSvc)
vcsProviders[vcs.VCSRepositoryKindGithub] = githubProvider
}
@@ -286,7 +287,7 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF
if cf.Runtime.WorkerEnabled {
// get the internal tenant or create if it doesn't exist
internalTenant, err := dc.Repository.Tenant().GetTenantBySlug("internal")
internalTenant, err := dc.APIRepository.Tenant().GetTenantBySlug("internal")
if err != nil {
return nil, nil, fmt.Errorf("could not get internal tenant: %w", err)
+3 -3
View File
@@ -16,13 +16,13 @@ import (
)
type GithubVCSProvider struct {
repo repository.Repository
repo repository.APIRepository
appConf *GithubAppConf
serverURL string
enc encryption.EncryptionService
}
func NewGithubVCSProvider(appConf *GithubAppConf, repo repository.Repository, serverURL string, enc encryption.EncryptionService) GithubVCSProvider {
func NewGithubVCSProvider(appConf *GithubAppConf, repo repository.APIRepository, serverURL string, enc encryption.EncryptionService) GithubVCSProvider {
return GithubVCSProvider{
appConf: appConf,
repo: repo,
@@ -82,7 +82,7 @@ func (g GithubVCSProvider) GetVCSRepositoryFromWorkflow(workflow *db.WorkflowMod
type GithubVCSRepository struct {
repoOwner, repoName string
client *githubsdk.Client
repo repository.Repository
repo repository.APIRepository
serverURL string
webhookURL string
enc encryption.EncryptionService
+1 -1
View File
@@ -15,7 +15,7 @@ import (
// GetStepRunOverrideDiffs returns a map of the override keys to the override values which have changed
// between the first step run and the latest step run.
func GetStepRunOverrideDiffs(repo repository.StepRunRepository, stepRun *db.StepRunModel) (diffs map[string]string, original map[string]string, err error) {
func GetStepRunOverrideDiffs(repo repository.StepRunAPIRepository, stepRun *db.StepRunModel) (diffs map[string]string, original map[string]string, err error) {
// get the first step run archived result, there will be at least one
var archivedResult inputtable
+6 -1
View File
@@ -4,6 +4,7 @@ import (
"time"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/db"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/dbsqlc"
)
type CreateAPITokenOpts struct {
@@ -22,7 +23,11 @@ type CreateAPITokenOpts struct {
type APITokenRepository interface {
GetAPITokenById(id string) (*db.APITokenModel, error)
CreateAPIToken(opts *CreateAPITokenOpts) (*db.APITokenModel, error)
RevokeAPIToken(id string) error
ListAPITokensByTenant(tenantId string) ([]db.APITokenModel, error)
}
type EngineTokenRepository interface {
CreateAPIToken(opts *CreateAPITokenOpts) (*dbsqlc.APIToken, error)
GetAPITokenById(id string) (*dbsqlc.APIToken, error)
}
+4 -10
View File
@@ -3,7 +3,7 @@ package repository
import (
"time"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/db"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/dbsqlc"
)
type CreateDispatcherOpts struct {
@@ -14,20 +14,14 @@ type UpdateDispatcherOpts struct {
LastHeartbeatAt *time.Time
}
type DispatcherRepository interface {
// GetDispatcherForWorker returns the dispatcher connected to a given worker.
GetDispatcherForWorker(workerId string) (*db.DispatcherModel, error)
type DispatcherEngineRepository interface {
// CreateNewDispatcher creates a new dispatcher for a given tenant.
CreateNewDispatcher(opts *CreateDispatcherOpts) (*db.DispatcherModel, error)
CreateNewDispatcher(opts *CreateDispatcherOpts) (*dbsqlc.Dispatcher, error)
// UpdateDispatcher updates a dispatcher for a given tenant.
UpdateDispatcher(dispatcherId string, opts *UpdateDispatcherOpts) (*db.DispatcherModel, error)
UpdateDispatcher(dispatcherId string, opts *UpdateDispatcherOpts) (*dbsqlc.Dispatcher, error)
Delete(dispatcherId string) error
// AddWorker adds a worker to a dispatcher.
AddWorker(dispatcherId, workerId string) (*db.DispatcherModel, error)
UpdateStaleDispatchers(onStale func(dispatcherId string, getValidDispatcherId func() string) error) error
}
+12 -8
View File
@@ -15,7 +15,7 @@ type CreateEventOpts struct {
Key string `validate:"required"`
// (optional) the event data
Data *db.JSON
Data []byte
// (optional) the event that this event is replaying
ReplayedEvent *string `validate:"omitempty,uuid"`
@@ -55,7 +55,7 @@ type ListEventResult struct {
Count int
}
type EventRepository interface {
type EventAPIRepository interface {
// ListEvents returns all events for a given tenant.
ListEvents(tenantId string, opts *ListEventOpts) (*ListEventResult, error)
@@ -65,12 +65,16 @@ type EventRepository interface {
// GetEventById returns an event by id.
GetEventById(id string) (*db.EventModel, error)
// GetEventForEngine returns an event for the engine by id.
GetEventForEngine(tenantId, id string) (*dbsqlc.GetEventForEngineRow, error)
// ListEventsById returns a list of events by id.
ListEventsById(tenantId string, ids []string) ([]db.EventModel, error)
// CreateEvent creates a new event for a given tenant.
CreateEvent(ctx context.Context, opts *CreateEventOpts) (*db.EventModel, error)
}
type EventEngineRepository interface {
// CreateEvent creates a new event for a given tenant.
CreateEvent(ctx context.Context, opts *CreateEventOpts) (*dbsqlc.Event, error)
// GetEventForEngine returns an event for the engine by id.
GetEventForEngine(tenantId, id string) (*dbsqlc.Event, error)
ListEventsByIds(tenantId string, ids []string) ([]*dbsqlc.Event, error)
}
+1 -5
View File
@@ -33,10 +33,7 @@ type UpdateGetGroupKeyRunOpts struct {
Output *string
}
type GetGroupKeyRunRepository interface {
// ListGetGroupKeyRuns returns a list of get group key runs for a tenant which match the given options.
ListGetGroupKeyRuns(tenantId string, opts *ListGetGroupKeyRunsOpts) ([]db.GetGroupKeyRunModel, error)
type GetGroupKeyRunEngineRepository interface {
// ListStepRunsToRequeue returns a list of step runs which are in a requeueable state.
ListGetGroupKeyRunsToRequeue(tenantId string) ([]*dbsqlc.GetGroupKeyRun, error)
@@ -47,6 +44,5 @@ type GetGroupKeyRunRepository interface {
UpdateGetGroupKeyRun(tenantId, getGroupKeyRunId string, opts *UpdateGetGroupKeyRunOpts) (*dbsqlc.GetGroupKeyRunForEngineRow, error)
GetGroupKeyRunById(tenantId, getGroupKeyRunId string) (*db.GetGroupKeyRunModel, error)
GetGroupKeyRunForEngine(tenantId, getGroupKeyRunId string) (*dbsqlc.GetGroupKeyRunForEngineRow, error)
}
+9 -7
View File
@@ -1,6 +1,8 @@
package repository
import (
"github.com/jackc/pgx/v5/pgtype"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/db"
)
@@ -21,16 +23,16 @@ func JobRunStatusPtr(status db.JobRunStatus) *db.JobRunStatus {
return &status
}
type JobRunRepository interface {
ListAllJobRuns(opts *ListAllJobRunsOpts) ([]db.JobRunModel, error)
GetJobRunById(tenantId, jobRunId string) (*db.JobRunModel, error)
type JobRunAPIRepository interface {
// SetJobRunStatusRunning resets the status of a job run to a RUNNING status. This is useful if a step
// run is being manually replayed, but shouldn't be used by most callers.
SetJobRunStatusRunning(tenantId, jobRunId string) error
}
type JobRunEngineRepository interface {
// SetJobRunStatusRunning resets the status of a job run to a RUNNING status. This is useful if a step
// run is being manually replayed, but shouldn't be used by most callers.
SetJobRunStatusRunning(tenantId, jobRunId string) error
GetJobRunLookupData(tenantId, jobRunId string) (*db.JobRunLookupDataModel, error)
UpdateJobRunLookupData(tenantId, jobRunId string, opts *UpdateJobRunLookupDataOpts) error
ListJobRunsForWorkflowRun(tenantId, workflowRunId string) ([]pgtype.UUID, error)
}
+6 -4
View File
@@ -51,10 +51,12 @@ type ListLogsResult struct {
Count int
}
type LogsRepository interface {
// PutLog creates a new log line.
PutLog(tenantId string, opts *CreateLogLineOpts) (*dbsqlc.LogLine, error)
type LogsAPIRepository interface {
// ListLogLines returns a list of log lines for a given step run.
ListLogLines(tenantId string, opts *ListLogsOpts) (*ListLogsResult, error)
}
type LogsEngineRepository interface {
// PutLog creates a new log line.
PutLog(tenantId string, opts *CreateLogLineOpts) (*dbsqlc.LogLine, error)
}
+52
View File
@@ -4,9 +4,14 @@ import (
"context"
"time"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/rs/zerolog"
"github.com/hatchet-dev/hatchet/internal/repository"
"github.com/hatchet-dev/hatchet/internal/repository/cache"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/db"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/dbsqlc"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/sqlchelpers"
"github.com/hatchet-dev/hatchet/internal/validator"
)
@@ -74,3 +79,50 @@ func (a *apiTokenRepository) ListAPITokensByTenant(tenantId string) ([]db.APITok
db.APIToken.Revoked.Equals(false),
).Exec(context.Background())
}
type engineTokenRepository struct {
cache cache.Cacheable
pool *pgxpool.Pool
v validator.Validator
queries *dbsqlc.Queries
l *zerolog.Logger
}
func NewEngineTokenRepository(pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger, cache cache.Cacheable) repository.EngineTokenRepository {
queries := dbsqlc.New()
return &engineTokenRepository{
cache: cache,
pool: pool,
v: v,
l: l,
queries: queries,
}
}
func (a *engineTokenRepository) CreateAPIToken(opts *repository.CreateAPITokenOpts) (*dbsqlc.APIToken, error) {
if err := a.v.Validate(opts); err != nil {
return nil, err
}
createParams := dbsqlc.CreateAPITokenParams{
ID: sqlchelpers.UUIDFromStr(opts.ID),
Expiresat: sqlchelpers.TimestampFromTime(opts.ExpiresAt),
}
if opts.TenantId != nil {
createParams.TenantId = sqlchelpers.UUIDFromStr(*opts.TenantId)
}
if opts.Name != nil {
createParams.Name = sqlchelpers.TextFromStr(*opts.Name)
}
return a.queries.CreateAPIToken(context.Background(), a.pool, createParams)
}
func (a *engineTokenRepository) GetAPITokenById(id string) (*dbsqlc.APIToken, error) {
return cache.MakeCacheable[dbsqlc.APIToken](a.cache, id, func() (*dbsqlc.APIToken, error) {
return a.queries.GetAPITokenById(context.Background(), a.pool, sqlchelpers.UUIDFromStr(id))
})
}
@@ -0,0 +1,24 @@
-- name: GetAPITokenById :one
SELECT
*
FROM
"APIToken"
WHERE
"id" = @id::uuid;
-- name: CreateAPIToken :one
INSERT INTO "APIToken" (
"id",
"createdAt",
"updatedAt",
"tenantId",
"name",
"expiresAt"
) VALUES (
coalesce(@id::uuid, gen_random_uuid()),
CURRENT_TIMESTAMP,
CURRENT_TIMESTAMP,
sqlc.narg('tenantId')::uuid,
sqlc.narg('name')::text,
@expiresAt::timestamp
) RETURNING *;
@@ -0,0 +1,81 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.24.0
// source: api_tokens.sql
package dbsqlc
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const createAPIToken = `-- name: CreateAPIToken :one
INSERT INTO "APIToken" (
"id",
"createdAt",
"updatedAt",
"tenantId",
"name",
"expiresAt"
) VALUES (
coalesce($1::uuid, gen_random_uuid()),
CURRENT_TIMESTAMP,
CURRENT_TIMESTAMP,
$2::uuid,
$3::text,
$4::timestamp
) RETURNING id, "createdAt", "updatedAt", "expiresAt", revoked, name, "tenantId"
`
type CreateAPITokenParams struct {
ID pgtype.UUID `json:"id"`
TenantId pgtype.UUID `json:"tenantId"`
Name pgtype.Text `json:"name"`
Expiresat pgtype.Timestamp `json:"expiresat"`
}
func (q *Queries) CreateAPIToken(ctx context.Context, db DBTX, arg CreateAPITokenParams) (*APIToken, error) {
row := db.QueryRow(ctx, createAPIToken,
arg.ID,
arg.TenantId,
arg.Name,
arg.Expiresat,
)
var i APIToken
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.ExpiresAt,
&i.Revoked,
&i.Name,
&i.TenantId,
)
return &i, err
}
const getAPITokenById = `-- name: GetAPITokenById :one
SELECT
id, "createdAt", "updatedAt", "expiresAt", revoked, name, "tenantId"
FROM
"APIToken"
WHERE
"id" = $1::uuid
`
func (q *Queries) GetAPITokenById(ctx context.Context, db DBTX, id pgtype.UUID) (*APIToken, error) {
row := db.QueryRow(ctx, getAPITokenById, id)
var i APIToken
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.ExpiresAt,
&i.Revoked,
&i.Name,
&i.TenantId,
)
return &i, err
}
@@ -32,4 +32,27 @@ RETURNING
SELECT
sqlc.embed(dispatchers)
FROM
"Dispatcher" as dispatchers;
"Dispatcher" as dispatchers;
-- name: DeleteDispatcher :one
DELETE FROM
"Dispatcher" as dispatchers
WHERE
"id" = sqlc.arg('id')::uuid
RETURNING *;
-- name: CreateDispatcher :one
INSERT INTO
"Dispatcher" ("id", "lastHeartbeatAt", "isActive")
VALUES
(sqlc.arg('id')::uuid, CURRENT_TIMESTAMP, 't')
RETURNING *;
-- name: UpdateDispatcher :one
UPDATE
"Dispatcher" as dispatchers
SET
"lastHeartbeatAt" = sqlc.arg('lastHeartbeatAt')::timestamp
WHERE
"id" = sqlc.arg('id')::uuid
RETURNING *;
@@ -11,6 +11,50 @@ import (
"github.com/jackc/pgx/v5/pgtype"
)
const createDispatcher = `-- name: CreateDispatcher :one
INSERT INTO
"Dispatcher" ("id", "lastHeartbeatAt", "isActive")
VALUES
($1::uuid, CURRENT_TIMESTAMP, 't')
RETURNING id, "createdAt", "updatedAt", "deletedAt", "lastHeartbeatAt", "isActive"
`
func (q *Queries) CreateDispatcher(ctx context.Context, db DBTX, id pgtype.UUID) (*Dispatcher, error) {
row := db.QueryRow(ctx, createDispatcher, id)
var i Dispatcher
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.DeletedAt,
&i.LastHeartbeatAt,
&i.IsActive,
)
return &i, err
}
const deleteDispatcher = `-- name: DeleteDispatcher :one
DELETE FROM
"Dispatcher" as dispatchers
WHERE
"id" = $1::uuid
RETURNING id, "createdAt", "updatedAt", "deletedAt", "lastHeartbeatAt", "isActive"
`
func (q *Queries) DeleteDispatcher(ctx context.Context, db DBTX, id pgtype.UUID) (*Dispatcher, error) {
row := db.QueryRow(ctx, deleteDispatcher, id)
var i Dispatcher
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.DeletedAt,
&i.LastHeartbeatAt,
&i.IsActive,
)
return &i, err
}
const listActiveDispatchers = `-- name: ListActiveDispatchers :many
SELECT
dispatchers.id, dispatchers."createdAt", dispatchers."updatedAt", dispatchers."deletedAt", dispatchers."lastHeartbeatAt", dispatchers."isActive"
@@ -174,3 +218,32 @@ func (q *Queries) SetDispatchersInactive(ctx context.Context, db DBTX, ids []pgt
}
return items, nil
}
const updateDispatcher = `-- name: UpdateDispatcher :one
UPDATE
"Dispatcher" as dispatchers
SET
"lastHeartbeatAt" = $1::timestamp
WHERE
"id" = $2::uuid
RETURNING id, "createdAt", "updatedAt", "deletedAt", "lastHeartbeatAt", "isActive"
`
type UpdateDispatcherParams struct {
LastHeartbeatAt pgtype.Timestamp `json:"lastHeartbeatAt"`
ID pgtype.UUID `json:"id"`
}
func (q *Queries) UpdateDispatcher(ctx context.Context, db DBTX, arg UpdateDispatcherParams) (*Dispatcher, error) {
row := db.QueryRow(ctx, updateDispatcher, arg.LastHeartbeatAt, arg.ID)
var i Dispatcher
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.DeletedAt,
&i.LastHeartbeatAt,
&i.IsActive,
)
return &i, err
}
+10 -4
View File
@@ -1,9 +1,6 @@
-- name: GetEventForEngine :one
SELECT
"id",
"key",
"data",
"tenantId"
*
FROM
"Event"
WHERE
@@ -120,3 +117,12 @@ GROUP BY
event_hour
ORDER BY
event_hour;
-- name: ListEventsByIDs :many
SELECT
*
FROM
"Event" as events
WHERE
"tenantId" = @tenantId::uuid AND
"id" = ANY (sqlc.arg('ids')::uuid[]);
+52 -14
View File
@@ -125,31 +125,25 @@ func (q *Queries) CreateEvent(ctx context.Context, db DBTX, arg CreateEventParam
const getEventForEngine = `-- name: GetEventForEngine :one
SELECT
"id",
"key",
"data",
"tenantId"
id, "createdAt", "updatedAt", "deletedAt", key, "tenantId", "replayedFromId", data
FROM
"Event"
WHERE
"id" = $1::uuid
`
type GetEventForEngineRow struct {
ID pgtype.UUID `json:"id"`
Key string `json:"key"`
Data []byte `json:"data"`
TenantId pgtype.UUID `json:"tenantId"`
}
func (q *Queries) GetEventForEngine(ctx context.Context, db DBTX, id pgtype.UUID) (*GetEventForEngineRow, error) {
func (q *Queries) GetEventForEngine(ctx context.Context, db DBTX, id pgtype.UUID) (*Event, error) {
row := db.QueryRow(ctx, getEventForEngine, id)
var i GetEventForEngineRow
var i Event
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.DeletedAt,
&i.Key,
&i.Data,
&i.TenantId,
&i.ReplayedFromId,
&i.Data,
)
return &i, err
}
@@ -300,3 +294,47 @@ func (q *Queries) ListEvents(ctx context.Context, db DBTX, arg ListEventsParams)
}
return items, nil
}
const listEventsByIDs = `-- name: ListEventsByIDs :many
SELECT
id, "createdAt", "updatedAt", "deletedAt", key, "tenantId", "replayedFromId", data
FROM
"Event" as events
WHERE
"tenantId" = $1::uuid AND
"id" = ANY ($2::uuid[])
`
type ListEventsByIDsParams struct {
Tenantid pgtype.UUID `json:"tenantid"`
Ids []pgtype.UUID `json:"ids"`
}
func (q *Queries) ListEventsByIDs(ctx context.Context, db DBTX, arg ListEventsByIDsParams) ([]*Event, error) {
rows, err := db.Query(ctx, listEventsByIDs, arg.Tenantid, arg.Ids)
if err != nil {
return nil, err
}
defer rows.Close()
var items []*Event
for rows.Next() {
var i Event
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.DeletedAt,
&i.Key,
&i.TenantId,
&i.ReplayedFromId,
&i.Data,
); err != nil {
return nil, err
}
items = append(items, &i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
@@ -96,7 +96,6 @@ WITH get_group_key_run AS (
WHERE
ggr."id" = @getGroupKeyRunId::uuid AND
ggr."tenantId" = @tenantId::uuid
FOR UPDATE SKIP LOCKED
), valid_workers AS (
SELECT
w."id", w."dispatcherId"
@@ -116,7 +115,6 @@ WITH get_group_key_run AS (
SELECT "id", "dispatcherId"
FROM valid_workers
LIMIT 1
FOR UPDATE SKIP LOCKED
)
UPDATE
"GetGroupKeyRun"
@@ -127,7 +125,8 @@ SET
FROM selected_worker
LIMIT 1
),
"updatedAt" = CURRENT_TIMESTAMP
"updatedAt" = CURRENT_TIMESTAMP,
"timeoutAt" = CURRENT_TIMESTAMP + INTERVAL '5 minutes'
WHERE
"id" = @getGroupKeyRunId::uuid AND
"tenantId" = @tenantId::uuid AND
@@ -72,7 +72,6 @@ WITH get_group_key_run AS (
WHERE
ggr."id" = $1::uuid AND
ggr."tenantId" = $2::uuid
FOR UPDATE SKIP LOCKED
), valid_workers AS (
SELECT
w."id", w."dispatcherId"
@@ -92,7 +91,6 @@ WITH get_group_key_run AS (
SELECT "id", "dispatcherId"
FROM valid_workers
LIMIT 1
FOR UPDATE SKIP LOCKED
)
UPDATE
"GetGroupKeyRun"
@@ -103,7 +101,8 @@ SET
FROM selected_worker
LIMIT 1
),
"updatedAt" = CURRENT_TIMESTAMP
"updatedAt" = CURRENT_TIMESTAMP,
"timeoutAt" = CURRENT_TIMESTAMP + INTERVAL '5 minutes'
WHERE
"id" = $1::uuid AND
"tenantId" = $2::uuid AND
@@ -105,3 +105,11 @@ WHERE
WHERE "id" = @stepRunId::uuid
)
AND "tenantId" = @tenantId::uuid;
-- name: ListJobRunsForWorkflowRun :many
SELECT
"id"
FROM
"JobRun" jr
WHERE
jr."workflowRunId" = @workflowRunId::uuid;
@@ -11,6 +11,35 @@ import (
"github.com/jackc/pgx/v5/pgtype"
)
const listJobRunsForWorkflowRun = `-- name: ListJobRunsForWorkflowRun :many
SELECT
"id"
FROM
"JobRun" jr
WHERE
jr."workflowRunId" = $1::uuid
`
func (q *Queries) ListJobRunsForWorkflowRun(ctx context.Context, db DBTX, workflowrunid pgtype.UUID) ([]pgtype.UUID, error) {
rows, err := db.Query(ctx, listJobRunsForWorkflowRun, workflowrunid)
if err != nil {
return nil, err
}
defer rows.Close()
var items []pgtype.UUID
for rows.Next() {
var id pgtype.UUID
if err := rows.Scan(&id); err != nil {
return nil, err
}
items = append(items, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const resolveJobRunStatus = `-- name: ResolveJobRunStatus :one
WITH stepRuns AS (
SELECT sum(case when runs."status" IN ('PENDING', 'PENDING_ASSIGNMENT') then 1 else 0 end) AS pendingRuns,
@@ -4,6 +4,7 @@ sql:
# database:
# uri: "postgres://hatchet:hatchet@localhost:5431/hatchet"
queries:
- api_tokens.sql
- health.sql
- events.sql
- workflow_runs.sql
@@ -15,6 +16,7 @@ sql:
- dispatchers.sql
- workers.sql
- logs.sql
- tenants.sql
schema:
- schema.sql
strict_order_by: false
+81 -38
View File
@@ -28,25 +28,88 @@ SELECT
FROM
"StepRun" sr
JOIN
"Step" s ON sr."stepId" = s."id" AND s."tenantId" = @tenantId::uuid
"Step" s ON sr."stepId" = s."id"
JOIN
"Action" a ON s."actionId" = a."actionId" AND a."tenantId" = @tenantId::uuid
"Action" a ON s."actionId" = a."actionId"
JOIN
"JobRun" jr ON sr."jobRunId" = jr."id" AND jr."tenantId" = @tenantId::uuid
"JobRun" jr ON sr."jobRunId" = jr."id"
JOIN
"JobRunLookupData" jrld ON jr."id" = jrld."jobRunId" AND jrld."tenantId" = @tenantId::uuid
"JobRunLookupData" jrld ON jr."id" = jrld."jobRunId"
JOIN
"Job" j ON jr."jobId" = j."id" AND j."tenantId" = @tenantId::uuid
"Job" j ON jr."jobId" = j."id"
JOIN
"WorkflowRun" wr ON jr."workflowRunId" = wr."id" AND wr."tenantId" = @tenantId::uuid
"WorkflowRun" wr ON jr."workflowRunId" = wr."id"
JOIN
"WorkflowVersion" wv ON wr."workflowVersionId" = wv."id"
JOIN
"Workflow" w ON wv."workflowId" = w."id" AND w."tenantId" = @tenantId::uuid
"Workflow" w ON wv."workflowId" = w."id"
WHERE
sr."id" = ANY(@ids::uuid[]) AND
sr."tenantId" = @tenantId::uuid;
(
sqlc.narg('tenantId')::uuid IS NULL OR
sr."tenantId" = sqlc.narg('tenantId')::uuid
);
-- name: ListStartableStepRuns :many
WITH job_run AS (
SELECT "status"
FROM "JobRun"
WHERE "id" = @jobRunId::uuid
)
SELECT
child_run."id" AS "id"
FROM
"StepRun" AS child_run
LEFT JOIN
"_StepRunOrder" AS step_run_order ON step_run_order."B" = child_run."id"
JOIN
job_run ON true
WHERE
child_run."jobRunId" = @jobRunId::uuid
AND child_run."status" = 'PENDING'
AND job_run."status" = 'RUNNING'
-- case on whether parentStepRunId is null
AND (
(sqlc.narg('parentStepRunId')::uuid IS NULL AND step_run_order."A" IS NULL) OR
(
step_run_order."A" = sqlc.narg('parentStepRunId')::uuid
AND NOT EXISTS (
SELECT 1
FROM "_StepRunOrder" AS parent_order
JOIN "StepRun" AS parent_run ON parent_order."A" = parent_run."id"
WHERE
parent_order."B" = child_run."id"
AND parent_run."status" != 'SUCCEEDED'
)
)
);
-- name: ListStepRuns :many
SELECT
"StepRun"."id"
FROM
"StepRun"
JOIN
"JobRun" ON "StepRun"."jobRunId" = "JobRun"."id"
WHERE
"StepRun"."tenantId" = @tenantId::uuid
AND (
sqlc.narg('status')::"StepRunStatus" IS NULL OR
"StepRun"."status" = sqlc.narg('status')::"StepRunStatus"
)
AND (
sqlc.narg('workflowRunId')::uuid IS NULL OR
"JobRun"."workflowRunId" = sqlc.narg('workflowRunId')::uuid
)
AND (
sqlc.narg('jobRunId')::uuid IS NULL OR
"StepRun"."jobRunId" = sqlc.narg('jobRunId')::uuid
)
AND (
sqlc.narg('tickerId')::uuid IS NULL OR
"StepRun"."tickerId" = sqlc.narg('tickerId')::uuid
);
-- name: UpdateStepRun :one
UPDATE
"StepRun"
@@ -268,7 +331,8 @@ WITH step_run AS (
SELECT
sr."id",
sr."status",
a."id" AS "actionId"
a."id" AS "actionId",
s."timeout" AS "stepTimeout"
FROM
"StepRun" sr
JOIN
@@ -278,7 +342,6 @@ WITH step_run AS (
WHERE
sr."id" = @stepRunId::uuid AND
sr."tenantId" = @tenantId::uuid
FOR UPDATE SKIP LOCKED
),
valid_workers AS (
SELECT
@@ -308,7 +371,6 @@ selected_worker AS (
SELECT "id", "dispatcherId"
FROM valid_workers
LIMIT 1
FOR UPDATE SKIP LOCKED
)
UPDATE
"StepRun"
@@ -319,33 +381,14 @@ SET
FROM selected_worker
LIMIT 1
),
"updatedAt" = CURRENT_TIMESTAMP
"updatedAt" = CURRENT_TIMESTAMP,
"timeoutAt" = CASE
WHEN (SELECT "stepTimeout" FROM step_run) IS NOT NULL THEN
CURRENT_TIMESTAMP + convert_duration_to_interval((SELECT "stepTimeout" FROM step_run))
ELSE CURRENT_TIMESTAMP + INTERVAL '5 minutes'
END
WHERE
"id" = @stepRunId::uuid AND
"tenantId" = @tenantId::uuid AND
EXISTS (SELECT 1 FROM selected_worker)
RETURNING "StepRun"."id", "StepRun"."workerId", (SELECT "dispatcherId" FROM selected_worker) AS "dispatcherId";
-- name: AssignStepRunToTicker :one
WITH selected_ticker AS (
SELECT
t."id"
FROM
"Ticker" t
WHERE
t."lastHeartbeatAt" > NOW() - INTERVAL '6 seconds'
ORDER BY random()
LIMIT 1
)
UPDATE
"StepRun"
SET
"tickerId" = (
SELECT "id"
FROM selected_ticker
)
WHERE
"id" = @stepRunId::uuid AND
"tenantId" = @tenantId::uuid AND
EXISTS (SELECT 1 FROM selected_ticker)
RETURNING "StepRun"."id", "StepRun"."tickerId";
RETURNING "StepRun"."id", "StepRun"."workerId", (SELECT "dispatcherId" FROM selected_worker) AS "dispatcherId";
@@ -95,54 +95,13 @@ func (q *Queries) ArchiveStepRunResultFromStepRun(ctx context.Context, db DBTX,
return &i, err
}
const assignStepRunToTicker = `-- name: AssignStepRunToTicker :one
WITH selected_ticker AS (
SELECT
t."id"
FROM
"Ticker" t
WHERE
t."lastHeartbeatAt" > NOW() - INTERVAL '6 seconds'
ORDER BY random()
LIMIT 1
)
UPDATE
"StepRun"
SET
"tickerId" = (
SELECT "id"
FROM selected_ticker
)
WHERE
"id" = $1::uuid AND
"tenantId" = $2::uuid AND
EXISTS (SELECT 1 FROM selected_ticker)
RETURNING "StepRun"."id", "StepRun"."tickerId"
`
type AssignStepRunToTickerParams struct {
Steprunid pgtype.UUID `json:"steprunid"`
Tenantid pgtype.UUID `json:"tenantid"`
}
type AssignStepRunToTickerRow struct {
ID pgtype.UUID `json:"id"`
TickerId pgtype.UUID `json:"tickerId"`
}
func (q *Queries) AssignStepRunToTicker(ctx context.Context, db DBTX, arg AssignStepRunToTickerParams) (*AssignStepRunToTickerRow, error) {
row := db.QueryRow(ctx, assignStepRunToTicker, arg.Steprunid, arg.Tenantid)
var i AssignStepRunToTickerRow
err := row.Scan(&i.ID, &i.TickerId)
return &i, err
}
const assignStepRunToWorker = `-- name: AssignStepRunToWorker :one
WITH step_run AS (
SELECT
sr."id",
sr."status",
a."id" AS "actionId"
a."id" AS "actionId",
s."timeout" AS "stepTimeout"
FROM
"StepRun" sr
JOIN
@@ -152,7 +111,6 @@ WITH step_run AS (
WHERE
sr."id" = $1::uuid AND
sr."tenantId" = $2::uuid
FOR UPDATE SKIP LOCKED
),
valid_workers AS (
SELECT
@@ -182,7 +140,6 @@ selected_worker AS (
SELECT "id", "dispatcherId"
FROM valid_workers
LIMIT 1
FOR UPDATE SKIP LOCKED
)
UPDATE
"StepRun"
@@ -193,7 +150,12 @@ SET
FROM selected_worker
LIMIT 1
),
"updatedAt" = CURRENT_TIMESTAMP
"updatedAt" = CURRENT_TIMESTAMP,
"timeoutAt" = CASE
WHEN (SELECT "stepTimeout" FROM step_run) IS NOT NULL THEN
CURRENT_TIMESTAMP + convert_duration_to_interval((SELECT "stepTimeout" FROM step_run))
ELSE CURRENT_TIMESTAMP + INTERVAL '5 minutes'
END
WHERE
"id" = $1::uuid AND
"tenantId" = $2::uuid AND
@@ -289,29 +251,32 @@ SELECT
FROM
"StepRun" sr
JOIN
"Step" s ON sr."stepId" = s."id" AND s."tenantId" = $1::uuid
"Step" s ON sr."stepId" = s."id"
JOIN
"Action" a ON s."actionId" = a."actionId" AND a."tenantId" = $1::uuid
"Action" a ON s."actionId" = a."actionId"
JOIN
"JobRun" jr ON sr."jobRunId" = jr."id" AND jr."tenantId" = $1::uuid
"JobRun" jr ON sr."jobRunId" = jr."id"
JOIN
"JobRunLookupData" jrld ON jr."id" = jrld."jobRunId" AND jrld."tenantId" = $1::uuid
"JobRunLookupData" jrld ON jr."id" = jrld."jobRunId"
JOIN
"Job" j ON jr."jobId" = j."id" AND j."tenantId" = $1::uuid
"Job" j ON jr."jobId" = j."id"
JOIN
"WorkflowRun" wr ON jr."workflowRunId" = wr."id" AND wr."tenantId" = $1::uuid
"WorkflowRun" wr ON jr."workflowRunId" = wr."id"
JOIN
"WorkflowVersion" wv ON wr."workflowVersionId" = wv."id"
JOIN
"Workflow" w ON wv."workflowId" = w."id" AND w."tenantId" = $1::uuid
"Workflow" w ON wv."workflowId" = w."id"
WHERE
sr."id" = ANY($2::uuid[]) AND
sr."tenantId" = $1::uuid
sr."id" = ANY($1::uuid[]) AND
(
$2::uuid IS NULL OR
sr."tenantId" = $2::uuid
)
`
type GetStepRunForEngineParams struct {
Tenantid pgtype.UUID `json:"tenantid"`
Ids []pgtype.UUID `json:"ids"`
TenantId pgtype.UUID `json:"tenantId"`
}
type GetStepRunForEngineRow struct {
@@ -333,7 +298,7 @@ type GetStepRunForEngineRow struct {
}
func (q *Queries) GetStepRunForEngine(ctx context.Context, db DBTX, arg GetStepRunForEngineParams) ([]*GetStepRunForEngineRow, error) {
rows, err := db.Query(ctx, getStepRunForEngine, arg.Tenantid, arg.Ids)
rows, err := db.Query(ctx, getStepRunForEngine, arg.Ids, arg.TenantId)
if err != nil {
return nil, err
}
@@ -393,6 +358,127 @@ func (q *Queries) GetStepRunForEngine(ctx context.Context, db DBTX, arg GetStepR
return items, nil
}
const listStartableStepRuns = `-- name: ListStartableStepRuns :many
WITH job_run AS (
SELECT "status"
FROM "JobRun"
WHERE "id" = $1::uuid
)
SELECT
child_run."id" AS "id"
FROM
"StepRun" AS child_run
LEFT JOIN
"_StepRunOrder" AS step_run_order ON step_run_order."B" = child_run."id"
JOIN
job_run ON true
WHERE
child_run."jobRunId" = $1::uuid
AND child_run."status" = 'PENDING'
AND job_run."status" = 'RUNNING'
-- case on whether parentStepRunId is null
AND (
($2::uuid IS NULL AND step_run_order."A" IS NULL) OR
(
step_run_order."A" = $2::uuid
AND NOT EXISTS (
SELECT 1
FROM "_StepRunOrder" AS parent_order
JOIN "StepRun" AS parent_run ON parent_order."A" = parent_run."id"
WHERE
parent_order."B" = child_run."id"
AND parent_run."status" != 'SUCCEEDED'
)
)
)
`
type ListStartableStepRunsParams struct {
Jobrunid pgtype.UUID `json:"jobrunid"`
ParentStepRunId pgtype.UUID `json:"parentStepRunId"`
}
func (q *Queries) ListStartableStepRuns(ctx context.Context, db DBTX, arg ListStartableStepRunsParams) ([]pgtype.UUID, error) {
rows, err := db.Query(ctx, listStartableStepRuns, arg.Jobrunid, arg.ParentStepRunId)
if err != nil {
return nil, err
}
defer rows.Close()
var items []pgtype.UUID
for rows.Next() {
var id pgtype.UUID
if err := rows.Scan(&id); err != nil {
return nil, err
}
items = append(items, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listStepRuns = `-- name: ListStepRuns :many
SELECT
"StepRun"."id"
FROM
"StepRun"
JOIN
"JobRun" ON "StepRun"."jobRunId" = "JobRun"."id"
WHERE
"StepRun"."tenantId" = $1::uuid
AND (
$2::"StepRunStatus" IS NULL OR
"StepRun"."status" = $2::"StepRunStatus"
)
AND (
$3::uuid IS NULL OR
"JobRun"."workflowRunId" = $3::uuid
)
AND (
$4::uuid IS NULL OR
"StepRun"."jobRunId" = $4::uuid
)
AND (
$5::uuid IS NULL OR
"StepRun"."tickerId" = $5::uuid
)
`
type ListStepRunsParams struct {
Tenantid pgtype.UUID `json:"tenantid"`
Status NullStepRunStatus `json:"status"`
WorkflowRunId pgtype.UUID `json:"workflowRunId"`
JobRunId pgtype.UUID `json:"jobRunId"`
TickerId pgtype.UUID `json:"tickerId"`
}
func (q *Queries) ListStepRuns(ctx context.Context, db DBTX, arg ListStepRunsParams) ([]pgtype.UUID, error) {
rows, err := db.Query(ctx, listStepRuns,
arg.Tenantid,
arg.Status,
arg.WorkflowRunId,
arg.JobRunId,
arg.TickerId,
)
if err != nil {
return nil, err
}
defer rows.Close()
var items []pgtype.UUID
for rows.Next() {
var id pgtype.UUID
if err := rows.Scan(&id); err != nil {
return nil, err
}
items = append(items, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listStepRunsToReassign = `-- name: ListStepRunsToReassign :many
SELECT
sr.id, sr."createdAt", sr."updatedAt", sr."deletedAt", sr."tenantId", sr."jobRunId", sr."stepId", sr."order", sr."workerId", sr."tickerId", sr.status, sr.input, sr.output, sr."requeueAfter", sr."scheduleTimeoutAt", sr.error, sr."startedAt", sr."finishedAt", sr."timeoutAt", sr."cancelledAt", sr."cancelledReason", sr."cancelledError", sr."inputSchema", sr."callerFiles", sr."gitRepoBranch", sr."retryCount"
@@ -0,0 +1,13 @@
-- name: ListTenants :many
SELECT
*
FROM
"Tenant" as tenants;
-- name: GetTenantByID :one
SELECT
*
FROM
"Tenant" as tenants
WHERE
"id" = sqlc.arg('id')::uuid;
@@ -0,0 +1,69 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.24.0
// source: tenants.sql
package dbsqlc
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const getTenantByID = `-- name: GetTenantByID :one
SELECT
id, "createdAt", "updatedAt", "deletedAt", name, slug
FROM
"Tenant" as tenants
WHERE
"id" = $1::uuid
`
func (q *Queries) GetTenantByID(ctx context.Context, db DBTX, id pgtype.UUID) (*Tenant, error) {
row := db.QueryRow(ctx, getTenantByID, id)
var i Tenant
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.DeletedAt,
&i.Name,
&i.Slug,
)
return &i, err
}
const listTenants = `-- name: ListTenants :many
SELECT
id, "createdAt", "updatedAt", "deletedAt", name, slug
FROM
"Tenant" as tenants
`
func (q *Queries) ListTenants(ctx context.Context, db DBTX) ([]*Tenant, error) {
rows, err := db.Query(ctx, listTenants)
if err != nil {
return nil, err
}
defer rows.Close()
var items []*Tenant
for rows.Next() {
var i Tenant
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.DeletedAt,
&i.Name,
&i.Slug,
); err != nil {
return nil, err
}
items = append(items, &i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
+183 -2
View File
@@ -1,3 +1,10 @@
-- name: CreateTicker :one
INSERT INTO
"Ticker" ("id", "lastHeartbeatAt", "isActive")
VALUES
(sqlc.arg('id')::uuid, CURRENT_TIMESTAMP, 't')
RETURNING *;
-- name: ListNewlyStaleTickers :many
SELECT
sqlc.embed(tickers)
@@ -30,6 +37,180 @@ RETURNING
-- name: ListTickers :many
SELECT
sqlc.embed(tickers)
*
FROM
"Ticker" as tickers;
"Ticker" as tickers
WHERE
(
sqlc.arg('isActive')::boolean IS NULL OR
"isActive" = sqlc.arg('isActive')::boolean
)
AND
(
sqlc.arg('lastHeartbeatAfter')::timestamp IS NULL OR
tickers."lastHeartbeatAt" > sqlc.narg('lastHeartbeatAfter')::timestamp
);
-- name: DeleteTicker :one
DELETE FROM
"Ticker" as tickers
WHERE
"id" = sqlc.arg('id')::uuid
RETURNING *;
-- name: UpdateTicker :one
UPDATE
"Ticker" as tickers
SET
"lastHeartbeatAt" = sqlc.arg('lastHeartbeatAt')::timestamp
WHERE
"id" = sqlc.arg('id')::uuid
RETURNING *;
-- name: PollStepRuns :many
WITH stepRunsToTimeout AS (
SELECT
stepRun."id"
FROM
"StepRun" as stepRun
WHERE
"status" = 'RUNNING'
AND "timeoutAt" < NOW()
AND (
NOT EXISTS (
SELECT 1 FROM "Ticker" WHERE "id" = stepRun."tickerId" AND "isActive" = true AND "lastHeartbeatAt" >= NOW() - INTERVAL '10 seconds'
)
OR "tickerId" IS NULL
)
FOR UPDATE SKIP LOCKED
)
UPDATE
"StepRun" as stepRuns
SET
"tickerId" = @tickerId::uuid
FROM
stepRunsToTimeout
WHERE
stepRuns."id" = stepRunsToTimeout."id"
RETURNING stepRuns.*;
-- name: PollGetGroupKeyRuns :many
WITH getGroupKeyRunsToTimeout AS (
SELECT
getGroupKeyRun."id"
FROM
"GetGroupKeyRun" as getGroupKeyRun
WHERE
"status" = 'RUNNING'
AND "timeoutAt" < NOW()
AND (
NOT EXISTS (
SELECT 1 FROM "Ticker" WHERE "id" = getGroupKeyRun."tickerId" AND "isActive" = true AND "lastHeartbeatAt" >= NOW() - INTERVAL '10 seconds'
)
OR "tickerId" IS NULL
)
FOR UPDATE SKIP LOCKED
)
UPDATE
"GetGroupKeyRun" as getGroupKeyRuns
SET
"tickerId" = @tickerId::uuid
FROM
getGroupKeyRunsToTimeout
WHERE
getGroupKeyRuns."id" = getGroupKeyRunsToTimeout."id"
RETURNING getGroupKeyRuns.*;
-- name: PollCronSchedules :many
WITH latest_workflow_versions AS (
SELECT
"workflowId",
MAX("order") as max_order
FROM
"WorkflowVersion"
GROUP BY "workflowId"
),
active_cron_schedules AS (
SELECT
cronSchedule."parentId",
versions."id" AS "workflowVersionId",
triggers."tenantId" AS "tenantId"
FROM
"WorkflowTriggerCronRef" as cronSchedule
JOIN
"WorkflowTriggers" as triggers ON triggers."id" = cronSchedule."parentId"
JOIN
"WorkflowVersion" as versions ON versions."id" = triggers."workflowVersionId"
JOIN
latest_workflow_versions l ON versions."workflowId" = l."workflowId" AND versions."order" = l.max_order
WHERE
"tickerId" IS NULL
OR NOT EXISTS (
SELECT 1 FROM "Ticker" WHERE "id" = cronSchedule."tickerId" AND "isActive" = true AND "lastHeartbeatAt" >= NOW() - INTERVAL '10 seconds'
)
OR "tickerId" = @tickerId::uuid
FOR UPDATE SKIP LOCKED
)
UPDATE
"WorkflowTriggerCronRef" as cronSchedules
SET
"tickerId" = @tickerId::uuid
FROM
active_cron_schedules
WHERE
cronSchedules."parentId" = active_cron_schedules."parentId"
RETURNING cronSchedules.*, active_cron_schedules."workflowVersionId", active_cron_schedules."tenantId";
-- name: PollScheduledWorkflows :many
-- Finds workflows that are either past their execution time or will be in the next 5 seconds and assigns them
-- to a ticker, or finds workflows that were assigned to a ticker that is no longer active
WITH latest_workflow_versions AS (
SELECT
"workflowId",
MAX("order") as max_order
FROM
"WorkflowVersion"
GROUP BY "workflowId"
),
not_run_scheduled_workflows AS (
SELECT
scheduledWorkflow."id",
versions."id" AS "workflowVersionId",
workflow."tenantId" AS "tenantId"
FROM
"WorkflowTriggerScheduledRef" as scheduledWorkflow
JOIN
"WorkflowVersion" as versions ON versions."id" = scheduledWorkflow."parentId"
JOIN
latest_workflow_versions l ON versions."workflowId" = l."workflowId" AND versions."order" = l.max_order
JOIN
"Workflow" as workflow ON workflow."id" = versions."workflowId"
LEFT JOIN
"WorkflowRunTriggeredBy" as runTriggeredBy ON runTriggeredBy."scheduledId" = scheduledWorkflow."id"
WHERE
"triggerAt" <= NOW() + INTERVAL '5 seconds'
AND runTriggeredBy IS NULL
AND (
"tickerId" IS NULL
OR NOT EXISTS (
SELECT 1 FROM "Ticker" WHERE "id" = scheduledWorkflow."tickerId" AND "isActive" = true AND "lastHeartbeatAt" >= NOW() - INTERVAL '10 seconds'
)
OR "tickerId" = @tickerId::uuid
)
),
active_scheduled_workflows AS (
SELECT
*
FROM
not_run_scheduled_workflows
FOR UPDATE SKIP LOCKED
)
UPDATE
"WorkflowTriggerScheduledRef" as scheduledWorkflows
SET
"tickerId" = @tickerId::uuid
FROM
active_scheduled_workflows
WHERE
scheduledWorkflows."id" = active_scheduled_workflows."id"
RETURNING scheduledWorkflows.*, active_scheduled_workflows."workflowVersionId", active_scheduled_workflows."tenantId";
+407 -12
View File
@@ -11,6 +11,48 @@ import (
"github.com/jackc/pgx/v5/pgtype"
)
const createTicker = `-- name: CreateTicker :one
INSERT INTO
"Ticker" ("id", "lastHeartbeatAt", "isActive")
VALUES
($1::uuid, CURRENT_TIMESTAMP, 't')
RETURNING id, "createdAt", "updatedAt", "lastHeartbeatAt", "isActive"
`
func (q *Queries) CreateTicker(ctx context.Context, db DBTX, id pgtype.UUID) (*Ticker, error) {
row := db.QueryRow(ctx, createTicker, id)
var i Ticker
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.LastHeartbeatAt,
&i.IsActive,
)
return &i, err
}
const deleteTicker = `-- name: DeleteTicker :one
DELETE FROM
"Ticker" as tickers
WHERE
"id" = $1::uuid
RETURNING id, "createdAt", "updatedAt", "lastHeartbeatAt", "isActive"
`
func (q *Queries) DeleteTicker(ctx context.Context, db DBTX, id pgtype.UUID) (*Ticker, error) {
row := db.QueryRow(ctx, deleteTicker, id)
var i Ticker
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.LastHeartbeatAt,
&i.IsActive,
)
return &i, err
}
const listActiveTickers = `-- name: ListActiveTickers :many
SELECT
tickers.id, tickers."createdAt", tickers."updatedAt", tickers."lastHeartbeatAt", tickers."isActive"
@@ -95,30 +137,355 @@ func (q *Queries) ListNewlyStaleTickers(ctx context.Context, db DBTX) ([]*ListNe
const listTickers = `-- name: ListTickers :many
SELECT
tickers.id, tickers."createdAt", tickers."updatedAt", tickers."lastHeartbeatAt", tickers."isActive"
id, "createdAt", "updatedAt", "lastHeartbeatAt", "isActive"
FROM
"Ticker" as tickers
WHERE
(
$1::boolean IS NULL OR
"isActive" = $1::boolean
)
AND
(
$2::timestamp IS NULL OR
tickers."lastHeartbeatAt" > $2::timestamp
)
`
type ListTickersRow struct {
Ticker Ticker `json:"ticker"`
type ListTickersParams struct {
IsActive bool `json:"isActive"`
LastHeartbeatAfter pgtype.Timestamp `json:"lastHeartbeatAfter"`
}
func (q *Queries) ListTickers(ctx context.Context, db DBTX) ([]*ListTickersRow, error) {
rows, err := db.Query(ctx, listTickers)
func (q *Queries) ListTickers(ctx context.Context, db DBTX, arg ListTickersParams) ([]*Ticker, error) {
rows, err := db.Query(ctx, listTickers, arg.IsActive, arg.LastHeartbeatAfter)
if err != nil {
return nil, err
}
defer rows.Close()
var items []*ListTickersRow
var items []*Ticker
for rows.Next() {
var i ListTickersRow
var i Ticker
if err := rows.Scan(
&i.Ticker.ID,
&i.Ticker.CreatedAt,
&i.Ticker.UpdatedAt,
&i.Ticker.LastHeartbeatAt,
&i.Ticker.IsActive,
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.LastHeartbeatAt,
&i.IsActive,
); err != nil {
return nil, err
}
items = append(items, &i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const pollCronSchedules = `-- name: PollCronSchedules :many
WITH latest_workflow_versions AS (
SELECT
"workflowId",
MAX("order") as max_order
FROM
"WorkflowVersion"
GROUP BY "workflowId"
),
active_cron_schedules AS (
SELECT
cronSchedule."parentId",
versions."id" AS "workflowVersionId",
triggers."tenantId" AS "tenantId"
FROM
"WorkflowTriggerCronRef" as cronSchedule
JOIN
"WorkflowTriggers" as triggers ON triggers."id" = cronSchedule."parentId"
JOIN
"WorkflowVersion" as versions ON versions."id" = triggers."workflowVersionId"
JOIN
latest_workflow_versions l ON versions."workflowId" = l."workflowId" AND versions."order" = l.max_order
WHERE
"tickerId" IS NULL
OR NOT EXISTS (
SELECT 1 FROM "Ticker" WHERE "id" = cronSchedule."tickerId" AND "isActive" = true AND "lastHeartbeatAt" >= NOW() - INTERVAL '10 seconds'
)
OR "tickerId" = $1::uuid
FOR UPDATE SKIP LOCKED
)
UPDATE
"WorkflowTriggerCronRef" as cronSchedules
SET
"tickerId" = $1::uuid
FROM
active_cron_schedules
WHERE
cronSchedules."parentId" = active_cron_schedules."parentId"
RETURNING cronschedules."parentId", cronschedules.cron, cronschedules."tickerId", cronschedules.input, active_cron_schedules."workflowVersionId", active_cron_schedules."tenantId"
`
type PollCronSchedulesRow struct {
ParentId pgtype.UUID `json:"parentId"`
Cron string `json:"cron"`
TickerId pgtype.UUID `json:"tickerId"`
Input []byte `json:"input"`
WorkflowVersionId pgtype.UUID `json:"workflowVersionId"`
TenantId pgtype.UUID `json:"tenantId"`
}
func (q *Queries) PollCronSchedules(ctx context.Context, db DBTX, tickerid pgtype.UUID) ([]*PollCronSchedulesRow, error) {
rows, err := db.Query(ctx, pollCronSchedules, tickerid)
if err != nil {
return nil, err
}
defer rows.Close()
var items []*PollCronSchedulesRow
for rows.Next() {
var i PollCronSchedulesRow
if err := rows.Scan(
&i.ParentId,
&i.Cron,
&i.TickerId,
&i.Input,
&i.WorkflowVersionId,
&i.TenantId,
); err != nil {
return nil, err
}
items = append(items, &i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const pollGetGroupKeyRuns = `-- name: PollGetGroupKeyRuns :many
WITH getGroupKeyRunsToTimeout AS (
SELECT
getGroupKeyRun."id"
FROM
"GetGroupKeyRun" as getGroupKeyRun
WHERE
"status" = 'RUNNING'
AND "timeoutAt" < NOW()
AND (
NOT EXISTS (
SELECT 1 FROM "Ticker" WHERE "id" = getGroupKeyRun."tickerId" AND "isActive" = true AND "lastHeartbeatAt" >= NOW() - INTERVAL '10 seconds'
)
OR "tickerId" IS NULL
)
FOR UPDATE SKIP LOCKED
)
UPDATE
"GetGroupKeyRun" as getGroupKeyRuns
SET
"tickerId" = $1::uuid
FROM
getGroupKeyRunsToTimeout
WHERE
getGroupKeyRuns."id" = getGroupKeyRunsToTimeout."id"
RETURNING getgroupkeyruns.id, getgroupkeyruns."createdAt", getgroupkeyruns."updatedAt", getgroupkeyruns."deletedAt", getgroupkeyruns."tenantId", getgroupkeyruns."workerId", getgroupkeyruns."tickerId", getgroupkeyruns.status, getgroupkeyruns.input, getgroupkeyruns.output, getgroupkeyruns."requeueAfter", getgroupkeyruns.error, getgroupkeyruns."startedAt", getgroupkeyruns."finishedAt", getgroupkeyruns."timeoutAt", getgroupkeyruns."cancelledAt", getgroupkeyruns."cancelledReason", getgroupkeyruns."cancelledError", getgroupkeyruns."workflowRunId", getgroupkeyruns."scheduleTimeoutAt"
`
func (q *Queries) PollGetGroupKeyRuns(ctx context.Context, db DBTX, tickerid pgtype.UUID) ([]*GetGroupKeyRun, error) {
rows, err := db.Query(ctx, pollGetGroupKeyRuns, tickerid)
if err != nil {
return nil, err
}
defer rows.Close()
var items []*GetGroupKeyRun
for rows.Next() {
var i GetGroupKeyRun
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.DeletedAt,
&i.TenantId,
&i.WorkerId,
&i.TickerId,
&i.Status,
&i.Input,
&i.Output,
&i.RequeueAfter,
&i.Error,
&i.StartedAt,
&i.FinishedAt,
&i.TimeoutAt,
&i.CancelledAt,
&i.CancelledReason,
&i.CancelledError,
&i.WorkflowRunId,
&i.ScheduleTimeoutAt,
); err != nil {
return nil, err
}
items = append(items, &i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const pollScheduledWorkflows = `-- name: PollScheduledWorkflows :many
WITH latest_workflow_versions AS (
SELECT
"workflowId",
MAX("order") as max_order
FROM
"WorkflowVersion"
GROUP BY "workflowId"
),
not_run_scheduled_workflows AS (
SELECT
scheduledWorkflow."id",
versions."id" AS "workflowVersionId",
workflow."tenantId" AS "tenantId"
FROM
"WorkflowTriggerScheduledRef" as scheduledWorkflow
JOIN
"WorkflowVersion" as versions ON versions."id" = scheduledWorkflow."parentId"
JOIN
latest_workflow_versions l ON versions."workflowId" = l."workflowId" AND versions."order" = l.max_order
JOIN
"Workflow" as workflow ON workflow."id" = versions."workflowId"
LEFT JOIN
"WorkflowRunTriggeredBy" as runTriggeredBy ON runTriggeredBy."scheduledId" = scheduledWorkflow."id"
WHERE
"triggerAt" <= NOW() + INTERVAL '5 seconds'
AND runTriggeredBy IS NULL
AND (
"tickerId" IS NULL
OR NOT EXISTS (
SELECT 1 FROM "Ticker" WHERE "id" = scheduledWorkflow."tickerId" AND "isActive" = true AND "lastHeartbeatAt" >= NOW() - INTERVAL '10 seconds'
)
OR "tickerId" = $1::uuid
)
),
active_scheduled_workflows AS (
SELECT
id, "workflowVersionId", "tenantId"
FROM
not_run_scheduled_workflows
FOR UPDATE SKIP LOCKED
)
UPDATE
"WorkflowTriggerScheduledRef" as scheduledWorkflows
SET
"tickerId" = $1::uuid
FROM
active_scheduled_workflows
WHERE
scheduledWorkflows."id" = active_scheduled_workflows."id"
RETURNING scheduledworkflows.id, scheduledworkflows."parentId", scheduledworkflows."triggerAt", scheduledworkflows."tickerId", scheduledworkflows.input, active_scheduled_workflows."workflowVersionId", active_scheduled_workflows."tenantId"
`
type PollScheduledWorkflowsRow struct {
ID pgtype.UUID `json:"id"`
ParentId pgtype.UUID `json:"parentId"`
TriggerAt pgtype.Timestamp `json:"triggerAt"`
TickerId pgtype.UUID `json:"tickerId"`
Input []byte `json:"input"`
WorkflowVersionId pgtype.UUID `json:"workflowVersionId"`
TenantId pgtype.UUID `json:"tenantId"`
}
// Finds workflows that are either past their execution time or will be in the next 5 seconds and assigns them
// to a ticker, or finds workflows that were assigned to a ticker that is no longer active
func (q *Queries) PollScheduledWorkflows(ctx context.Context, db DBTX, tickerid pgtype.UUID) ([]*PollScheduledWorkflowsRow, error) {
rows, err := db.Query(ctx, pollScheduledWorkflows, tickerid)
if err != nil {
return nil, err
}
defer rows.Close()
var items []*PollScheduledWorkflowsRow
for rows.Next() {
var i PollScheduledWorkflowsRow
if err := rows.Scan(
&i.ID,
&i.ParentId,
&i.TriggerAt,
&i.TickerId,
&i.Input,
&i.WorkflowVersionId,
&i.TenantId,
); err != nil {
return nil, err
}
items = append(items, &i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const pollStepRuns = `-- name: PollStepRuns :many
WITH stepRunsToTimeout AS (
SELECT
stepRun."id"
FROM
"StepRun" as stepRun
WHERE
"status" = 'RUNNING'
AND "timeoutAt" < NOW()
AND (
NOT EXISTS (
SELECT 1 FROM "Ticker" WHERE "id" = stepRun."tickerId" AND "isActive" = true AND "lastHeartbeatAt" >= NOW() - INTERVAL '10 seconds'
)
OR "tickerId" IS NULL
)
FOR UPDATE SKIP LOCKED
)
UPDATE
"StepRun" as stepRuns
SET
"tickerId" = $1::uuid
FROM
stepRunsToTimeout
WHERE
stepRuns."id" = stepRunsToTimeout."id"
RETURNING stepruns.id, stepruns."createdAt", stepruns."updatedAt", stepruns."deletedAt", stepruns."tenantId", stepruns."jobRunId", stepruns."stepId", stepruns."order", stepruns."workerId", stepruns."tickerId", stepruns.status, stepruns.input, stepruns.output, stepruns."requeueAfter", stepruns."scheduleTimeoutAt", stepruns.error, stepruns."startedAt", stepruns."finishedAt", stepruns."timeoutAt", stepruns."cancelledAt", stepruns."cancelledReason", stepruns."cancelledError", stepruns."inputSchema", stepruns."callerFiles", stepruns."gitRepoBranch", stepruns."retryCount"
`
func (q *Queries) PollStepRuns(ctx context.Context, db DBTX, tickerid pgtype.UUID) ([]*StepRun, error) {
rows, err := db.Query(ctx, pollStepRuns, tickerid)
if err != nil {
return nil, err
}
defer rows.Close()
var items []*StepRun
for rows.Next() {
var i StepRun
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.DeletedAt,
&i.TenantId,
&i.JobRunId,
&i.StepId,
&i.Order,
&i.WorkerId,
&i.TickerId,
&i.Status,
&i.Input,
&i.Output,
&i.RequeueAfter,
&i.ScheduleTimeoutAt,
&i.Error,
&i.StartedAt,
&i.FinishedAt,
&i.TimeoutAt,
&i.CancelledAt,
&i.CancelledReason,
&i.CancelledError,
&i.InputSchema,
&i.CallerFiles,
&i.GitRepoBranch,
&i.RetryCount,
); err != nil {
return nil, err
}
@@ -170,3 +537,31 @@ func (q *Queries) SetTickersInactive(ctx context.Context, db DBTX, ids []pgtype.
}
return items, nil
}
const updateTicker = `-- name: UpdateTicker :one
UPDATE
"Ticker" as tickers
SET
"lastHeartbeatAt" = $1::timestamp
WHERE
"id" = $2::uuid
RETURNING id, "createdAt", "updatedAt", "lastHeartbeatAt", "isActive"
`
type UpdateTickerParams struct {
LastHeartbeatAt pgtype.Timestamp `json:"lastHeartbeatAt"`
ID pgtype.UUID `json:"id"`
}
func (q *Queries) UpdateTicker(ctx context.Context, db DBTX, arg UpdateTickerParams) (*Ticker, error) {
row := db.QueryRow(ctx, updateTicker, arg.LastHeartbeatAt, arg.ID)
var i Ticker
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.LastHeartbeatAt,
&i.IsActive,
)
return &i, err
}
+84 -1
View File
@@ -42,4 +42,87 @@ FROM
"Worker" w
WHERE
w."tenantId" = @tenantId
AND w."id" = @id;
AND w."id" = @id;
-- name: CreateWorker :one
INSERT INTO "Worker" (
"id",
"createdAt",
"updatedAt",
"tenantId",
"name",
"status",
"dispatcherId",
"maxRuns"
) VALUES (
gen_random_uuid(),
CURRENT_TIMESTAMP,
CURRENT_TIMESTAMP,
@tenantId::uuid,
@name::text,
'ACTIVE',
@dispatcherId::uuid,
sqlc.narg('maxRuns')::int
) RETURNING *;
-- name: UpdateWorker :one
UPDATE
"Worker"
SET
"updatedAt" = CURRENT_TIMESTAMP,
"status" = coalesce(sqlc.narg('status')::"WorkerStatus", "status"),
"dispatcherId" = coalesce(sqlc.narg('dispatcherId')::uuid, "dispatcherId"),
"maxRuns" = coalesce(sqlc.narg('maxRuns')::int, "maxRuns"),
"lastHeartbeatAt" = coalesce(sqlc.narg('lastHeartbeatAt')::timestamp, "lastHeartbeatAt")
WHERE
"id" = @id::uuid
RETURNING *;
-- name: LinkActionsToWorker :exec
INSERT INTO "_ActionToWorker" (
"A",
"B"
) SELECT
unnest(@actionIds::uuid[]),
@workerId::uuid
ON CONFLICT DO NOTHING;
-- name: UpsertService :one
INSERT INTO "Service" (
"id",
"createdAt",
"updatedAt",
"name",
"tenantId"
)
VALUES (
gen_random_uuid(),
CURRENT_TIMESTAMP,
CURRENT_TIMESTAMP,
@name::text,
@tenantId::uuid
)
ON CONFLICT ("tenantId", "name") DO UPDATE
SET
"updatedAt" = CURRENT_TIMESTAMP
WHERE
"Service"."tenantId" = @tenantId AND "Service"."name" = @name::text
RETURNING *;
-- name: LinkServicesToWorker :exec
INSERT INTO "_ServiceToWorker" (
"A",
"B"
)
VALUES (
unnest(@services::uuid[]),
@workerId::uuid
)
ON CONFLICT DO NOTHING;
-- name: DeleteWorker :one
DELETE FROM
"Worker"
WHERE
"id" = @id::uuid
RETURNING *;
@@ -11,6 +11,84 @@ import (
"github.com/jackc/pgx/v5/pgtype"
)
const createWorker = `-- name: CreateWorker :one
INSERT INTO "Worker" (
"id",
"createdAt",
"updatedAt",
"tenantId",
"name",
"status",
"dispatcherId",
"maxRuns"
) VALUES (
gen_random_uuid(),
CURRENT_TIMESTAMP,
CURRENT_TIMESTAMP,
$1::uuid,
$2::text,
'ACTIVE',
$3::uuid,
$4::int
) RETURNING id, "createdAt", "updatedAt", "deletedAt", "tenantId", "lastHeartbeatAt", name, status, "dispatcherId", "maxRuns"
`
type CreateWorkerParams struct {
Tenantid pgtype.UUID `json:"tenantid"`
Name string `json:"name"`
Dispatcherid pgtype.UUID `json:"dispatcherid"`
MaxRuns pgtype.Int4 `json:"maxRuns"`
}
func (q *Queries) CreateWorker(ctx context.Context, db DBTX, arg CreateWorkerParams) (*Worker, error) {
row := db.QueryRow(ctx, createWorker,
arg.Tenantid,
arg.Name,
arg.Dispatcherid,
arg.MaxRuns,
)
var i Worker
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.DeletedAt,
&i.TenantId,
&i.LastHeartbeatAt,
&i.Name,
&i.Status,
&i.DispatcherId,
&i.MaxRuns,
)
return &i, err
}
const deleteWorker = `-- name: DeleteWorker :one
DELETE FROM
"Worker"
WHERE
"id" = $1::uuid
RETURNING id, "createdAt", "updatedAt", "deletedAt", "tenantId", "lastHeartbeatAt", name, status, "dispatcherId", "maxRuns"
`
func (q *Queries) DeleteWorker(ctx context.Context, db DBTX, id pgtype.UUID) (*Worker, error) {
row := db.QueryRow(ctx, deleteWorker, id)
var i Worker
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.DeletedAt,
&i.TenantId,
&i.LastHeartbeatAt,
&i.Name,
&i.Status,
&i.DispatcherId,
&i.MaxRuns,
)
return &i, err
}
const getWorkerForEngine = `-- name: GetWorkerForEngine :one
SELECT
w."id" AS "id",
@@ -41,6 +119,48 @@ func (q *Queries) GetWorkerForEngine(ctx context.Context, db DBTX, arg GetWorker
return &i, err
}
const linkActionsToWorker = `-- name: LinkActionsToWorker :exec
INSERT INTO "_ActionToWorker" (
"A",
"B"
) SELECT
unnest($1::uuid[]),
$2::uuid
ON CONFLICT DO NOTHING
`
type LinkActionsToWorkerParams struct {
Actionids []pgtype.UUID `json:"actionids"`
Workerid pgtype.UUID `json:"workerid"`
}
func (q *Queries) LinkActionsToWorker(ctx context.Context, db DBTX, arg LinkActionsToWorkerParams) error {
_, err := db.Exec(ctx, linkActionsToWorker, arg.Actionids, arg.Workerid)
return err
}
const linkServicesToWorker = `-- name: LinkServicesToWorker :exec
INSERT INTO "_ServiceToWorker" (
"A",
"B"
)
VALUES (
unnest($1::uuid[]),
$2::uuid
)
ON CONFLICT DO NOTHING
`
type LinkServicesToWorkerParams struct {
Services []pgtype.UUID `json:"services"`
Workerid pgtype.UUID `json:"workerid"`
}
func (q *Queries) LinkServicesToWorker(ctx context.Context, db DBTX, arg LinkServicesToWorkerParams) error {
_, err := db.Exec(ctx, linkServicesToWorker, arg.Services, arg.Workerid)
return err
}
const listWorkersWithStepCount = `-- name: ListWorkersWithStepCount :many
SELECT
workers.id, workers."createdAt", workers."updatedAt", workers."deletedAt", workers."tenantId", workers."lastHeartbeatAt", workers.name, workers.status, workers."dispatcherId", workers."maxRuns",
@@ -125,3 +245,92 @@ func (q *Queries) ListWorkersWithStepCount(ctx context.Context, db DBTX, arg Lis
}
return items, nil
}
const updateWorker = `-- name: UpdateWorker :one
UPDATE
"Worker"
SET
"updatedAt" = CURRENT_TIMESTAMP,
"status" = coalesce($1::"WorkerStatus", "status"),
"dispatcherId" = coalesce($2::uuid, "dispatcherId"),
"maxRuns" = coalesce($3::int, "maxRuns"),
"lastHeartbeatAt" = coalesce($4::timestamp, "lastHeartbeatAt")
WHERE
"id" = $5::uuid
RETURNING id, "createdAt", "updatedAt", "deletedAt", "tenantId", "lastHeartbeatAt", name, status, "dispatcherId", "maxRuns"
`
type UpdateWorkerParams struct {
Status NullWorkerStatus `json:"status"`
DispatcherId pgtype.UUID `json:"dispatcherId"`
MaxRuns pgtype.Int4 `json:"maxRuns"`
LastHeartbeatAt pgtype.Timestamp `json:"lastHeartbeatAt"`
ID pgtype.UUID `json:"id"`
}
func (q *Queries) UpdateWorker(ctx context.Context, db DBTX, arg UpdateWorkerParams) (*Worker, error) {
row := db.QueryRow(ctx, updateWorker,
arg.Status,
arg.DispatcherId,
arg.MaxRuns,
arg.LastHeartbeatAt,
arg.ID,
)
var i Worker
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.DeletedAt,
&i.TenantId,
&i.LastHeartbeatAt,
&i.Name,
&i.Status,
&i.DispatcherId,
&i.MaxRuns,
)
return &i, err
}
const upsertService = `-- name: UpsertService :one
INSERT INTO "Service" (
"id",
"createdAt",
"updatedAt",
"name",
"tenantId"
)
VALUES (
gen_random_uuid(),
CURRENT_TIMESTAMP,
CURRENT_TIMESTAMP,
$1::text,
$2::uuid
)
ON CONFLICT ("tenantId", "name") DO UPDATE
SET
"updatedAt" = CURRENT_TIMESTAMP
WHERE
"Service"."tenantId" = $2 AND "Service"."name" = $1::text
RETURNING id, "createdAt", "updatedAt", "deletedAt", name, description, "tenantId"
`
type UpsertServiceParams struct {
Name string `json:"name"`
Tenantid pgtype.UUID `json:"tenantid"`
}
func (q *Queries) UpsertService(ctx context.Context, db DBTX, arg UpsertServiceParams) (*Service, error) {
row := db.QueryRow(ctx, upsertService, arg.Name, arg.Tenantid)
var i Service
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.DeletedAt,
&i.Name,
&i.Description,
&i.TenantId,
)
return &i, err
}
@@ -98,7 +98,7 @@ WITH running_count AS (
SELECT
r2.id,
row_number() OVER (PARTITION BY r2."concurrencyGroupId" ORDER BY r2."createdAt") AS rn,
row_number() over (order by r2."id" desc) as seqnum
row_number() over (order by r2."createdAt" ASC) as seqnum
FROM
"WorkflowRun" r2
LEFT JOIN
@@ -108,14 +108,28 @@ WITH running_count AS (
r2."status" = 'QUEUED' AND
workflowVersion."id" = $2
ORDER BY
rn ASC
rn, seqnum ASC
), min_rn AS (
SELECT
MIN(rn) as min_rn
FROM
queued_row_numbers
), first_partition_count AS (
SELECT
COUNT(*) as count
FROM
queued_row_numbers
WHERE
rn = (SELECT min_rn FROM min_rn)
), eligible_runs AS (
SELECT
id
FROM
queued_row_numbers
WHERE
queued_row_numbers."seqnum" <= (@maxRuns::int) - (SELECT "count" FROM running_count)
-- We can run up to maxRuns per group, so we multiple max runs by the number of groups, then subtract the
-- total number of running workflows.
queued_row_numbers."seqnum" <= (@maxRuns::int) * (SELECT count FROM first_partition_count) - (SELECT "count" FROM running_count)
FOR UPDATE SKIP LOCKED
)
UPDATE "WorkflowRun"
@@ -424,36 +438,28 @@ JOIN
JOIN
"StepRun" AS child_run ON child_run."stepId" = step_order."B" AND child_run."jobRunId" = @jobRunId::uuid;
-- name: ListStartableStepRuns :many
WITH job_run AS (
SELECT "status"
FROM "JobRun"
WHERE "id" = @jobRunId::uuid
)
SELECT
child_run."id" AS "id"
FROM
"StepRun" AS child_run
LEFT JOIN
"_StepRunOrder" AS step_run_order ON step_run_order."B" = child_run."id"
JOIN
job_run ON true
WHERE
child_run."jobRunId" = @jobRunId::uuid
AND child_run."status" = 'PENDING'
AND job_run."status" = 'RUNNING'
-- case on whether parentStepRunId is null
AND (
(sqlc.narg('parentStepRunId')::uuid IS NULL AND step_run_order."A" IS NULL) OR
(
step_run_order."A" = sqlc.narg('parentStepRunId')::uuid
AND NOT EXISTS (
SELECT 1
FROM "_StepRunOrder" AS parent_order
JOIN "StepRun" AS parent_run ON parent_order."A" = parent_run."id"
WHERE
parent_order."B" = child_run."id"
AND parent_run."status" != 'SUCCEEDED'
)
)
);
-- name: GetWorkflowRun :many
SELECT
sqlc.embed(runs),
sqlc.embed(runTriggers),
sqlc.embed(workflowVersion),
workflow."name" as "workflowName",
-- waiting on https://github.com/sqlc-dev/sqlc/pull/2858 for nullable fields
wc."limitStrategy" as "concurrencyLimitStrategy",
wc."maxRuns" as "concurrencyMaxRuns",
groupKeyRun."id" as "getGroupKeyRunId"
FROM
"WorkflowRun" as runs
LEFT JOIN
"WorkflowRunTriggeredBy" as runTriggers ON runTriggers."parentId" = runs."id"
LEFT JOIN
"WorkflowVersion" as workflowVersion ON runs."workflowVersionId" = workflowVersion."id"
LEFT JOIN
"Workflow" as workflow ON workflowVersion."workflowId" = workflow."id"
LEFT JOIN
"WorkflowConcurrency" as wc ON wc."workflowVersionId" = workflowVersion."id"
LEFT JOIN
"GetGroupKeyRun" as groupKeyRun ON groupKeyRun."workflowRunId" = runs."id"
WHERE
runs."id" = ANY(@ids::uuid[]) AND
runs."tenantId" = @tenantId::uuid;
@@ -430,6 +430,106 @@ func (q *Queries) CreateWorkflowRunTriggeredBy(ctx context.Context, db DBTX, arg
return &i, err
}
const getWorkflowRun = `-- name: GetWorkflowRun :many
SELECT
runs."createdAt", runs."updatedAt", runs."deletedAt", runs."tenantId", runs."workflowVersionId", runs.status, runs.error, runs."startedAt", runs."finishedAt", runs."concurrencyGroupId", runs."displayName", runs.id, runs."gitRepoBranch",
runtriggers.id, runtriggers."createdAt", runtriggers."updatedAt", runtriggers."deletedAt", runtriggers."tenantId", runtriggers."eventId", runtriggers."cronParentId", runtriggers."cronSchedule", runtriggers."scheduledId", runtriggers.input, runtriggers."parentId",
workflowversion.id, workflowversion."createdAt", workflowversion."updatedAt", workflowversion."deletedAt", workflowversion.version, workflowversion."order", workflowversion."workflowId", workflowversion.checksum, workflowversion."scheduleTimeout",
workflow."name" as "workflowName",
-- waiting on https://github.com/sqlc-dev/sqlc/pull/2858 for nullable fields
wc."limitStrategy" as "concurrencyLimitStrategy",
wc."maxRuns" as "concurrencyMaxRuns",
groupKeyRun."id" as "getGroupKeyRunId"
FROM
"WorkflowRun" as runs
LEFT JOIN
"WorkflowRunTriggeredBy" as runTriggers ON runTriggers."parentId" = runs."id"
LEFT JOIN
"WorkflowVersion" as workflowVersion ON runs."workflowVersionId" = workflowVersion."id"
LEFT JOIN
"Workflow" as workflow ON workflowVersion."workflowId" = workflow."id"
LEFT JOIN
"WorkflowConcurrency" as wc ON wc."workflowVersionId" = workflowVersion."id"
LEFT JOIN
"GetGroupKeyRun" as groupKeyRun ON groupKeyRun."workflowRunId" = runs."id"
WHERE
runs."id" = ANY($1::uuid[]) AND
runs."tenantId" = $2::uuid
`
type GetWorkflowRunParams struct {
Ids []pgtype.UUID `json:"ids"`
Tenantid pgtype.UUID `json:"tenantid"`
}
type GetWorkflowRunRow struct {
WorkflowRun WorkflowRun `json:"workflow_run"`
WorkflowRunTriggeredBy WorkflowRunTriggeredBy `json:"workflow_run_triggered_by"`
WorkflowVersion WorkflowVersion `json:"workflow_version"`
WorkflowName pgtype.Text `json:"workflowName"`
ConcurrencyLimitStrategy NullConcurrencyLimitStrategy `json:"concurrencyLimitStrategy"`
ConcurrencyMaxRuns pgtype.Int4 `json:"concurrencyMaxRuns"`
GetGroupKeyRunId pgtype.UUID `json:"getGroupKeyRunId"`
}
func (q *Queries) GetWorkflowRun(ctx context.Context, db DBTX, arg GetWorkflowRunParams) ([]*GetWorkflowRunRow, error) {
rows, err := db.Query(ctx, getWorkflowRun, arg.Ids, arg.Tenantid)
if err != nil {
return nil, err
}
defer rows.Close()
var items []*GetWorkflowRunRow
for rows.Next() {
var i GetWorkflowRunRow
if err := rows.Scan(
&i.WorkflowRun.CreatedAt,
&i.WorkflowRun.UpdatedAt,
&i.WorkflowRun.DeletedAt,
&i.WorkflowRun.TenantId,
&i.WorkflowRun.WorkflowVersionId,
&i.WorkflowRun.Status,
&i.WorkflowRun.Error,
&i.WorkflowRun.StartedAt,
&i.WorkflowRun.FinishedAt,
&i.WorkflowRun.ConcurrencyGroupId,
&i.WorkflowRun.DisplayName,
&i.WorkflowRun.ID,
&i.WorkflowRun.GitRepoBranch,
&i.WorkflowRunTriggeredBy.ID,
&i.WorkflowRunTriggeredBy.CreatedAt,
&i.WorkflowRunTriggeredBy.UpdatedAt,
&i.WorkflowRunTriggeredBy.DeletedAt,
&i.WorkflowRunTriggeredBy.TenantId,
&i.WorkflowRunTriggeredBy.EventId,
&i.WorkflowRunTriggeredBy.CronParentId,
&i.WorkflowRunTriggeredBy.CronSchedule,
&i.WorkflowRunTriggeredBy.ScheduledId,
&i.WorkflowRunTriggeredBy.Input,
&i.WorkflowRunTriggeredBy.ParentId,
&i.WorkflowVersion.ID,
&i.WorkflowVersion.CreatedAt,
&i.WorkflowVersion.UpdatedAt,
&i.WorkflowVersion.DeletedAt,
&i.WorkflowVersion.Version,
&i.WorkflowVersion.Order,
&i.WorkflowVersion.WorkflowId,
&i.WorkflowVersion.Checksum,
&i.WorkflowVersion.ScheduleTimeout,
&i.WorkflowName,
&i.ConcurrencyLimitStrategy,
&i.ConcurrencyMaxRuns,
&i.GetGroupKeyRunId,
); err != nil {
return nil, err
}
items = append(items, &i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const linkStepRunParents = `-- name: LinkStepRunParents :exec
INSERT INTO "_StepRunOrder" ("A", "B")
SELECT
@@ -448,66 +548,6 @@ func (q *Queries) LinkStepRunParents(ctx context.Context, db DBTX, jobrunid pgty
return err
}
const listStartableStepRuns = `-- name: ListStartableStepRuns :many
WITH job_run AS (
SELECT "status"
FROM "JobRun"
WHERE "id" = $1::uuid
)
SELECT
child_run."id" AS "id"
FROM
"StepRun" AS child_run
LEFT JOIN
"_StepRunOrder" AS step_run_order ON step_run_order."B" = child_run."id"
JOIN
job_run ON true
WHERE
child_run."jobRunId" = $1::uuid
AND child_run."status" = 'PENDING'
AND job_run."status" = 'RUNNING'
-- case on whether parentStepRunId is null
AND (
($2::uuid IS NULL AND step_run_order."A" IS NULL) OR
(
step_run_order."A" = $2::uuid
AND NOT EXISTS (
SELECT 1
FROM "_StepRunOrder" AS parent_order
JOIN "StepRun" AS parent_run ON parent_order."A" = parent_run."id"
WHERE
parent_order."B" = child_run."id"
AND parent_run."status" != 'SUCCEEDED'
)
)
)
`
type ListStartableStepRunsParams struct {
Jobrunid pgtype.UUID `json:"jobrunid"`
ParentStepRunId pgtype.UUID `json:"parentStepRunId"`
}
func (q *Queries) ListStartableStepRuns(ctx context.Context, db DBTX, arg ListStartableStepRunsParams) ([]pgtype.UUID, error) {
rows, err := db.Query(ctx, listStartableStepRuns, arg.Jobrunid, arg.ParentStepRunId)
if err != nil {
return nil, err
}
defer rows.Close()
var items []pgtype.UUID
for rows.Next() {
var id pgtype.UUID
if err := rows.Scan(&id); err != nil {
return nil, err
}
items = append(items, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listWorkflowRuns = `-- name: ListWorkflowRuns :many
SELECT
runs."createdAt", runs."updatedAt", runs."deletedAt", runs."tenantId", runs."workflowVersionId", runs.status, runs.error, runs."startedAt", runs."finishedAt", runs."concurrencyGroupId", runs."displayName", runs.id, runs."gitRepoBranch",
@@ -671,7 +711,7 @@ WITH running_count AS (
SELECT
r2.id,
row_number() OVER (PARTITION BY r2."concurrencyGroupId" ORDER BY r2."createdAt") AS rn,
row_number() over (order by r2."id" desc) as seqnum
row_number() over (order by r2."createdAt" ASC) as seqnum
FROM
"WorkflowRun" r2
LEFT JOIN
@@ -681,14 +721,28 @@ WITH running_count AS (
r2."status" = 'QUEUED' AND
workflowVersion."id" = $2
ORDER BY
rn ASC
rn, seqnum ASC
), min_rn AS (
SELECT
MIN(rn) as min_rn
FROM
queued_row_numbers
), first_partition_count AS (
SELECT
COUNT(*) as count
FROM
queued_row_numbers
WHERE
rn = (SELECT min_rn FROM min_rn)
), eligible_runs AS (
SELECT
id
FROM
queued_row_numbers
WHERE
queued_row_numbers."seqnum" <= ($3::int) - (SELECT "count" FROM running_count)
-- We can run up to maxRuns per group, so we multiple max runs by the number of groups, then subtract the
-- total number of running workflows.
queued_row_numbers."seqnum" <= ($3::int) * (SELECT count FROM first_partition_count) - (SELECT "count" FROM running_count)
FOR UPDATE SKIP LOCKED
)
UPDATE "WorkflowRun"
@@ -359,16 +359,47 @@ ORDER BY "WorkflowVersion"."workflowId", "WorkflowVersion"."order" DESC;
SELECT
sqlc.embed(workflowVersions),
w."name" as "workflowName",
-- return "hasWorkflowConcurrency" if the workflow has concurrency
EXISTS (
SELECT 1
FROM "WorkflowConcurrency" as wc
WHERE wc."workflowVersionId" = workflowVersions."id"
) as "hasWorkflowConcurrency"
wc."limitStrategy" as "concurrencyLimitStrategy",
wc."maxRuns" as "concurrencyMaxRuns"
FROM
"WorkflowVersion" as workflowVersions
JOIN
"Workflow" as w ON w."id" = workflowVersions."workflowId"
LEFT JOIN
"WorkflowConcurrency" as wc ON wc."workflowVersionId" = workflowVersions."id"
WHERE
workflowVersions."id" = ANY(@ids::uuid[]) AND
w."tenantId" = @tenantId::uuid;
w."tenantId" = @tenantId::uuid;
-- name: GetWorkflowByName :one
SELECT
*
FROM
"Workflow" as workflows
WHERE
workflows."tenantId" = @tenantId::uuid AND
workflows."name" = @name::text;
-- name: CreateSchedules :many
INSERT INTO "WorkflowTriggerScheduledRef" (
"id",
"parentId",
"triggerAt",
"input"
) VALUES (
gen_random_uuid(),
@workflowRunId::uuid,
unnest(@triggerTimes::timestamp[]),
@input::jsonb
) RETURNING *;
-- name: GetWorkflowLatestVersion :one
SELECT
"id"
FROM
"WorkflowVersion" as workflowVersions
WHERE
workflowVersions."workflowId" = @workflowId::uuid
ORDER BY
workflowVersions."order" DESC
LIMIT 1;
@@ -159,6 +159,52 @@ func (q *Queries) CreateJob(ctx context.Context, db DBTX, arg CreateJobParams) (
return &i, err
}
const createSchedules = `-- name: CreateSchedules :many
INSERT INTO "WorkflowTriggerScheduledRef" (
"id",
"parentId",
"triggerAt",
"input"
) VALUES (
gen_random_uuid(),
$1::uuid,
unnest($2::timestamp[]),
$3::jsonb
) RETURNING id, "parentId", "triggerAt", "tickerId", input
`
type CreateSchedulesParams struct {
Workflowrunid pgtype.UUID `json:"workflowrunid"`
Triggertimes []pgtype.Timestamp `json:"triggertimes"`
Input []byte `json:"input"`
}
func (q *Queries) CreateSchedules(ctx context.Context, db DBTX, arg CreateSchedulesParams) ([]*WorkflowTriggerScheduledRef, error) {
rows, err := db.Query(ctx, createSchedules, arg.Workflowrunid, arg.Triggertimes, arg.Input)
if err != nil {
return nil, err
}
defer rows.Close()
var items []*WorkflowTriggerScheduledRef
for rows.Next() {
var i WorkflowTriggerScheduledRef
if err := rows.Scan(
&i.ID,
&i.ParentId,
&i.TriggerAt,
&i.TickerId,
&i.Input,
); err != nil {
return nil, err
}
items = append(items, &i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const createStep = `-- name: CreateStep :one
INSERT INTO "Step" (
"id",
@@ -523,20 +569,67 @@ func (q *Queries) CreateWorkflowVersion(ctx context.Context, db DBTX, arg Create
return &i, err
}
const getWorkflowByName = `-- name: GetWorkflowByName :one
SELECT
id, "createdAt", "updatedAt", "deletedAt", "tenantId", name, description
FROM
"Workflow" as workflows
WHERE
workflows."tenantId" = $1::uuid AND
workflows."name" = $2::text
`
type GetWorkflowByNameParams struct {
Tenantid pgtype.UUID `json:"tenantid"`
Name string `json:"name"`
}
func (q *Queries) GetWorkflowByName(ctx context.Context, db DBTX, arg GetWorkflowByNameParams) (*Workflow, error) {
row := db.QueryRow(ctx, getWorkflowByName, arg.Tenantid, arg.Name)
var i Workflow
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.DeletedAt,
&i.TenantId,
&i.Name,
&i.Description,
)
return &i, err
}
const getWorkflowLatestVersion = `-- name: GetWorkflowLatestVersion :one
SELECT
"id"
FROM
"WorkflowVersion" as workflowVersions
WHERE
workflowVersions."workflowId" = $1::uuid
ORDER BY
workflowVersions."order" DESC
LIMIT 1
`
func (q *Queries) GetWorkflowLatestVersion(ctx context.Context, db DBTX, workflowid pgtype.UUID) (pgtype.UUID, error) {
row := db.QueryRow(ctx, getWorkflowLatestVersion, workflowid)
var id pgtype.UUID
err := row.Scan(&id)
return id, err
}
const getWorkflowVersionForEngine = `-- name: GetWorkflowVersionForEngine :many
SELECT
workflowversions.id, workflowversions."createdAt", workflowversions."updatedAt", workflowversions."deletedAt", workflowversions.version, workflowversions."order", workflowversions."workflowId", workflowversions.checksum, workflowversions."scheduleTimeout",
w."name" as "workflowName",
-- return "hasWorkflowConcurrency" if the workflow has concurrency
EXISTS (
SELECT 1
FROM "WorkflowConcurrency" as wc
WHERE wc."workflowVersionId" = workflowVersions."id"
) as "hasWorkflowConcurrency"
wc."limitStrategy" as "concurrencyLimitStrategy",
wc."maxRuns" as "concurrencyMaxRuns"
FROM
"WorkflowVersion" as workflowVersions
JOIN
"Workflow" as w ON w."id" = workflowVersions."workflowId"
LEFT JOIN
"WorkflowConcurrency" as wc ON wc."workflowVersionId" = workflowVersions."id"
WHERE
workflowVersions."id" = ANY($1::uuid[]) AND
w."tenantId" = $2::uuid
@@ -548,9 +641,10 @@ type GetWorkflowVersionForEngineParams struct {
}
type GetWorkflowVersionForEngineRow struct {
WorkflowVersion WorkflowVersion `json:"workflow_version"`
WorkflowName string `json:"workflowName"`
HasWorkflowConcurrency bool `json:"hasWorkflowConcurrency"`
WorkflowVersion WorkflowVersion `json:"workflow_version"`
WorkflowName string `json:"workflowName"`
ConcurrencyLimitStrategy NullConcurrencyLimitStrategy `json:"concurrencyLimitStrategy"`
ConcurrencyMaxRuns pgtype.Int4 `json:"concurrencyMaxRuns"`
}
func (q *Queries) GetWorkflowVersionForEngine(ctx context.Context, db DBTX, arg GetWorkflowVersionForEngineParams) ([]*GetWorkflowVersionForEngineRow, error) {
@@ -573,7 +667,8 @@ func (q *Queries) GetWorkflowVersionForEngine(ctx context.Context, db DBTX, arg
&i.WorkflowVersion.Checksum,
&i.WorkflowVersion.ScheduleTimeout,
&i.WorkflowName,
&i.HasWorkflowConcurrency,
&i.ConcurrencyLimitStrategy,
&i.ConcurrencyMaxRuns,
); err != nil {
return nil, err
}
+13 -35
View File
@@ -8,25 +8,22 @@ import (
"github.com/rs/zerolog"
"github.com/hatchet-dev/hatchet/internal/repository"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/db"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/dbsqlc"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/sqlchelpers"
"github.com/hatchet-dev/hatchet/internal/validator"
)
type dispatcherRepository struct {
client *db.PrismaClient
pool *pgxpool.Pool
v validator.Validator
queries *dbsqlc.Queries
l *zerolog.Logger
}
func NewDispatcherRepository(client *db.PrismaClient, pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.DispatcherRepository {
func NewDispatcherRepository(pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.DispatcherEngineRepository {
queries := dbsqlc.New()
return &dispatcherRepository{
client: client,
pool: pool,
queries: queries,
v: v,
@@ -34,46 +31,27 @@ func NewDispatcherRepository(client *db.PrismaClient, pool *pgxpool.Pool, v vali
}
}
func (d *dispatcherRepository) GetDispatcherForWorker(workerId string) (*db.DispatcherModel, error) {
return d.client.Dispatcher.FindFirst(
db.Dispatcher.Workers.Some(
db.Worker.ID.Equals(workerId),
),
).Exec(context.Background())
}
func (d *dispatcherRepository) CreateNewDispatcher(opts *repository.CreateDispatcherOpts) (*db.DispatcherModel, error) {
return d.client.Dispatcher.CreateOne(
db.Dispatcher.ID.Set(opts.ID),
).Exec(context.Background())
}
func (d *dispatcherRepository) UpdateDispatcher(dispatcherId string, opts *repository.UpdateDispatcherOpts) (*db.DispatcherModel, error) {
func (d *dispatcherRepository) CreateNewDispatcher(opts *repository.CreateDispatcherOpts) (*dbsqlc.Dispatcher, error) {
if err := d.v.Validate(opts); err != nil {
return nil, err
}
return d.client.Dispatcher.FindUnique(
db.Dispatcher.ID.Equals(dispatcherId),
).Update(
db.Dispatcher.LastHeartbeatAt.SetIfPresent(opts.LastHeartbeatAt),
).Exec(context.Background())
return d.queries.CreateDispatcher(context.Background(), d.pool, sqlchelpers.UUIDFromStr(opts.ID))
}
func (d *dispatcherRepository) AddWorker(dispatcherId, workerId string) (*db.DispatcherModel, error) {
return d.client.Dispatcher.FindUnique(
db.Dispatcher.ID.Equals(dispatcherId),
).Update(
db.Dispatcher.Workers.Link(
db.Worker.ID.Equals(workerId),
),
).Exec(context.Background())
func (d *dispatcherRepository) UpdateDispatcher(dispatcherId string, opts *repository.UpdateDispatcherOpts) (*dbsqlc.Dispatcher, error) {
if err := d.v.Validate(opts); err != nil {
return nil, err
}
return d.queries.UpdateDispatcher(context.Background(), d.pool, dbsqlc.UpdateDispatcherParams{
ID: sqlchelpers.UUIDFromStr(dispatcherId),
LastHeartbeatAt: sqlchelpers.TimestampFromTime(opts.LastHeartbeatAt.UTC()),
})
}
func (d *dispatcherRepository) Delete(dispatcherId string) error {
_, err := d.client.Dispatcher.FindUnique(
db.Dispatcher.ID.Equals(dispatcherId),
).Delete().Exec(context.Background())
_, err := d.queries.DeleteDispatcher(context.Background(), d.pool, sqlchelpers.UUIDFromStr(dispatcherId))
return err
}
+49 -16
View File
@@ -2,7 +2,6 @@ package prisma
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -16,12 +15,11 @@ import (
"github.com/hatchet-dev/hatchet/internal/repository/prisma/db"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/dbsqlc"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/sqlchelpers"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/sqlctoprisma"
"github.com/hatchet-dev/hatchet/internal/telemetry"
"github.com/hatchet-dev/hatchet/internal/validator"
)
type eventRepository struct {
type eventAPIRepository struct {
client *db.PrismaClient
pool *pgxpool.Pool
v validator.Validator
@@ -29,10 +27,10 @@ type eventRepository struct {
l *zerolog.Logger
}
func NewEventRepository(client *db.PrismaClient, pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.EventRepository {
func NewEventAPIRepository(client *db.PrismaClient, pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.EventAPIRepository {
queries := dbsqlc.New()
return &eventRepository{
return &eventAPIRepository{
client: client,
pool: pool,
v: v,
@@ -41,7 +39,7 @@ func NewEventRepository(client *db.PrismaClient, pool *pgxpool.Pool, v validator
}
}
func (r *eventRepository) ListEvents(tenantId string, opts *repository.ListEventOpts) (*repository.ListEventResult, error) {
func (r *eventAPIRepository) ListEvents(tenantId string, opts *repository.ListEventOpts) (*repository.ListEventResult, error) {
if err := r.v.Validate(opts); err != nil {
return nil, err
}
@@ -149,7 +147,7 @@ func (r *eventRepository) ListEvents(tenantId string, opts *repository.ListEvent
return res, nil
}
func (r *eventRepository) ListEventKeys(tenantId string) ([]string, error) {
func (r *eventAPIRepository) ListEventKeys(tenantId string) ([]string, error) {
var rows []struct {
Key string `json:"key"`
}
@@ -177,24 +175,42 @@ func (r *eventRepository) ListEventKeys(tenantId string) ([]string, error) {
return keys, nil
}
func (r *eventRepository) GetEventById(id string) (*db.EventModel, error) {
func (r *eventAPIRepository) GetEventById(id string) (*db.EventModel, error) {
return r.client.Event.FindUnique(
db.Event.ID.Equals(id),
).Exec(context.Background())
}
func (r *eventRepository) GetEventForEngine(tenantId, id string) (*dbsqlc.GetEventForEngineRow, error) {
return r.queries.GetEventForEngine(context.Background(), r.pool, sqlchelpers.UUIDFromStr(id))
}
func (r *eventRepository) ListEventsById(tenantId string, ids []string) ([]db.EventModel, error) {
func (r *eventAPIRepository) ListEventsById(tenantId string, ids []string) ([]db.EventModel, error) {
return r.client.Event.FindMany(
db.Event.ID.In(ids),
db.Event.TenantID.Equals(tenantId),
).Exec(context.Background())
}
func (r *eventRepository) CreateEvent(ctx context.Context, opts *repository.CreateEventOpts) (*db.EventModel, error) {
type eventEngineRepository struct {
pool *pgxpool.Pool
v validator.Validator
queries *dbsqlc.Queries
l *zerolog.Logger
}
func NewEventEngineRepository(pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.EventEngineRepository {
queries := dbsqlc.New()
return &eventEngineRepository{
pool: pool,
v: v,
queries: queries,
l: l,
}
}
func (r *eventEngineRepository) GetEventForEngine(tenantId, id string) (*dbsqlc.Event, error) {
return r.queries.GetEventForEngine(context.Background(), r.pool, sqlchelpers.UUIDFromStr(id))
}
func (r *eventEngineRepository) CreateEvent(ctx context.Context, opts *repository.CreateEventOpts) (*dbsqlc.Event, error) {
ctx, span := telemetry.NewSpan(ctx, "db-create-event")
defer span.End()
@@ -206,7 +222,7 @@ func (r *eventRepository) CreateEvent(ctx context.Context, opts *repository.Crea
ID: sqlchelpers.UUIDFromStr(uuid.New().String()),
Key: opts.Key,
Tenantid: sqlchelpers.UUIDFromStr(opts.TenantId),
Data: []byte(json.RawMessage(*opts.Data)),
Data: opts.Data,
}
if opts.ReplayedEvent != nil {
@@ -223,5 +239,22 @@ func (r *eventRepository) CreateEvent(ctx context.Context, opts *repository.Crea
return nil, fmt.Errorf("could not create event: %w", err)
}
return sqlctoprisma.NewConverter[dbsqlc.Event, db.EventModel]().ToPrisma(e), nil
return e, nil
}
func (r *eventEngineRepository) ListEventsByIds(tenantId string, ids []string) ([]*dbsqlc.Event, error) {
pgIds := make([]pgtype.UUID, len(ids))
for i, id := range ids {
if err := pgIds[i].Scan(id); err != nil {
return nil, err
}
}
pgTenantId := sqlchelpers.UUIDFromStr(tenantId)
return r.queries.ListEventsByIDs(context.Background(), r.pool, dbsqlc.ListEventsByIDsParams{
Tenantid: pgTenantId,
Ids: pgIds,
})
}
@@ -11,25 +11,22 @@ import (
"github.com/rs/zerolog"
"github.com/hatchet-dev/hatchet/internal/repository"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/db"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/dbsqlc"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/sqlchelpers"
"github.com/hatchet-dev/hatchet/internal/validator"
)
type getGroupKeyRunRepository struct {
client *db.PrismaClient
pool *pgxpool.Pool
v validator.Validator
l *zerolog.Logger
queries *dbsqlc.Queries
}
func NewGetGroupKeyRunRepository(client *db.PrismaClient, pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.GetGroupKeyRunRepository {
func NewGetGroupKeyRunRepository(pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.GetGroupKeyRunEngineRepository {
queries := dbsqlc.New()
return &getGroupKeyRunRepository{
client: client,
pool: pool,
v: v,
l: l,
@@ -37,33 +34,6 @@ func NewGetGroupKeyRunRepository(client *db.PrismaClient, pool *pgxpool.Pool, v
}
}
func (s *getGroupKeyRunRepository) ListGetGroupKeyRuns(tenantId string, opts *repository.ListGetGroupKeyRunsOpts) ([]db.GetGroupKeyRunModel, error) {
if err := s.v.Validate(opts); err != nil {
return nil, err
}
params := []db.GetGroupKeyRunWhereParam{
db.GetGroupKeyRun.TenantID.Equals(tenantId),
}
if opts.Status != nil {
params = append(params, db.GetGroupKeyRun.Status.Equals(*opts.Status))
}
return s.client.GetGroupKeyRun.FindMany(
params...,
).With(
db.GetGroupKeyRun.Ticker.Fetch(),
db.GetGroupKeyRun.WorkflowRun.Fetch().With(
db.WorkflowRun.WorkflowVersion.Fetch().With(
db.WorkflowVersion.Concurrency.Fetch().With(
db.WorkflowConcurrency.GetConcurrencyGroup.Fetch(),
),
),
),
).Exec(context.Background())
}
func (s *getGroupKeyRunRepository) ListGetGroupKeyRunsToRequeue(tenantId string) ([]*dbsqlc.GetGroupKeyRun, error) {
return s.queries.ListGetGroupKeyRunsToRequeue(context.Background(), s.pool, sqlchelpers.UUIDFromStr(tenantId))
}
@@ -233,21 +203,6 @@ func (s *getGroupKeyRunRepository) UpdateGetGroupKeyRun(tenantId, getGroupKeyRun
return getGroupKeyRuns[0], nil
}
func (s *getGroupKeyRunRepository) GetGroupKeyRunById(tenantId, getGroupKeyRunId string) (*db.GetGroupKeyRunModel, error) {
return s.client.GetGroupKeyRun.FindUnique(
db.GetGroupKeyRun.ID.Equals(getGroupKeyRunId),
).With(
db.GetGroupKeyRun.Ticker.Fetch(),
db.GetGroupKeyRun.WorkflowRun.Fetch().With(
db.WorkflowRun.WorkflowVersion.Fetch().With(
db.WorkflowVersion.Concurrency.Fetch().With(
db.WorkflowConcurrency.GetConcurrencyGroup.Fetch(),
),
),
),
).Exec(context.Background())
}
func (s *getGroupKeyRunRepository) GetGroupKeyRunForEngine(tenantId, getGroupKeyRunId string) (*dbsqlc.GetGroupKeyRunForEngineRow, error) {
res, err := s.queries.GetGroupKeyRunForEngine(context.Background(), s.pool, dbsqlc.GetGroupKeyRunForEngineParams{
Ids: []pgtype.UUID{sqlchelpers.UUIDFromStr(getGroupKeyRunId)},
+28 -4
View File
@@ -10,23 +10,23 @@ import (
"github.com/hatchet-dev/hatchet/internal/repository/prisma/dbsqlc"
)
type healthRepository struct {
type healthAPIRepository struct {
client *db.PrismaClient
queries *dbsqlc.Queries
pool *pgxpool.Pool
}
func NewHealthRepository(client *db.PrismaClient, pool *pgxpool.Pool) repository.HealthRepository {
func NewHealthAPIRepository(client *db.PrismaClient, pool *pgxpool.Pool) repository.HealthRepository {
queries := dbsqlc.New()
return &healthRepository{
return &healthAPIRepository{
client: client,
queries: queries,
pool: pool,
}
}
func (a *healthRepository) IsHealthy() bool {
func (a *healthAPIRepository) IsHealthy() bool {
_, err := a.client.User.FindMany().Take(1).Exec(context.Background())
if err != nil {
return false
@@ -39,3 +39,27 @@ func (a *healthRepository) IsHealthy() bool {
return true
}
type healthEngineRepository struct {
queries *dbsqlc.Queries
pool *pgxpool.Pool
}
func NewHealthEngineRepository(pool *pgxpool.Pool) repository.HealthRepository {
queries := dbsqlc.New()
return &healthEngineRepository{
queries: queries,
pool: pool,
}
}
func (a *healthEngineRepository) IsHealthy() bool {
_, err := a.queries.Health(context.Background(), a.pool)
if err != nil { //nolint:gosimple
return false
}
return true
}
+35 -100
View File
@@ -3,6 +3,7 @@ package prisma
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/rs/zerolog"
@@ -13,7 +14,7 @@ import (
"github.com/hatchet-dev/hatchet/internal/validator"
)
type jobRunRepository struct {
type jobRunAPIRepository struct {
client *db.PrismaClient
pool *pgxpool.Pool
v validator.Validator
@@ -21,10 +22,10 @@ type jobRunRepository struct {
l *zerolog.Logger
}
func NewJobRunRepository(client *db.PrismaClient, pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.JobRunRepository {
func NewJobRunAPIRepository(client *db.PrismaClient, pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.JobRunAPIRepository {
queries := dbsqlc.New()
return &jobRunRepository{
return &jobRunAPIRepository{
client: client,
v: v,
pool: pool,
@@ -33,73 +34,46 @@ func NewJobRunRepository(client *db.PrismaClient, pool *pgxpool.Pool, v validato
}
}
func (j *jobRunRepository) ListAllJobRuns(opts *repository.ListAllJobRunsOpts) ([]db.JobRunModel, error) {
if err := j.v.Validate(opts); err != nil {
return nil, err
}
params := []db.JobRunWhereParam{}
if opts.TickerId != nil {
params = append(params, db.JobRun.TickerID.Equals(*opts.TickerId))
}
if opts.Status != nil {
params = append(params, db.JobRun.Status.Equals(*opts.Status))
}
if opts.NoTickerId != nil && *opts.NoTickerId {
params = append(params, db.JobRun.TickerID.IsNull())
}
return j.client.JobRun.FindMany(
params...,
).With(
db.JobRun.LookupData.Fetch(),
db.JobRun.StepRuns.Fetch().With(
db.StepRun.Step.Fetch().With(
db.Step.Children.Fetch(),
db.Step.Parents.Fetch(),
db.Step.Action.Fetch(),
),
),
db.JobRun.Job.Fetch().With(
db.Job.Workflow.Fetch(),
),
).Exec(context.Background())
func (j *jobRunAPIRepository) SetJobRunStatusRunning(tenantId, jobRunId string) error {
return setJobRunStatusRunning(context.Background(), j.pool, j.queries, j.l, tenantId, jobRunId)
}
func (j *jobRunRepository) GetJobRunById(tenantId, jobRunId string) (*db.JobRunModel, error) {
return j.client.JobRun.FindUnique(
db.JobRun.ID.Equals(jobRunId),
).With(
db.JobRun.LookupData.Fetch(),
db.JobRun.StepRuns.Fetch().With(
db.StepRun.Parents.Fetch(),
db.StepRun.Children.Fetch(),
db.StepRun.Step.Fetch().With(
db.Step.Children.Fetch(),
db.Step.Parents.Fetch(),
db.Step.Action.Fetch(),
),
),
db.JobRun.Job.Fetch().With(
db.Job.Workflow.Fetch(),
),
db.JobRun.Ticker.Fetch(),
).Exec(context.Background())
type jobRunEngineRepository struct {
pool *pgxpool.Pool
v validator.Validator
queries *dbsqlc.Queries
l *zerolog.Logger
}
func (j *jobRunRepository) SetJobRunStatusRunning(tenantId, jobRunId string) error {
tx, err := j.pool.Begin(context.Background())
func NewJobRunEngineRepository(pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.JobRunEngineRepository {
queries := dbsqlc.New()
return &jobRunEngineRepository{
v: v,
pool: pool,
queries: queries,
l: l,
}
}
func (j *jobRunEngineRepository) SetJobRunStatusRunning(tenantId, jobRunId string) error {
return setJobRunStatusRunning(context.Background(), j.pool, j.queries, j.l, tenantId, jobRunId)
}
func (j *jobRunEngineRepository) ListJobRunsForWorkflowRun(tenantId, workflowRunId string) ([]pgtype.UUID, error) {
return j.queries.ListJobRunsForWorkflowRun(context.Background(), j.pool, sqlchelpers.UUIDFromStr(workflowRunId))
}
func setJobRunStatusRunning(ctx context.Context, pool *pgxpool.Pool, queries *dbsqlc.Queries, l *zerolog.Logger, tenantId, jobRunId string) error {
tx, err := pool.Begin(context.Background())
if err != nil {
return err
}
defer deferRollback(context.Background(), j.l, tx.Rollback)
defer deferRollback(context.Background(), l, tx.Rollback)
jobRun, err := j.queries.UpdateJobRunStatus(context.Background(), tx, dbsqlc.UpdateJobRunStatusParams{
jobRun, err := queries.UpdateJobRunStatus(context.Background(), tx, dbsqlc.UpdateJobRunStatusParams{
ID: sqlchelpers.UUIDFromStr(jobRunId),
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Status: dbsqlc.JobRunStatusRUNNING,
@@ -109,7 +83,7 @@ func (j *jobRunRepository) SetJobRunStatusRunning(tenantId, jobRunId string) err
return err
}
_, err = j.queries.UpdateWorkflowRun(
_, err = queries.UpdateWorkflowRun(
context.Background(),
tx,
dbsqlc.UpdateWorkflowRunParams{
@@ -128,42 +102,3 @@ func (j *jobRunRepository) SetJobRunStatusRunning(tenantId, jobRunId string) err
return tx.Commit(context.Background())
}
func (j *jobRunRepository) GetJobRunLookupData(tenantId, jobRunId string) (*db.JobRunLookupDataModel, error) {
return j.client.JobRunLookupData.FindUnique(
db.JobRunLookupData.JobRunIDTenantID(
db.JobRunLookupData.JobRunID.Equals(jobRunId),
db.JobRunLookupData.TenantID.Equals(tenantId),
),
).Exec(context.Background())
}
func (j *jobRunRepository) UpdateJobRunLookupData(tenantId, jobRunId string, opts *repository.UpdateJobRunLookupDataOpts) error {
pgTenantId := sqlchelpers.UUIDFromStr(tenantId)
pgJobRunId := sqlchelpers.UUIDFromStr(jobRunId)
tx, err := j.pool.Begin(context.Background())
if err != nil {
return err
}
defer deferRollback(context.Background(), j.l, tx.Rollback)
err = j.queries.UpsertJobRunLookupData(
context.Background(),
tx,
dbsqlc.UpsertJobRunLookupDataParams{
Jobrunid: pgJobRunId,
Tenantid: pgTenantId,
Fieldpath: opts.FieldPath,
Jsondata: opts.Data,
},
)
if err != nil {
return err
}
return tx.Commit(context.Background())
}
+76 -61
View File
@@ -10,25 +10,22 @@ import (
"github.com/rs/zerolog"
"github.com/hatchet-dev/hatchet/internal/repository"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/db"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/dbsqlc"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/sqlchelpers"
"github.com/hatchet-dev/hatchet/internal/validator"
)
type logRepository struct {
client *db.PrismaClient
type logAPIRepository struct {
pool *pgxpool.Pool
v validator.Validator
queries *dbsqlc.Queries
l *zerolog.Logger
}
func NewLogRepository(client *db.PrismaClient, pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.LogsRepository {
func NewLogAPIRepository(pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.LogsAPIRepository {
queries := dbsqlc.New()
return &logRepository{
client: client,
return &logAPIRepository{
pool: pool,
v: v,
queries: queries,
@@ -36,61 +33,7 @@ func NewLogRepository(client *db.PrismaClient, pool *pgxpool.Pool, v validator.V
}
}
func (r *logRepository) PutLog(tenantId string, opts *repository.CreateLogLineOpts) (*dbsqlc.LogLine, error) {
if err := r.v.Validate(opts); err != nil {
return nil, err
}
createParams := dbsqlc.CreateLogLineParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Message: opts.Message,
Steprunid: sqlchelpers.UUIDFromStr(opts.StepRunId),
}
if opts.CreatedAt != nil {
utcTime := opts.CreatedAt.UTC()
createParams.CreatedAt = sqlchelpers.TimestampFromTime(utcTime)
}
if opts.Level != nil {
createParams.Level = dbsqlc.NullLogLineLevel{
LogLineLevel: dbsqlc.LogLineLevel(*opts.Level),
Valid: true,
}
}
if opts.Metadata != nil {
createParams.Metadata = opts.Metadata
}
tx, err := r.pool.Begin(context.Background())
if err != nil {
return nil, err
}
defer deferRollback(context.Background(), r.l, tx.Rollback)
logLine, err := r.queries.CreateLogLine(
context.Background(),
tx,
createParams,
)
if err != nil {
return nil, fmt.Errorf("could not create log line: %w", err)
}
err = tx.Commit(context.Background())
if err != nil {
return nil, fmt.Errorf("could not commit transaction: %w", err)
}
return logLine, nil
}
func (r *logRepository) ListLogLines(tenantId string, opts *repository.ListLogsOpts) (*repository.ListLogsResult, error) {
func (r *logAPIRepository) ListLogLines(tenantId string, opts *repository.ListLogsOpts) (*repository.ListLogsResult, error) {
if err := r.v.Validate(opts); err != nil {
return nil, err
}
@@ -188,3 +131,75 @@ func (r *logRepository) ListLogLines(tenantId string, opts *repository.ListLogsO
return res, nil
}
type logEngineRepository struct {
pool *pgxpool.Pool
v validator.Validator
queries *dbsqlc.Queries
l *zerolog.Logger
}
func NewLogEngineRepository(pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.LogsEngineRepository {
queries := dbsqlc.New()
return &logEngineRepository{
pool: pool,
v: v,
queries: queries,
l: l,
}
}
func (r *logEngineRepository) PutLog(tenantId string, opts *repository.CreateLogLineOpts) (*dbsqlc.LogLine, error) {
if err := r.v.Validate(opts); err != nil {
return nil, err
}
createParams := dbsqlc.CreateLogLineParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Message: opts.Message,
Steprunid: sqlchelpers.UUIDFromStr(opts.StepRunId),
}
if opts.CreatedAt != nil {
utcTime := opts.CreatedAt.UTC()
createParams.CreatedAt = sqlchelpers.TimestampFromTime(utcTime)
}
if opts.Level != nil {
createParams.Level = dbsqlc.NullLogLineLevel{
LogLineLevel: dbsqlc.LogLineLevel(*opts.Level),
Valid: true,
}
}
if opts.Metadata != nil {
createParams.Metadata = opts.Metadata
}
tx, err := r.pool.Begin(context.Background())
if err != nil {
return nil, err
}
defer deferRollback(context.Background(), r.l, tx.Rollback)
logLine, err := r.queries.CreateLogLine(
context.Background(),
tx,
createParams,
)
if err != nil {
return nil, fmt.Errorf("could not create log line: %w", err)
}
err = tx.Commit(context.Background())
if err != nil {
return nil, fmt.Errorf("could not commit transaction: %w", err)
}
return logLine, nil
}
+150 -69
View File
@@ -12,26 +12,23 @@ import (
"github.com/hatchet-dev/hatchet/internal/validator"
)
type prismaRepository struct {
apiToken repository.APITokenRepository
event repository.EventRepository
log repository.LogsRepository
tenant repository.TenantRepository
tenantInvite repository.TenantInviteRepository
workflow repository.WorkflowRepository
workflowRun repository.WorkflowRunRepository
jobRun repository.JobRunRepository
stepRun repository.StepRunRepository
getGroupKeyRun repository.GetGroupKeyRunRepository
github repository.GithubRepository
step repository.StepRepository
sns repository.SNSRepository
dispatcher repository.DispatcherRepository
worker repository.WorkerRepository
ticker repository.TickerRepository
userSession repository.UserSessionRepository
user repository.UserRepository
health repository.HealthRepository
type apiRepository struct {
apiToken repository.APITokenRepository
event repository.EventAPIRepository
log repository.LogsAPIRepository
tenant repository.TenantAPIRepository
tenantInvite repository.TenantInviteRepository
workflow repository.WorkflowAPIRepository
workflowRun repository.WorkflowRunAPIRepository
jobRun repository.JobRunAPIRepository
stepRun repository.StepRunAPIRepository
github repository.GithubRepository
step repository.StepRepository
sns repository.SNSRepository
worker repository.WorkerAPIRepository
userSession repository.UserSessionRepository
user repository.UserRepository
health repository.HealthRepository
}
type PrismaRepositoryOpt func(*PrismaRepositoryOpts)
@@ -66,7 +63,7 @@ func WithCache(cache cache.Cacheable) PrismaRepositoryOpt {
}
}
func NewPrismaRepository(client *db.PrismaClient, pool *pgxpool.Pool, fs ...PrismaRepositoryOpt) repository.Repository {
func NewAPIRepository(client *db.PrismaClient, pool *pgxpool.Pool, fs ...PrismaRepositoryOpt) repository.APIRepository {
opts := defaultPrismaRepositoryOpts()
for _, f := range fs {
@@ -80,101 +77,185 @@ func NewPrismaRepository(client *db.PrismaClient, pool *pgxpool.Pool, fs ...Pris
opts.cache = cache.New(1 * time.Millisecond)
}
return &prismaRepository{
apiToken: NewAPITokenRepository(client, opts.v, opts.cache),
event: NewEventRepository(client, pool, opts.v, opts.l),
log: NewLogRepository(client, pool, opts.v, opts.l),
tenant: NewTenantRepository(client, opts.v, opts.cache),
tenantInvite: NewTenantInviteRepository(client, opts.v),
workflow: NewWorkflowRepository(client, pool, opts.v, opts.l),
workflowRun: NewWorkflowRunRepository(client, pool, opts.v, opts.l),
jobRun: NewJobRunRepository(client, pool, opts.v, opts.l),
stepRun: NewStepRunRepository(client, pool, opts.v, opts.l),
getGroupKeyRun: NewGetGroupKeyRunRepository(client, pool, opts.v, opts.l),
github: NewGithubRepository(client, opts.v),
step: NewStepRepository(client, opts.v),
sns: NewSNSRepository(client, opts.v),
dispatcher: NewDispatcherRepository(client, pool, opts.v, opts.l),
worker: NewWorkerRepository(client, pool, opts.v, opts.l),
ticker: NewTickerRepository(client, pool, opts.v, opts.l),
userSession: NewUserSessionRepository(client, opts.v),
user: NewUserRepository(client, opts.v),
health: NewHealthRepository(client, pool),
return &apiRepository{
apiToken: NewAPITokenRepository(client, opts.v, opts.cache),
event: NewEventAPIRepository(client, pool, opts.v, opts.l),
log: NewLogAPIRepository(pool, opts.v, opts.l),
tenant: NewTenantAPIRepository(client, opts.v, opts.cache),
tenantInvite: NewTenantInviteRepository(client, opts.v),
workflow: NewWorkflowRepository(client, pool, opts.v, opts.l),
workflowRun: NewWorkflowRunRepository(client, pool, opts.v, opts.l),
jobRun: NewJobRunAPIRepository(client, pool, opts.v, opts.l),
stepRun: NewStepRunAPIRepository(client, pool, opts.v, opts.l),
github: NewGithubRepository(client, opts.v),
step: NewStepRepository(client, opts.v),
sns: NewSNSRepository(client, opts.v),
worker: NewWorkerAPIRepository(client, pool, opts.v, opts.l),
userSession: NewUserSessionRepository(client, opts.v),
user: NewUserRepository(client, opts.v),
health: NewHealthAPIRepository(client, pool),
}
}
func (r *prismaRepository) Health() repository.HealthRepository {
func (r *apiRepository) Health() repository.HealthRepository {
return r.health
}
func (r *prismaRepository) APIToken() repository.APITokenRepository {
func (r *apiRepository) APIToken() repository.APITokenRepository {
return r.apiToken
}
func (r *prismaRepository) Event() repository.EventRepository {
func (r *apiRepository) Event() repository.EventAPIRepository {
return r.event
}
func (r *prismaRepository) Log() repository.LogsRepository {
func (r *apiRepository) Log() repository.LogsAPIRepository {
return r.log
}
func (r *prismaRepository) Tenant() repository.TenantRepository {
func (r *apiRepository) Tenant() repository.TenantAPIRepository {
return r.tenant
}
func (r *prismaRepository) TenantInvite() repository.TenantInviteRepository {
func (r *apiRepository) TenantInvite() repository.TenantInviteRepository {
return r.tenantInvite
}
func (r *prismaRepository) Workflow() repository.WorkflowRepository {
func (r *apiRepository) Workflow() repository.WorkflowAPIRepository {
return r.workflow
}
func (r *prismaRepository) WorkflowRun() repository.WorkflowRunRepository {
func (r *apiRepository) WorkflowRun() repository.WorkflowRunAPIRepository {
return r.workflowRun
}
func (r *prismaRepository) JobRun() repository.JobRunRepository {
func (r *apiRepository) JobRun() repository.JobRunAPIRepository {
return r.jobRun
}
func (r *prismaRepository) StepRun() repository.StepRunRepository {
func (r *apiRepository) StepRun() repository.StepRunAPIRepository {
return r.stepRun
}
func (r *prismaRepository) SNS() repository.SNSRepository {
func (r *apiRepository) SNS() repository.SNSRepository {
return r.sns
}
func (r *prismaRepository) GetGroupKeyRun() repository.GetGroupKeyRunRepository {
return r.getGroupKeyRun
}
func (r *prismaRepository) Github() repository.GithubRepository {
func (r *apiRepository) Github() repository.GithubRepository {
return r.github
}
func (r *prismaRepository) Step() repository.StepRepository {
func (r *apiRepository) Step() repository.StepRepository {
return r.step
}
func (r *prismaRepository) Dispatcher() repository.DispatcherRepository {
return r.dispatcher
}
func (r *prismaRepository) Worker() repository.WorkerRepository {
func (r *apiRepository) Worker() repository.WorkerAPIRepository {
return r.worker
}
func (r *prismaRepository) Ticker() repository.TickerRepository {
return r.ticker
}
func (r *prismaRepository) UserSession() repository.UserSessionRepository {
func (r *apiRepository) UserSession() repository.UserSessionRepository {
return r.userSession
}
func (r *prismaRepository) User() repository.UserRepository {
func (r *apiRepository) User() repository.UserRepository {
return r.user
}
type engineRepository struct {
health repository.HealthRepository
apiToken repository.EngineTokenRepository
dispatcher repository.DispatcherEngineRepository
event repository.EventEngineRepository
getGroupKeyRun repository.GetGroupKeyRunEngineRepository
jobRun repository.JobRunEngineRepository
stepRun repository.StepRunEngineRepository
tenant repository.TenantEngineRepository
ticker repository.TickerEngineRepository
worker repository.WorkerEngineRepository
workflow repository.WorkflowEngineRepository
workflowRun repository.WorkflowRunEngineRepository
log repository.LogsEngineRepository
}
func (r *engineRepository) Health() repository.HealthRepository {
return r.health
}
func (r *engineRepository) APIToken() repository.EngineTokenRepository {
return r.apiToken
}
func (r *engineRepository) Dispatcher() repository.DispatcherEngineRepository {
return r.dispatcher
}
func (r *engineRepository) Event() repository.EventEngineRepository {
return r.event
}
func (r *engineRepository) GetGroupKeyRun() repository.GetGroupKeyRunEngineRepository {
return r.getGroupKeyRun
}
func (r *engineRepository) JobRun() repository.JobRunEngineRepository {
return r.jobRun
}
func (r *engineRepository) StepRun() repository.StepRunEngineRepository {
return r.stepRun
}
func (r *engineRepository) Tenant() repository.TenantEngineRepository {
return r.tenant
}
func (r *engineRepository) Ticker() repository.TickerEngineRepository {
return r.ticker
}
func (r *engineRepository) Worker() repository.WorkerEngineRepository {
return r.worker
}
func (r *engineRepository) Workflow() repository.WorkflowEngineRepository {
return r.workflow
}
func (r *engineRepository) WorkflowRun() repository.WorkflowRunEngineRepository {
return r.workflowRun
}
func (r *engineRepository) Log() repository.LogsEngineRepository {
return r.log
}
func NewEngineRepository(pool *pgxpool.Pool, fs ...PrismaRepositoryOpt) repository.EngineRepository {
opts := defaultPrismaRepositoryOpts()
for _, f := range fs {
f(opts)
}
newLogger := opts.l.With().Str("service", "database").Logger()
opts.l = &newLogger
if opts.cache == nil {
opts.cache = cache.New(1 * time.Millisecond)
}
return &engineRepository{
health: NewHealthEngineRepository(pool),
apiToken: NewEngineTokenRepository(pool, opts.v, opts.l, opts.cache),
dispatcher: NewDispatcherRepository(pool, opts.v, opts.l),
event: NewEventEngineRepository(pool, opts.v, opts.l),
getGroupKeyRun: NewGetGroupKeyRunRepository(pool, opts.v, opts.l),
jobRun: NewJobRunEngineRepository(pool, opts.v, opts.l),
stepRun: NewStepRunEngineRepository(pool, opts.v, opts.l),
tenant: NewTenantEngineRepository(pool, opts.v, opts.l, opts.cache),
ticker: NewTickerRepository(pool, opts.v, opts.l),
worker: NewWorkerEngineRepository(pool, opts.v, opts.l),
workflow: NewWorkflowEngineRepository(pool, opts.v, opts.l),
workflowRun: NewWorkflowRunEngineRepository(pool, opts.v, opts.l),
log: NewLogEngineRepository(pool, opts.v, opts.l),
}
}
+143 -162
View File
@@ -21,7 +21,7 @@ import (
"github.com/hatchet-dev/hatchet/internal/validator"
)
type stepRunRepository struct {
type stepRunAPIRepository struct {
client *db.PrismaClient
pool *pgxpool.Pool
v validator.Validator
@@ -29,10 +29,10 @@ type stepRunRepository struct {
queries *dbsqlc.Queries
}
func NewStepRunRepository(client *db.PrismaClient, pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.StepRunRepository {
func NewStepRunAPIRepository(client *db.PrismaClient, pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.StepRunAPIRepository {
queries := dbsqlc.New()
return &stepRunRepository{
return &stepRunAPIRepository{
client: client,
pool: pool,
v: v,
@@ -41,41 +41,122 @@ func NewStepRunRepository(client *db.PrismaClient, pool *pgxpool.Pool, v validat
}
}
func (s *stepRunRepository) ListAllStepRuns(opts *repository.ListAllStepRunsOpts) ([]db.StepRunModel, error) {
if err := s.v.Validate(opts); err != nil {
return nil, err
}
params := []db.StepRunWhereParam{}
if opts.TickerId != nil {
params = append(params, db.StepRun.TickerID.Equals(*opts.TickerId))
}
if opts.Status != nil {
params = append(params, db.StepRun.Status.Equals(*opts.Status))
}
if opts.NoTickerId != nil && *opts.NoTickerId {
params = append(params, db.StepRun.TickerID.IsNull())
}
return s.client.StepRun.FindMany(
params...,
func (s *stepRunAPIRepository) GetStepRunById(tenantId, stepRunId string) (*db.StepRunModel, error) {
return s.client.StepRun.FindUnique(
db.StepRun.ID.Equals(stepRunId),
).With(
db.StepRun.Children.Fetch(),
db.StepRun.Parents.Fetch().With(
db.StepRun.Step.Fetch(),
),
db.StepRun.Step.Fetch().With(
db.Step.Job.Fetch(),
db.Step.Action.Fetch(),
),
db.StepRun.Children.Fetch(),
db.StepRun.Parents.Fetch(),
db.StepRun.JobRun.Fetch().With(
db.JobRun.Job.Fetch(),
db.JobRun.LookupData.Fetch(),
db.JobRun.WorkflowRun.Fetch(),
),
db.StepRun.Ticker.Fetch(),
).Exec(context.Background())
}
func (s *stepRunRepository) ListStepRunsToRequeue(tenantId string) ([]*dbsqlc.StepRun, error) {
func (s *stepRunAPIRepository) GetFirstArchivedStepRunResult(tenantId, stepRunId string) (*db.StepRunResultArchiveModel, error) {
return s.client.StepRunResultArchive.FindFirst(
db.StepRunResultArchive.StepRunID.Equals(stepRunId),
db.StepRunResultArchive.StepRun.Where(
db.StepRun.TenantID.Equals(tenantId),
),
).OrderBy(
db.StepRunResultArchive.Order.Order(db.ASC),
).Exec(context.Background())
}
type stepRunEngineRepository struct {
pool *pgxpool.Pool
v validator.Validator
l *zerolog.Logger
queries *dbsqlc.Queries
}
func NewStepRunEngineRepository(pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.StepRunEngineRepository {
queries := dbsqlc.New()
return &stepRunEngineRepository{
pool: pool,
v: v,
l: l,
queries: queries,
}
}
func (s *stepRunEngineRepository) ListRunningStepRunsForTicker(tickerId string) ([]*dbsqlc.GetStepRunForEngineRow, error) {
tx, err := s.pool.Begin(context.Background())
if err != nil {
return nil, err
}
defer deferRollback(context.Background(), s.l, tx.Rollback)
srs, err := s.queries.ListStepRuns(context.Background(), s.pool, dbsqlc.ListStepRunsParams{
Status: dbsqlc.NullStepRunStatus{
StepRunStatus: dbsqlc.StepRunStatusRUNNING,
},
TickerId: sqlchelpers.UUIDFromStr(tickerId),
})
if err != nil {
return nil, err
}
res, err := s.queries.GetStepRunForEngine(context.Background(), tx, dbsqlc.GetStepRunForEngineParams{
Ids: srs,
})
if err != nil {
return nil, err
}
err = tx.Commit(context.Background())
return res, err
}
func (s *stepRunEngineRepository) ListRunningStepRunsForWorkflowRun(tenantId, workflowRunId string) ([]*dbsqlc.GetStepRunForEngineRow, error) {
tx, err := s.pool.Begin(context.Background())
if err != nil {
return nil, err
}
defer deferRollback(context.Background(), s.l, tx.Rollback)
srs, err := s.queries.ListStepRuns(context.Background(), s.pool, dbsqlc.ListStepRunsParams{
Status: dbsqlc.NullStepRunStatus{
StepRunStatus: dbsqlc.StepRunStatusRUNNING,
},
WorkflowRunId: sqlchelpers.UUIDFromStr(workflowRunId),
})
if err != nil {
return nil, err
}
res, err := s.queries.GetStepRunForEngine(context.Background(), tx, dbsqlc.GetStepRunForEngineParams{
Ids: srs,
})
if err != nil {
return nil, err
}
err = tx.Commit(context.Background())
return res, err
}
func (s *stepRunEngineRepository) ListStepRunsToRequeue(tenantId string) ([]*dbsqlc.StepRun, error) {
pgTenantId := sqlchelpers.UUIDFromStr(tenantId)
tx, err := s.pool.Begin(context.Background())
@@ -102,7 +183,7 @@ func (s *stepRunRepository) ListStepRunsToRequeue(tenantId string) ([]*dbsqlc.St
return stepRuns, nil
}
func (s *stepRunRepository) ListStepRunsToReassign(tenantId string) ([]*dbsqlc.StepRun, error) {
func (s *stepRunEngineRepository) ListStepRunsToReassign(tenantId string) ([]*dbsqlc.StepRun, error) {
pgTenantId := sqlchelpers.UUIDFromStr(tenantId)
tx, err := s.pool.Begin(context.Background())
@@ -129,44 +210,6 @@ func (s *stepRunRepository) ListStepRunsToReassign(tenantId string) ([]*dbsqlc.S
return stepRuns, nil
}
func (s *stepRunRepository) ListStepRuns(tenantId string, opts *repository.ListStepRunsOpts) ([]db.StepRunModel, error) {
if err := s.v.Validate(opts); err != nil {
return nil, err
}
params := []db.StepRunWhereParam{
db.StepRun.TenantID.Equals(tenantId),
}
if opts.Status != nil {
params = append(params, db.StepRun.Status.Equals(*opts.Status))
}
if opts.JobRunId != nil {
params = append(params, db.StepRun.JobRunID.Equals(*opts.JobRunId))
}
if opts.WorkflowRunId != nil {
params = append(params, db.StepRun.JobRun.Where(
db.JobRun.WorkflowRunID.Equals(*opts.WorkflowRunId),
))
}
return s.client.StepRun.FindMany(
params...,
).With(
db.StepRun.Step.Fetch().With(
db.Step.Action.Fetch(),
),
db.StepRun.Children.Fetch(),
db.StepRun.Parents.Fetch(),
db.StepRun.JobRun.Fetch().With(
db.JobRun.Job.Fetch(),
),
db.StepRun.Ticker.Fetch(),
).Exec(context.Background())
}
var retrier = func(l *zerolog.Logger, f func() error) error {
retries := 0
@@ -203,26 +246,29 @@ var retrier = func(l *zerolog.Logger, f func() error) error {
return nil
}
func (s *stepRunRepository) AssignStepRunToWorker(tenantId, stepRunId string) (string, string, error) {
// var assigned
var assigned *dbsqlc.AssignStepRunToWorkerRow
func (s *stepRunEngineRepository) AssignStepRunToWorker(tenantId, stepRunId string) (string, string, error) {
tx, err := s.pool.Begin(context.Background())
err := retrier(s.l, func() (err error) {
assigned, err = s.queries.AssignStepRunToWorker(context.Background(), s.pool, dbsqlc.AssignStepRunToWorkerParams{
Steprunid: sqlchelpers.UUIDFromStr(stepRunId),
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
})
if err != nil {
return "", "", err
}
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return repository.ErrNoWorkerAvailable
}
defer deferRollback(context.Background(), s.l, tx.Rollback)
return err
assigned, err := s.queries.AssignStepRunToWorker(context.Background(), tx, dbsqlc.AssignStepRunToWorkerParams{
Steprunid: sqlchelpers.UUIDFromStr(stepRunId),
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
})
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return "", "", repository.ErrNoWorkerAvailable
}
return nil
})
return "", "", err
}
err = tx.Commit(context.Background())
if err != nil {
return "", "", err
@@ -231,26 +277,7 @@ func (s *stepRunRepository) AssignStepRunToWorker(tenantId, stepRunId string) (s
return sqlchelpers.UUIDToStr(assigned.WorkerId), sqlchelpers.UUIDToStr(assigned.DispatcherId), nil
}
func (s *stepRunRepository) AssignStepRunToTicker(tenantId, stepRunId string) (tickerId string, err error) {
err = retrier(s.l, func() error {
assigned, err := s.queries.AssignStepRunToTicker(context.Background(), s.pool, dbsqlc.AssignStepRunToTickerParams{
Steprunid: sqlchelpers.UUIDFromStr(stepRunId),
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
})
if err != nil {
return err
}
tickerId = sqlchelpers.UUIDToStr(assigned.TickerId)
return nil
})
return tickerId, err
}
func (s *stepRunRepository) UpdateStepRun(ctx context.Context, tenantId, stepRunId string, opts *repository.UpdateStepRunOpts) (*dbsqlc.GetStepRunForEngineRow, *repository.StepRunUpdateInfo, error) {
func (s *stepRunEngineRepository) UpdateStepRun(ctx context.Context, tenantId, stepRunId string, opts *repository.UpdateStepRunOpts) (*dbsqlc.GetStepRunForEngineRow, *repository.StepRunUpdateInfo, error) {
ctx, span := telemetry.NewSpan(ctx, "update-step-run")
defer span.End()
@@ -313,15 +340,13 @@ func (s *stepRunRepository) UpdateStepRun(ctx context.Context, tenantId, stepRun
})
if err != nil {
// non-fatal error, log and continue
s.l.Err(err).Msg("could not update step run extra")
return nil, nil, nil
return nil, nil, fmt.Errorf("could not update step run extra: %w", err)
}
return stepRun, updateInfo, nil
}
func (s *stepRunRepository) UpdateStepRunOverridesData(tenantId, stepRunId string, opts *repository.UpdateStepRunOverridesDataOpts) ([]byte, error) {
func (s *stepRunEngineRepository) UpdateStepRunOverridesData(tenantId, stepRunId string, opts *repository.UpdateStepRunOverridesDataOpts) ([]byte, error) {
if err := s.v.Validate(opts); err != nil {
return nil, err
}
@@ -378,7 +403,7 @@ func (s *stepRunRepository) UpdateStepRunOverridesData(tenantId, stepRunId strin
return input, nil
}
func (s *stepRunRepository) UpdateStepRunInputSchema(tenantId, stepRunId string, schema []byte) ([]byte, error) {
func (s *stepRunEngineRepository) UpdateStepRunInputSchema(tenantId, stepRunId string, schema []byte) ([]byte, error) {
tx, err := s.pool.Begin(context.Background())
if err != nil {
@@ -413,7 +438,7 @@ func (s *stepRunRepository) UpdateStepRunInputSchema(tenantId, stepRunId string,
return inputSchema, nil
}
func (s *stepRunRepository) QueueStepRun(ctx context.Context, tenantId, stepRunId string, opts *repository.UpdateStepRunOpts) (*dbsqlc.GetStepRunForEngineRow, error) {
func (s *stepRunEngineRepository) QueueStepRun(ctx context.Context, tenantId, stepRunId string, opts *repository.UpdateStepRunOpts) (*dbsqlc.GetStepRunForEngineRow, error) {
ctx, span := telemetry.NewSpan(ctx, "queue-step-run-database")
defer span.End()
@@ -471,7 +496,7 @@ func (s *stepRunRepository) QueueStepRun(ctx context.Context, tenantId, stepRunI
})
if retrierErr != nil {
return nil, fmt.Errorf("could not queue step run: %w", err)
return nil, fmt.Errorf("could not queue step run: %w", retrierErr)
}
retrierExtraErr := retrier(s.l, func() error {
@@ -495,9 +520,7 @@ func (s *stepRunRepository) QueueStepRun(ctx context.Context, tenantId, stepRunI
})
if retrierExtraErr != nil {
// non-fatal error, log and continue
s.l.Err(err).Msg("could not update step run extra")
return nil, nil
return nil, fmt.Errorf("could not update step run extra: %w", retrierExtraErr)
}
return stepRun, nil
@@ -600,7 +623,7 @@ func getUpdateParams(
return updateParams, updateJobRunLookupDataParams, resolveJobRunParams, resolveLaterStepRunsParams, nil
}
func (s *stepRunRepository) updateStepRunCore(
func (s *stepRunEngineRepository) updateStepRunCore(
ctx context.Context,
tx pgx.Tx,
tenantId string,
@@ -618,7 +641,7 @@ func (s *stepRunRepository) updateStepRunCore(
stepRuns, err := s.queries.GetStepRunForEngine(ctx, tx, dbsqlc.GetStepRunForEngineParams{
Ids: []pgtype.UUID{updateStepRun.ID},
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
TenantId: sqlchelpers.UUIDFromStr(tenantId),
})
if err != nil {
@@ -641,7 +664,7 @@ func (s *stepRunRepository) updateStepRunCore(
return stepRuns[0], nil
}
func (s *stepRunRepository) updateStepRunExtra(
func (s *stepRunEngineRepository) updateStepRunExtra(
ctx context.Context,
tx pgx.Tx,
tenantId string,
@@ -690,31 +713,11 @@ func isFinalWorkflowRunStatus(status dbsqlc.WorkflowRunStatus) bool {
return status != dbsqlc.WorkflowRunStatusPENDING && status != dbsqlc.WorkflowRunStatusRUNNING && status != dbsqlc.WorkflowRunStatusQUEUED
}
func (s *stepRunRepository) GetStepRunById(tenantId, stepRunId string) (*db.StepRunModel, error) {
return s.client.StepRun.FindUnique(
db.StepRun.ID.Equals(stepRunId),
).With(
db.StepRun.Children.Fetch(),
db.StepRun.Parents.Fetch().With(
db.StepRun.Step.Fetch(),
),
db.StepRun.Step.Fetch().With(
db.Step.Job.Fetch(),
db.Step.Action.Fetch(),
),
db.StepRun.JobRun.Fetch().With(
db.JobRun.LookupData.Fetch(),
db.JobRun.WorkflowRun.Fetch(),
),
db.StepRun.Ticker.Fetch(),
).Exec(context.Background())
}
// performant query for step run id, only returns what the engine needs
func (s *stepRunRepository) GetStepRunForEngine(tenantId, stepRunId string) (*dbsqlc.GetStepRunForEngineRow, error) {
func (s *stepRunEngineRepository) GetStepRunForEngine(tenantId, stepRunId string) (*dbsqlc.GetStepRunForEngineRow, error) {
res, err := s.queries.GetStepRunForEngine(context.Background(), s.pool, dbsqlc.GetStepRunForEngineParams{
Ids: []pgtype.UUID{sqlchelpers.UUIDFromStr(stepRunId)},
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
TenantId: sqlchelpers.UUIDFromStr(tenantId),
})
if err != nil {
@@ -728,7 +731,7 @@ func (s *stepRunRepository) GetStepRunForEngine(tenantId, stepRunId string) (*db
return res[0], nil
}
func (s *stepRunRepository) ListStartableStepRuns(tenantId, jobRunId string, parentStepRunId *string) ([]*dbsqlc.GetStepRunForEngineRow, error) {
func (s *stepRunEngineRepository) ListStartableStepRuns(tenantId, jobRunId string, parentStepRunId *string) ([]*dbsqlc.GetStepRunForEngineRow, error) {
tx, err := s.pool.Begin(context.Background())
if err != nil {
@@ -753,7 +756,7 @@ func (s *stepRunRepository) ListStartableStepRuns(tenantId, jobRunId string, par
res, err := s.queries.GetStepRunForEngine(context.Background(), tx, dbsqlc.GetStepRunForEngineParams{
Ids: srs,
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
TenantId: sqlchelpers.UUIDFromStr(tenantId),
})
if err != nil {
@@ -765,7 +768,7 @@ func (s *stepRunRepository) ListStartableStepRuns(tenantId, jobRunId string, par
return res, err
}
func (s *stepRunRepository) ArchiveStepRunResult(tenantId, stepRunId string) error {
func (s *stepRunEngineRepository) ArchiveStepRunResult(tenantId, stepRunId string) error {
_, err := s.queries.ArchiveStepRunResultFromStepRun(context.Background(), s.pool, dbsqlc.ArchiveStepRunResultFromStepRunParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Steprunid: sqlchelpers.UUIDFromStr(stepRunId),
@@ -774,28 +777,6 @@ func (s *stepRunRepository) ArchiveStepRunResult(tenantId, stepRunId string) err
return err
}
func (s *stepRunRepository) ListArchivedStepRunResults(tenantId, stepRunId string) ([]db.StepRunResultArchiveModel, error) {
return s.client.StepRunResultArchive.FindMany(
db.StepRunResultArchive.StepRunID.Equals(stepRunId),
db.StepRunResultArchive.StepRun.Where(
db.StepRun.TenantID.Equals(tenantId),
),
).OrderBy(
db.StepRunResultArchive.Order.Order(db.DESC),
).Exec(context.Background())
}
func (s *stepRunRepository) GetFirstArchivedStepRunResult(tenantId, stepRunId string) (*db.StepRunResultArchiveModel, error) {
return s.client.StepRunResultArchive.FindFirst(
db.StepRunResultArchive.StepRunID.Equals(stepRunId),
db.StepRunResultArchive.StepRun.Where(
db.StepRun.TenantID.Equals(tenantId),
),
).OrderBy(
db.StepRunResultArchive.Order.Order(db.ASC),
).Exec(context.Background())
}
// sleepWithJitter sleeps for a random duration between min and max duration.
// min and max are time.Duration values, specifying the minimum and maximum sleep times.
func sleepWithJitter(min, max time.Duration) {
+48 -17
View File
@@ -3,27 +3,32 @@ package prisma
import (
"context"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/rs/zerolog"
"github.com/hatchet-dev/hatchet/internal/repository"
"github.com/hatchet-dev/hatchet/internal/repository/cache"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/db"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/dbsqlc"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/sqlchelpers"
"github.com/hatchet-dev/hatchet/internal/validator"
)
type tenantRepository struct {
type tenantAPIRepository struct {
client *db.PrismaClient
v validator.Validator
cache cache.Cacheable
}
func NewTenantRepository(client *db.PrismaClient, v validator.Validator, cache cache.Cacheable) repository.TenantRepository {
return &tenantRepository{
func NewTenantAPIRepository(client *db.PrismaClient, v validator.Validator, cache cache.Cacheable) repository.TenantAPIRepository {
return &tenantAPIRepository{
client: client,
v: v,
cache: cache,
}
}
func (r *tenantRepository) CreateTenant(opts *repository.CreateTenantOpts) (*db.TenantModel, error) {
func (r *tenantAPIRepository) CreateTenant(opts *repository.CreateTenantOpts) (*db.TenantModel, error) {
if err := r.v.Validate(opts); err != nil {
return nil, err
}
@@ -35,11 +40,7 @@ func (r *tenantRepository) CreateTenant(opts *repository.CreateTenantOpts) (*db.
).Exec(context.Background())
}
func (r *tenantRepository) ListTenants() ([]db.TenantModel, error) {
return r.client.Tenant.FindMany().Exec(context.Background())
}
func (r *tenantRepository) GetTenantByID(id string) (*db.TenantModel, error) {
func (r *tenantAPIRepository) GetTenantByID(id string) (*db.TenantModel, error) {
return cache.MakeCacheable[db.TenantModel](r.cache, id, func() (*db.TenantModel, error) {
return r.client.Tenant.FindUnique(
db.Tenant.ID.Equals(id),
@@ -47,13 +48,13 @@ func (r *tenantRepository) GetTenantByID(id string) (*db.TenantModel, error) {
})
}
func (r *tenantRepository) GetTenantBySlug(slug string) (*db.TenantModel, error) {
func (r *tenantAPIRepository) GetTenantBySlug(slug string) (*db.TenantModel, error) {
return r.client.Tenant.FindUnique(
db.Tenant.Slug.Equals(slug),
).Exec(context.Background())
}
func (r *tenantRepository) CreateTenantMember(tenantId string, opts *repository.CreateTenantMemberOpts) (*db.TenantMemberModel, error) {
func (r *tenantAPIRepository) CreateTenantMember(tenantId string, opts *repository.CreateTenantMemberOpts) (*db.TenantMemberModel, error) {
if err := r.v.Validate(opts); err != nil {
return nil, err
}
@@ -65,13 +66,13 @@ func (r *tenantRepository) CreateTenantMember(tenantId string, opts *repository.
).Exec(context.Background())
}
func (r *tenantRepository) GetTenantMemberByID(memberId string) (*db.TenantMemberModel, error) {
func (r *tenantAPIRepository) GetTenantMemberByID(memberId string) (*db.TenantMemberModel, error) {
return r.client.TenantMember.FindUnique(
db.TenantMember.ID.Equals(memberId),
).Exec(context.Background())
}
func (r *tenantRepository) GetTenantMemberByUserID(tenantId string, userId string) (*db.TenantMemberModel, error) {
func (r *tenantAPIRepository) GetTenantMemberByUserID(tenantId string, userId string) (*db.TenantMemberModel, error) {
return r.client.TenantMember.FindUnique(
db.TenantMember.TenantIDUserID(
db.TenantMember.TenantID.Equals(tenantId),
@@ -80,7 +81,7 @@ func (r *tenantRepository) GetTenantMemberByUserID(tenantId string, userId strin
).Exec(context.Background())
}
func (r *tenantRepository) ListTenantMembers(tenantId string) ([]db.TenantMemberModel, error) {
func (r *tenantAPIRepository) ListTenantMembers(tenantId string) ([]db.TenantMemberModel, error) {
return r.client.TenantMember.FindMany(
db.TenantMember.TenantID.Equals(tenantId),
).With(
@@ -89,7 +90,7 @@ func (r *tenantRepository) ListTenantMembers(tenantId string) ([]db.TenantMember
).Exec(context.Background())
}
func (r *tenantRepository) GetTenantMemberByEmail(tenantId string, email string) (*db.TenantMemberModel, error) {
func (r *tenantAPIRepository) GetTenantMemberByEmail(tenantId string, email string) (*db.TenantMemberModel, error) {
user, err := r.client.User.FindUnique(
db.User.Email.Equals(email),
).Exec(context.Background())
@@ -106,7 +107,7 @@ func (r *tenantRepository) GetTenantMemberByEmail(tenantId string, email string)
).Exec(context.Background())
}
func (r *tenantRepository) UpdateTenantMember(memberId string, opts *repository.UpdateTenantMemberOpts) (*db.TenantMemberModel, error) {
func (r *tenantAPIRepository) UpdateTenantMember(memberId string, opts *repository.UpdateTenantMemberOpts) (*db.TenantMemberModel, error) {
if err := r.v.Validate(opts); err != nil {
return nil, err
}
@@ -124,8 +125,38 @@ func (r *tenantRepository) UpdateTenantMember(memberId string, opts *repository.
).Exec(context.Background())
}
func (r *tenantRepository) DeleteTenantMember(memberId string) (*db.TenantMemberModel, error) {
func (r *tenantAPIRepository) DeleteTenantMember(memberId string) (*db.TenantMemberModel, error) {
return r.client.TenantMember.FindUnique(
db.TenantMember.ID.Equals(memberId),
).Delete().Exec(context.Background())
}
type tenantEngineRepository struct {
cache cache.Cacheable
pool *pgxpool.Pool
v validator.Validator
l *zerolog.Logger
queries *dbsqlc.Queries
}
func NewTenantEngineRepository(pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger, cache cache.Cacheable) repository.TenantEngineRepository {
queries := dbsqlc.New()
return &tenantEngineRepository{
cache: cache,
pool: pool,
v: v,
l: l,
queries: queries,
}
}
func (r *tenantEngineRepository) ListTenants() ([]*dbsqlc.Tenant, error) {
return r.queries.ListTenants(context.Background(), r.pool)
}
func (r *tenantEngineRepository) GetTenantByID(tenantId string) (*dbsqlc.Tenant, error) {
return cache.MakeCacheable[dbsqlc.Tenant](r.cache, tenantId, func() (*dbsqlc.Tenant, error) {
return r.queries.GetTenantByID(context.Background(), r.pool, sqlchelpers.UUIDFromStr(tenantId))
})
}
+167 -148
View File
@@ -2,32 +2,27 @@ package prisma
import (
"context"
"time"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/rs/zerolog"
"github.com/hatchet-dev/hatchet/internal/repository"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/db"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/dbsqlc"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/sqlchelpers"
"github.com/hatchet-dev/hatchet/internal/validator"
)
type tickerRepository struct {
client *db.PrismaClient
pool *pgxpool.Pool
v validator.Validator
queries *dbsqlc.Queries
l *zerolog.Logger
}
func NewTickerRepository(client *db.PrismaClient, pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.TickerRepository {
func NewTickerRepository(pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.TickerEngineRepository {
queries := dbsqlc.New()
return &tickerRepository{
client: client,
pool: pool,
v: v,
queries: queries,
@@ -35,194 +30,218 @@ func NewTickerRepository(client *db.PrismaClient, pool *pgxpool.Pool, v validato
}
}
func (t *tickerRepository) CreateNewTicker(opts *repository.CreateTickerOpts) (*db.TickerModel, error) {
return t.client.Ticker.CreateOne(
db.Ticker.ID.Set(opts.ID),
db.Ticker.LastHeartbeatAt.Set(time.Now().UTC()),
).Exec(context.Background())
}
func (t *tickerRepository) UpdateTicker(tickerId string, opts *repository.UpdateTickerOpts) (*db.TickerModel, error) {
func (t *tickerRepository) CreateNewTicker(opts *repository.CreateTickerOpts) (*dbsqlc.Ticker, error) {
if err := t.v.Validate(opts); err != nil {
return nil, err
}
return t.client.Ticker.FindUnique(
db.Ticker.ID.Equals(tickerId),
).Update(
db.Ticker.LastHeartbeatAt.SetIfPresent(opts.LastHeartbeatAt),
).Exec(context.Background())
return t.queries.CreateTicker(context.Background(), t.pool, sqlchelpers.UUIDFromStr(opts.ID))
}
func (t *tickerRepository) ListTickers(opts *repository.ListTickerOpts) ([]db.TickerModel, error) {
func (t *tickerRepository) UpdateTicker(tickerId string, opts *repository.UpdateTickerOpts) (*dbsqlc.Ticker, error) {
if err := t.v.Validate(opts); err != nil {
return nil, err
}
params := []db.TickerWhereParam{}
return t.queries.UpdateTicker(
context.Background(),
t.pool,
dbsqlc.UpdateTickerParams{
ID: sqlchelpers.UUIDFromStr(tickerId),
LastHeartbeatAt: sqlchelpers.TimestampFromTime(opts.LastHeartbeatAt.UTC()),
},
)
}
if opts.LatestHeartbeatAt != nil {
params = append(params, db.Ticker.LastHeartbeatAt.Gt(*opts.LatestHeartbeatAt))
func (t *tickerRepository) ListTickers(opts *repository.ListTickerOpts) ([]*dbsqlc.Ticker, error) {
if err := t.v.Validate(opts); err != nil {
return nil, err
}
params := dbsqlc.ListTickersParams{}
if opts.LatestHeartbeatAfter != nil {
params.LastHeartbeatAfter = sqlchelpers.TimestampFromTime(opts.LatestHeartbeatAfter.UTC())
}
if opts.Active != nil {
params = append(params, db.Ticker.IsActive.Equals(*opts.Active))
params.IsActive = *opts.Active
}
return t.client.Ticker.FindMany(
params...,
).Exec(context.Background())
return t.queries.ListTickers(
context.Background(),
t.pool,
params,
)
}
func (t *tickerRepository) Delete(tickerId string) error {
_, err := t.client.Ticker.FindUnique(
db.Ticker.ID.Equals(tickerId),
).Delete().Exec(context.Background())
_, err := t.queries.DeleteTicker(
context.Background(),
t.pool,
sqlchelpers.UUIDFromStr(tickerId),
)
return err
}
func (t *tickerRepository) AddJobRun(tickerId string, jobRun *db.JobRunModel) (*db.TickerModel, error) {
return t.client.Ticker.FindUnique(
db.Ticker.ID.Equals(tickerId),
).Update(
db.Ticker.JobRuns.Link(
db.JobRun.ID.Equals(jobRun.ID),
),
).Exec(context.Background())
func (t *tickerRepository) PollStepRuns(tickerId string) ([]*dbsqlc.StepRun, error) {
return t.queries.PollStepRuns(context.Background(), t.pool, sqlchelpers.UUIDFromStr(tickerId))
}
func (t *tickerRepository) AddStepRun(tickerId, stepRunId string) (*db.TickerModel, error) {
return t.client.Ticker.FindUnique(
db.Ticker.ID.Equals(tickerId),
).Update(
db.Ticker.StepRuns.Link(
db.StepRun.ID.Equals(stepRunId),
),
).Exec(context.Background())
func (t *tickerRepository) PollGetGroupKeyRuns(tickerId string) ([]*dbsqlc.GetGroupKeyRun, error) {
return t.queries.PollGetGroupKeyRuns(context.Background(), t.pool, sqlchelpers.UUIDFromStr(tickerId))
}
func (t *tickerRepository) AddGetGroupKeyRun(tickerId, getGroupKeyRunId string) (*db.TickerModel, error) {
return t.client.Ticker.FindUnique(
db.Ticker.ID.Equals(tickerId),
).Update(
db.Ticker.GroupKeyRuns.Link(
db.GetGroupKeyRun.ID.Equals(getGroupKeyRunId),
),
).Exec(context.Background())
func (t *tickerRepository) PollCronSchedules(tickerId string) ([]*dbsqlc.PollCronSchedulesRow, error) {
return t.queries.PollCronSchedules(context.Background(), t.pool, sqlchelpers.UUIDFromStr(tickerId))
}
func (t *tickerRepository) AddCron(tickerId string, cron *db.WorkflowTriggerCronRefModel) (*db.TickerModel, error) {
return t.client.Ticker.FindUnique(
db.Ticker.ID.Equals(tickerId),
).Update(
db.Ticker.Crons.Link(
db.WorkflowTriggerCronRef.ParentIDCron(
db.WorkflowTriggerCronRef.ParentID.Equals(cron.ParentID),
db.WorkflowTriggerCronRef.Cron.Equals(cron.Cron),
),
),
).Exec(context.Background())
func (t *tickerRepository) PollScheduledWorkflows(tickerId string) ([]*dbsqlc.PollScheduledWorkflowsRow, error) {
return t.queries.PollScheduledWorkflows(context.Background(), t.pool, sqlchelpers.UUIDFromStr(tickerId))
}
func (t *tickerRepository) RemoveCron(tickerId string, cron *db.WorkflowTriggerCronRefModel) (*db.TickerModel, error) {
return t.client.Ticker.FindUnique(
db.Ticker.ID.Equals(tickerId),
).Update(
db.Ticker.Crons.Unlink(
db.WorkflowTriggerCronRef.ParentIDCron(
db.WorkflowTriggerCronRef.ParentID.Equals(cron.ParentID),
db.WorkflowTriggerCronRef.Cron.Equals(cron.Cron),
),
),
).Exec(context.Background())
}
// func (t *tickerRepository) AddJobRun(tickerId string, jobRun *db.JobRunModel) (*db.TickerModel, error) {
// return t.client.Ticker.FindUnique(
// db.Ticker.ID.Equals(tickerId),
// ).Update(
// db.Ticker.JobRuns.Link(
// db.JobRun.ID.Equals(jobRun.ID),
// ),
// ).Exec(context.Background())
// }
func (t *tickerRepository) AddScheduledWorkflow(tickerId string, schedule *db.WorkflowTriggerScheduledRefModel) (*db.TickerModel, error) {
return t.client.Ticker.FindUnique(
db.Ticker.ID.Equals(tickerId),
).Update(
db.Ticker.Scheduled.Link(
db.WorkflowTriggerScheduledRef.ID.Equals(schedule.ID),
),
).Exec(context.Background())
}
// func (t *tickerRepository) AddStepRun(tickerId, stepRunId string) (*db.TickerModel, error) {
// return t.client.Ticker.FindUnique(
// db.Ticker.ID.Equals(tickerId),
// ).Update(
// db.Ticker.StepRuns.Link(
// db.StepRun.ID.Equals(stepRunId),
// ),
// ).Exec(context.Background())
// }
func (t *tickerRepository) RemoveScheduledWorkflow(tickerId string, schedule *db.WorkflowTriggerScheduledRefModel) (*db.TickerModel, error) {
return t.client.Ticker.FindUnique(
db.Ticker.ID.Equals(tickerId),
).Update(
db.Ticker.Scheduled.Unlink(
db.WorkflowTriggerScheduledRef.ID.Equals(schedule.ID),
),
).Exec(context.Background())
}
// func (t *tickerRepository) AddGetGroupKeyRun(tickerId, getGroupKeyRunId string) (*db.TickerModel, error) {
// return t.client.Ticker.FindUnique(
// db.Ticker.ID.Equals(tickerId),
// ).Update(
// db.Ticker.GroupKeyRuns.Link(
// db.GetGroupKeyRun.ID.Equals(getGroupKeyRunId),
// ),
// ).Exec(context.Background())
// }
func (t *tickerRepository) GetTickerById(tickerId string) (*db.TickerModel, error) {
return t.client.Ticker.FindUnique(
db.Ticker.ID.Equals(tickerId),
).With(
db.Ticker.Crons.Fetch().With(
db.WorkflowTriggerCronRef.Parent.Fetch().With(
db.WorkflowTriggers.Workflow.Fetch().With(
db.WorkflowVersion.Workflow.Fetch(),
),
),
),
db.Ticker.Scheduled.Fetch().With(
db.WorkflowTriggerScheduledRef.Parent.Fetch().With(
db.WorkflowVersion.Workflow.Fetch(),
),
),
).Exec(context.Background())
}
// func (t *tickerRepository) AddCron(tickerId string, cron *db.WorkflowTriggerCronRefModel) (*db.TickerModel, error) {
// return t.client.Ticker.FindUnique(
// db.Ticker.ID.Equals(tickerId),
// ).Update(
// db.Ticker.Crons.Link(
// db.WorkflowTriggerCronRef.ParentIDCron(
// db.WorkflowTriggerCronRef.ParentID.Equals(cron.ParentID),
// db.WorkflowTriggerCronRef.Cron.Equals(cron.Cron),
// ),
// ),
// ).Exec(context.Background())
// }
func (t *tickerRepository) UpdateStaleTickers(onStale func(tickerId string, getValidTickerId func() string) error) error {
tx, err := t.pool.Begin(context.Background())
// func (t *tickerRepository) RemoveCron(tickerId string, cron *db.WorkflowTriggerCronRefModel) (*db.TickerModel, error) {
// return t.client.Ticker.FindUnique(
// db.Ticker.ID.Equals(tickerId),
// ).Update(
// db.Ticker.Crons.Unlink(
// db.WorkflowTriggerCronRef.ParentIDCron(
// db.WorkflowTriggerCronRef.ParentID.Equals(cron.ParentID),
// db.WorkflowTriggerCronRef.Cron.Equals(cron.Cron),
// ),
// ),
// ).Exec(context.Background())
// }
if err != nil {
return err
}
// func (t *tickerRepository) AddScheduledWorkflow(tickerId string, schedule *db.WorkflowTriggerScheduledRefModel) (*db.TickerModel, error) {
// return t.client.Ticker.FindUnique(
// db.Ticker.ID.Equals(tickerId),
// ).Update(
// db.Ticker.Scheduled.Link(
// db.WorkflowTriggerScheduledRef.ID.Equals(schedule.ID),
// ),
// ).Exec(context.Background())
// }
defer deferRollback(context.Background(), t.l, tx.Rollback)
// func (t *tickerRepository) RemoveScheduledWorkflow(tickerId string, schedule *db.WorkflowTriggerScheduledRefModel) (*db.TickerModel, error) {
// return t.client.Ticker.FindUnique(
// db.Ticker.ID.Equals(tickerId),
// ).Update(
// db.Ticker.Scheduled.Unlink(
// db.WorkflowTriggerScheduledRef.ID.Equals(schedule.ID),
// ),
// ).Exec(context.Background())
// }
staleTickers, err := t.queries.ListNewlyStaleTickers(context.Background(), tx)
// func (t *tickerRepository) GetTickerById(tickerId string) (*db.TickerModel, error) {
// return t.client.Ticker.FindUnique(
// db.Ticker.ID.Equals(tickerId),
// ).With(
// db.Ticker.Crons.Fetch().With(
// db.WorkflowTriggerCronRef.Parent.Fetch().With(
// db.WorkflowTriggers.Workflow.Fetch().With(
// db.WorkflowVersion.Workflow.Fetch(),
// ),
// ),
// ),
// db.Ticker.Scheduled.Fetch().With(
// db.WorkflowTriggerScheduledRef.Parent.Fetch().With(
// db.WorkflowVersion.Workflow.Fetch(),
// ),
// ),
// ).Exec(context.Background())
// }
if err != nil {
return err
}
// func (t *tickerRepository) UpdateStaleTickers(onStale func(tickerId string, getValidTickerId func() string) error) error {
// tx, err := t.pool.Begin(context.Background())
activeTickers, err := t.queries.ListActiveTickers(context.Background(), tx)
// if err != nil {
// return err
// }
if err != nil {
return err
}
// defer deferRollback(context.Background(), t.l, tx.Rollback)
// if there are no active tickers, we can't reassign the stale tickers
if len(activeTickers) == 0 {
return nil
}
// staleTickers, err := t.queries.ListNewlyStaleTickers(context.Background(), tx)
tickersToDelete := make([]pgtype.UUID, 0)
// if err != nil {
// return err
// }
for i, ticker := range staleTickers {
err := onStale(sqlchelpers.UUIDToStr(ticker.Ticker.ID), func() string {
// assign tickers in round-robin fashion
return sqlchelpers.UUIDToStr(activeTickers[i%len(activeTickers)].Ticker.ID)
})
// activeTickers, err := t.queries.ListActiveTickers(context.Background(), tx)
if err != nil {
return err
}
// if err != nil {
// return err
// }
tickersToDelete = append(tickersToDelete, ticker.Ticker.ID)
}
// // if there are no active tickers, we can't reassign the stale tickers
// if len(activeTickers) == 0 {
// return nil
// }
_, err = t.queries.SetTickersInactive(context.Background(), tx, tickersToDelete)
// tickersToDelete := make([]pgtype.UUID, 0)
if err != nil {
return err
}
// for i, ticker := range staleTickers {
// err := onStale(sqlchelpers.UUIDToStr(ticker.Ticker.ID), func() string {
// // assign tickers in round-robin fashion
// return sqlchelpers.UUIDToStr(activeTickers[i%len(activeTickers)].Ticker.ID)
// })
return tx.Commit(context.Background())
}
// if err != nil {
// return err
// }
// tickersToDelete = append(tickersToDelete, ticker.Ticker.ID)
// }
// _, err = t.queries.SetTickersInactive(context.Background(), tx, tickersToDelete)
// if err != nil {
// return err
// }
// return tx.Commit(context.Background())
// }
+162 -176
View File
@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
@@ -18,7 +17,7 @@ import (
"github.com/hatchet-dev/hatchet/internal/validator"
)
type workerRepository struct {
type workerAPIRepository struct {
client *db.PrismaClient
pool *pgxpool.Pool
v validator.Validator
@@ -26,10 +25,10 @@ type workerRepository struct {
l *zerolog.Logger
}
func NewWorkerRepository(client *db.PrismaClient, pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.WorkerRepository {
func NewWorkerAPIRepository(client *db.PrismaClient, pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.WorkerAPIRepository {
queries := dbsqlc.New()
return &workerRepository{
return &workerAPIRepository{
client: client,
pool: pool,
v: v,
@@ -38,7 +37,7 @@ func NewWorkerRepository(client *db.PrismaClient, pool *pgxpool.Pool, v validato
}
}
func (w *workerRepository) GetWorkerById(workerId string) (*db.WorkerModel, error) {
func (w *workerAPIRepository) GetWorkerById(workerId string) (*db.WorkerModel, error) {
return w.client.Worker.FindUnique(
db.Worker.ID.Equals(workerId),
).With(
@@ -47,14 +46,7 @@ func (w *workerRepository) GetWorkerById(workerId string) (*db.WorkerModel, erro
).Exec(context.Background())
}
func (w *workerRepository) GetWorkerForEngine(tenantId, workerId string) (*dbsqlc.GetWorkerForEngineRow, error) {
return w.queries.GetWorkerForEngine(context.Background(), w.pool, dbsqlc.GetWorkerForEngineParams{
ID: sqlchelpers.UUIDFromStr(workerId),
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
})
}
func (w *workerRepository) ListRecentWorkerStepRuns(tenantId, workerId string) ([]db.StepRunModel, error) {
func (w *workerAPIRepository) ListRecentWorkerStepRuns(tenantId, workerId string) ([]db.StepRunModel, error) {
return w.client.StepRun.FindMany(
db.StepRun.WorkerID.Equals(workerId),
db.StepRun.TenantID.Equals(tenantId),
@@ -75,7 +67,7 @@ func (w *workerRepository) ListRecentWorkerStepRuns(tenantId, workerId string) (
).Exec(context.Background())
}
func (r *workerRepository) ListWorkers(tenantId string, opts *repository.ListWorkersOpts) ([]*dbsqlc.ListWorkersWithStepCountRow, error) {
func (r *workerAPIRepository) ListWorkers(tenantId string, opts *repository.ListWorkersOpts) ([]*dbsqlc.ListWorkersWithStepCountRow, error) {
if err := r.v.Validate(opts); err != nil {
return nil, err
}
@@ -128,205 +120,199 @@ func (r *workerRepository) ListWorkers(tenantId string, opts *repository.ListWor
return workers, nil
}
func (w *workerRepository) CreateNewWorker(tenantId string, opts *repository.CreateWorkerOpts) (*db.WorkerModel, error) {
type workerEngineRepository struct {
pool *pgxpool.Pool
v validator.Validator
queries *dbsqlc.Queries
l *zerolog.Logger
}
func NewWorkerEngineRepository(pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.WorkerEngineRepository {
queries := dbsqlc.New()
return &workerEngineRepository{
pool: pool,
v: v,
queries: queries,
l: l,
}
}
func (w *workerEngineRepository) GetWorkerForEngine(tenantId, workerId string) (*dbsqlc.GetWorkerForEngineRow, error) {
return w.queries.GetWorkerForEngine(context.Background(), w.pool, dbsqlc.GetWorkerForEngineParams{
ID: sqlchelpers.UUIDFromStr(workerId),
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
})
}
func (w *workerEngineRepository) CreateNewWorker(tenantId string, opts *repository.CreateWorkerOpts) (*dbsqlc.Worker, error) {
if err := w.v.Validate(opts); err != nil {
return nil, err
}
txs := []db.PrismaTransaction{}
workerId := uuid.New().String()
createTx := w.client.Worker.CreateOne(
db.Worker.Tenant.Link(
db.Tenant.ID.Equals(tenantId),
),
db.Worker.Name.Set(opts.Name),
db.Worker.Dispatcher.Link(
db.Dispatcher.ID.Equals(opts.DispatcherId),
),
db.Worker.ID.Set(workerId),
db.Worker.MaxRuns.SetIfPresent(opts.MaxRuns),
).Tx()
txs = append(txs, createTx)
for _, svc := range opts.Services {
upsertServiceTx := w.client.Service.UpsertOne(
db.Service.TenantIDName(
db.Service.TenantID.Equals(tenantId),
db.Service.Name.Equals(svc),
),
).Create(
db.Service.Name.Set(svc),
db.Service.Tenant.Link(
db.Tenant.ID.Equals(tenantId),
),
db.Service.Workers.Link(
db.Worker.ID.Equals(workerId),
),
).Update(
db.Service.Workers.Link(
db.Worker.ID.Equals(workerId),
),
).Tx()
txs = append(txs, upsertServiceTx)
}
if len(opts.Actions) > 0 {
for _, action := range opts.Actions {
txs = append(txs, w.client.Action.UpsertOne(
db.Action.TenantIDActionID(
db.Action.TenantID.Equals(tenantId),
db.Action.ActionID.Equals(action),
),
).Create(
db.Action.ActionID.Set(action),
db.Action.Tenant.Link(
db.Tenant.ID.Equals(tenantId),
),
).Update().Tx())
// This is unfortunate but due to https://github.com/steebchen/prisma-client-go/issues/1095,
// we cannot set db.Worker.Actions.Link multiple times, and since Link required a unique action
// where clause, we have to do these in separate transactions
txs = append(txs, w.client.Worker.FindUnique(
db.Worker.ID.Equals(workerId),
).Update(
db.Worker.Actions.Link(
db.Action.TenantIDActionID(
db.Action.TenantID.Equals(tenantId),
db.Action.ActionID.Equals(action),
),
),
).Tx())
}
}
err := w.client.Prisma.Transaction(txs...).Exec(context.Background())
tx, err := w.pool.Begin(context.Background())
if err != nil {
return nil, err
}
return createTx.Result(), nil
defer deferRollback(context.Background(), w.l, tx.Rollback)
pgTenantId := sqlchelpers.UUIDFromStr(tenantId)
createParams := dbsqlc.CreateWorkerParams{
Tenantid: pgTenantId,
Dispatcherid: sqlchelpers.UUIDFromStr(opts.DispatcherId),
Name: opts.Name,
}
if opts.MaxRuns != nil {
createParams.MaxRuns = pgtype.Int4{
Int32: int32(*opts.MaxRuns),
Valid: true,
}
}
worker, err := w.queries.CreateWorker(context.Background(), tx, createParams)
if err != nil {
return nil, fmt.Errorf("could not create worker: %w", err)
}
svcUUIDs := make([]pgtype.UUID, len(opts.Services))
for i, svc := range opts.Services {
dbSvc, err := w.queries.UpsertService(context.Background(), tx, dbsqlc.UpsertServiceParams{
Name: svc,
Tenantid: pgTenantId,
})
if err != nil {
return nil, fmt.Errorf("could not upsert service: %w", err)
}
svcUUIDs[i] = dbSvc.ID
}
err = w.queries.LinkServicesToWorker(context.Background(), tx, dbsqlc.LinkServicesToWorkerParams{
Services: svcUUIDs,
Workerid: worker.ID,
})
if err != nil {
return nil, fmt.Errorf("could not link services to worker: %w", err)
}
actionUUIDs := make([]pgtype.UUID, len(opts.Actions))
for i, action := range opts.Actions {
dbAction, err := w.queries.UpsertAction(context.Background(), tx, dbsqlc.UpsertActionParams{
Action: action,
Tenantid: pgTenantId,
})
if err != nil {
return nil, fmt.Errorf("could not upsert action: %w", err)
}
actionUUIDs[i] = dbAction.ID
}
err = w.queries.LinkActionsToWorker(context.Background(), tx, dbsqlc.LinkActionsToWorkerParams{
Actionids: actionUUIDs,
Workerid: worker.ID,
})
if err != nil {
return nil, fmt.Errorf("could not link actions to worker: %w", err)
}
err = tx.Commit(context.Background())
if err != nil {
return nil, fmt.Errorf("could not commit transaction: %w", err)
}
return worker, nil
}
func (w *workerRepository) UpdateWorker(tenantId, workerId string, opts *repository.UpdateWorkerOpts) (*db.WorkerModel, error) {
func (w *workerEngineRepository) UpdateWorker(tenantId, workerId string, opts *repository.UpdateWorkerOpts) (*dbsqlc.Worker, error) {
if err := w.v.Validate(opts); err != nil {
return nil, err
}
txs := []db.PrismaTransaction{}
tx, err := w.pool.Begin(context.Background())
optionals := []db.WorkerSetParam{}
if err != nil {
return nil, err
}
defer deferRollback(context.Background(), w.l, tx.Rollback)
pgTenantId := sqlchelpers.UUIDFromStr(tenantId)
updateParams := dbsqlc.UpdateWorkerParams{
ID: sqlchelpers.UUIDFromStr(workerId),
}
if opts.Status != nil {
optionals = append(optionals, db.Worker.Status.Set(*opts.Status))
updateParams.Status = dbsqlc.NullWorkerStatus{
WorkerStatus: dbsqlc.WorkerStatus(*opts.Status),
Valid: true,
}
}
if opts.LastHeartbeatAt != nil {
optionals = append(optionals, db.Worker.LastHeartbeatAt.Set(*opts.LastHeartbeatAt))
updateParams.LastHeartbeatAt = sqlchelpers.TimestampFromTime(*opts.LastHeartbeatAt)
}
if opts.DispatcherId != nil {
optionals = append(optionals, db.Worker.Dispatcher.Link(
db.Dispatcher.ID.Equals(*opts.DispatcherId),
))
updateParams.DispatcherId = sqlchelpers.UUIDFromStr(*opts.DispatcherId)
}
worker, err := w.queries.UpdateWorker(context.Background(), tx, updateParams)
if err != nil {
return nil, fmt.Errorf("could not update worker: %w", err)
}
if len(opts.Actions) > 0 {
for _, action := range opts.Actions {
txs = append(txs, w.client.Action.UpsertOne(
db.Action.TenantIDActionID(
db.Action.TenantID.Equals(tenantId),
db.Action.ActionID.Equals(action),
),
).Create(
db.Action.ActionID.Set(action),
db.Action.Tenant.Link(
db.Tenant.ID.Equals(tenantId),
),
).Update().Tx())
actionUUIDs := make([]pgtype.UUID, len(opts.Actions))
// This is unfortunate but due to https://github.com/steebchen/prisma-client-go/issues/1095,
// we cannot set db.Worker.Actions.Link multiple times, and since Link required a unique action
// where clause, we have to do these in separate transactions
txs = append(txs, w.client.Worker.FindUnique(
db.Worker.ID.Equals(workerId),
).Update(
db.Worker.Actions.Link(
db.Action.TenantIDActionID(
db.Action.TenantID.Equals(tenantId),
db.Action.ActionID.Equals(action),
),
),
).Tx())
for i, action := range opts.Actions {
dbAction, err := w.queries.UpsertAction(context.Background(), tx, dbsqlc.UpsertActionParams{
Action: action,
Tenantid: pgTenantId,
})
if err != nil {
return nil, fmt.Errorf("could not upsert action: %w", err)
}
actionUUIDs[i] = dbAction.ID
}
err = w.queries.LinkActionsToWorker(context.Background(), tx, dbsqlc.LinkActionsToWorkerParams{
Actionids: actionUUIDs,
Workerid: sqlchelpers.UUIDFromStr(workerId),
})
if err != nil {
return nil, fmt.Errorf("could not link actions to worker: %w", err)
}
}
updateTx := w.client.Worker.FindUnique(
db.Worker.ID.Equals(workerId),
).Update(
optionals...,
).Tx()
txs = append(txs, updateTx)
err := w.client.Prisma.Transaction(txs...).Exec(context.Background())
err = tx.Commit(context.Background())
if err != nil {
return nil, err
return nil, fmt.Errorf("could not commit transaction: %w", err)
}
return updateTx.Result(), nil
return worker, nil
}
func (w *workerRepository) DeleteWorker(tenantId, workerId string) error {
_, err := w.client.Worker.FindUnique(
db.Worker.ID.Equals(workerId),
).Delete().Exec(context.Background())
return err
}
func (w *workerRepository) AddStepRun(tenantId, workerId, stepRunId string) error {
tx1 := w.client.Worker.FindUnique(
db.Worker.ID.Equals(workerId),
).Update(
db.Worker.StepRuns.Link(
db.StepRun.ID.Equals(stepRunId),
),
).Tx()
tx2 := w.client.StepRun.FindUnique(
db.StepRun.ID.Equals(stepRunId),
).Update(
db.StepRun.Status.Set(db.StepRunStatusAssigned),
).Tx()
err := w.client.Prisma.Transaction(tx1, tx2).Exec(context.Background())
return err
}
func (w *workerRepository) AddGetGroupKeyRun(tenantId, workerId, getGroupKeyRunId string) error {
tx1 := w.client.Worker.FindUnique(
db.Worker.ID.Equals(workerId),
).Update(
db.Worker.GroupKeyRuns.Link(
db.GetGroupKeyRun.ID.Equals(getGroupKeyRunId),
),
).Tx()
tx2 := w.client.GetGroupKeyRun.FindUnique(
db.GetGroupKeyRun.ID.Equals(getGroupKeyRunId),
).Update(
db.GetGroupKeyRun.Status.Set(db.StepRunStatusAssigned),
).Tx()
err := w.client.Prisma.Transaction(tx1, tx2).Exec(context.Background())
func (w *workerEngineRepository) DeleteWorker(tenantId, workerId string) error {
_, err := w.queries.DeleteWorker(context.Background(), w.pool, sqlchelpers.UUIDFromStr(workerId))
return err
}
+182 -132
View File
@@ -21,7 +21,7 @@ import (
"github.com/hatchet-dev/hatchet/internal/validator"
)
type workflowRepository struct {
type workflowAPIRepository struct {
client *db.PrismaClient
pool *pgxpool.Pool
v validator.Validator
@@ -29,10 +29,10 @@ type workflowRepository struct {
l *zerolog.Logger
}
func NewWorkflowRepository(client *db.PrismaClient, pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.WorkflowRepository {
func NewWorkflowRepository(client *db.PrismaClient, pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.WorkflowAPIRepository {
queries := dbsqlc.New()
return &workflowRepository{
return &workflowAPIRepository{
client: client,
v: v,
queries: queries,
@@ -41,7 +41,7 @@ func NewWorkflowRepository(client *db.PrismaClient, pool *pgxpool.Pool, v valida
}
}
func (r *workflowRepository) ListWorkflows(tenantId string, opts *repository.ListWorkflowsOpts) (*repository.ListWorkflowsResult, error) {
func (r *workflowAPIRepository) ListWorkflows(tenantId string, opts *repository.ListWorkflowsOpts) (*repository.ListWorkflowsResult, error) {
if err := r.v.Validate(opts); err != nil {
return nil, err
}
@@ -161,7 +161,94 @@ func (r *workflowRepository) ListWorkflows(tenantId string, opts *repository.Lis
return res, nil
}
func (r *workflowRepository) CreateNewWorkflow(tenantId string, opts *repository.CreateWorkflowVersionOpts) (*db.WorkflowVersionModel, error) {
func (r *workflowAPIRepository) GetWorkflowById(workflowId string) (*db.WorkflowModel, error) {
return r.client.Workflow.FindUnique(
db.Workflow.ID.Equals(workflowId),
).With(
defaultWorkflowPopulator()...,
).Exec(context.Background())
}
func (r *workflowAPIRepository) GetWorkflowByName(tenantId, workflowName string) (*db.WorkflowModel, error) {
return r.client.Workflow.FindUnique(
db.Workflow.TenantIDName(
db.Workflow.TenantID.Equals(tenantId),
db.Workflow.Name.Equals(workflowName),
),
).With(
defaultWorkflowPopulator()...,
).Exec(context.Background())
}
func (r *workflowAPIRepository) GetWorkflowVersionById(tenantId, workflowVersionId string) (*db.WorkflowVersionModel, error) {
return r.client.WorkflowVersion.FindUnique(
db.WorkflowVersion.ID.Equals(workflowVersionId),
).With(
defaultWorkflowVersionPopulator()...,
).Exec(context.Background())
}
func (r *workflowAPIRepository) DeleteWorkflow(tenantId, workflowId string) (*db.WorkflowModel, error) {
return r.client.Workflow.FindUnique(
db.Workflow.ID.Equals(workflowId),
).With(
defaultWorkflowPopulator()...,
).Delete().Exec(context.Background())
}
func (r *workflowAPIRepository) UpsertWorkflowDeploymentConfig(workflowId string, opts *repository.UpsertWorkflowDeploymentConfigOpts) (*db.WorkflowDeploymentConfigModel, error) {
if err := r.v.Validate(opts); err != nil {
return nil, err
}
// upsert the deployment config
deploymentConfig, err := r.client.WorkflowDeploymentConfig.UpsertOne(
db.WorkflowDeploymentConfig.WorkflowID.Equals(workflowId),
).Create(
db.WorkflowDeploymentConfig.Workflow.Link(
db.Workflow.ID.Equals(workflowId),
),
db.WorkflowDeploymentConfig.GitRepoName.Set(opts.GitRepoName),
db.WorkflowDeploymentConfig.GitRepoOwner.Set(opts.GitRepoOwner),
db.WorkflowDeploymentConfig.GitRepoBranch.Set(opts.GitRepoBranch),
db.WorkflowDeploymentConfig.GithubAppInstallation.Link(
db.GithubAppInstallation.ID.Equals(opts.GithubAppInstallationId),
),
).Update(
db.WorkflowDeploymentConfig.GitRepoName.Set(opts.GitRepoName),
db.WorkflowDeploymentConfig.GitRepoOwner.Set(opts.GitRepoOwner),
db.WorkflowDeploymentConfig.GitRepoBranch.Set(opts.GitRepoBranch),
db.WorkflowDeploymentConfig.GithubAppInstallation.Link(
db.GithubAppInstallation.ID.Equals(opts.GithubAppInstallationId),
),
).Exec(context.Background())
if err != nil {
return nil, err
}
return deploymentConfig, nil
}
type workflowEngineRepository struct {
pool *pgxpool.Pool
v validator.Validator
queries *dbsqlc.Queries
l *zerolog.Logger
}
func NewWorkflowEngineRepository(pool *pgxpool.Pool, v validator.Validator, l *zerolog.Logger) repository.WorkflowEngineRepository {
queries := dbsqlc.New()
return &workflowEngineRepository{
v: v,
queries: queries,
pool: pool,
l: l,
}
}
func (r *workflowEngineRepository) CreateNewWorkflow(tenantId string, opts *repository.CreateWorkflowVersionOpts) (*dbsqlc.GetWorkflowVersionForEngineRow, error) {
if err := r.v.Validate(opts); err != nil {
return nil, err
}
@@ -176,15 +263,13 @@ func (r *workflowRepository) CreateNewWorkflow(tenantId string, opts *repository
}
// preflight check to ensure the workflow doesn't already exist
workflow, err := r.client.Workflow.FindUnique(
db.Workflow.TenantIDName(
db.Workflow.TenantID.Equals(tenantId),
db.Workflow.Name.Equals(opts.Name),
),
).Exec(context.Background())
workflow, err := r.queries.GetWorkflowByName(context.Background(), r.pool, dbsqlc.GetWorkflowByNameParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Name: opts.Name,
})
if err != nil {
if !errors.Is(err, db.ErrNotFound) {
if !errors.Is(err, pgx.ErrNoRows) {
return nil, err
}
} else if workflow != nil {
@@ -252,20 +337,29 @@ func (r *workflowRepository) CreateNewWorkflow(tenantId string, opts *repository
return nil, err
}
workflowVersion, err := r.queries.GetWorkflowVersionForEngine(context.Background(), tx, dbsqlc.GetWorkflowVersionForEngineParams{
Tenantid: pgTenantId,
Ids: []pgtype.UUID{sqlchelpers.UUIDFromStr(workflowVersionId)},
})
if err != nil {
return nil, fmt.Errorf("failed to fetch workflow version: %w", err)
}
if len(workflowVersion) != 1 {
return nil, fmt.Errorf("expected 1 workflow version when creating new, got %d", len(workflowVersion))
}
err = tx.Commit(context.Background())
if err != nil {
return nil, err
}
return r.client.WorkflowVersion.FindUnique(
db.WorkflowVersion.ID.Equals(workflowVersionId),
).With(
defaultWorkflowVersionPopulator()...,
).Exec(context.Background())
return workflowVersion[0], nil
}
func (r *workflowRepository) CreateWorkflowVersion(tenantId string, opts *repository.CreateWorkflowVersionOpts) (*db.WorkflowVersionModel, error) {
func (r *workflowEngineRepository) CreateWorkflowVersion(tenantId string, opts *repository.CreateWorkflowVersionOpts) (*dbsqlc.GetWorkflowVersionForEngineRow, error) {
if err := r.v.Validate(opts); err != nil {
return nil, err
}
@@ -280,15 +374,13 @@ func (r *workflowRepository) CreateWorkflowVersion(tenantId string, opts *reposi
}
// preflight check to ensure the workflow already exists
workflow, err := r.client.Workflow.FindUnique(
db.Workflow.TenantIDName(
db.Workflow.TenantID.Equals(tenantId),
db.Workflow.Name.Equals(opts.Name),
),
).Exec(context.Background())
workflow, err := r.queries.GetWorkflowByName(context.Background(), r.pool, dbsqlc.GetWorkflowByNameParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Name: opts.Name,
})
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to fetch workflow: %w", err)
}
if workflow == nil {
@@ -306,97 +398,105 @@ func (r *workflowRepository) CreateWorkflowVersion(tenantId string, opts *reposi
defer deferRollback(context.Background(), r.l, tx.Rollback)
workflowId := sqlchelpers.UUIDFromStr(workflow.ID)
pgTenantId := sqlchelpers.UUIDFromStr(tenantId)
workflowVersionId, err := r.createWorkflowVersionTxs(context.Background(), tx, pgTenantId, workflowId, opts)
workflowVersionId, err := r.createWorkflowVersionTxs(context.Background(), tx, pgTenantId, workflow.ID, opts)
if err != nil {
return nil, err
}
workflowVersion, err := r.queries.GetWorkflowVersionForEngine(context.Background(), tx, dbsqlc.GetWorkflowVersionForEngineParams{
Tenantid: pgTenantId,
Ids: []pgtype.UUID{sqlchelpers.UUIDFromStr(workflowVersionId)},
})
if err != nil {
return nil, fmt.Errorf("failed to fetch workflow version: %w", err)
}
if len(workflowVersion) != 1 {
return nil, fmt.Errorf("expected 1 workflow version when creating version, got %d", len(workflowVersion))
}
err = tx.Commit(context.Background())
if err != nil {
return nil, err
}
return r.client.WorkflowVersion.FindUnique(
db.WorkflowVersion.ID.Equals(workflowVersionId),
).With(
defaultWorkflowVersionPopulator()...,
).Exec(context.Background())
return workflowVersion[0], nil
}
type createScheduleTxResult interface {
Result() *db.WorkflowTriggerScheduledRefModel
}
func (r *workflowRepository) CreateSchedules(
func (r *workflowEngineRepository) CreateSchedules(
tenantId, workflowVersionId string,
opts *repository.CreateWorkflowSchedulesOpts,
) ([]*db.WorkflowTriggerScheduledRefModel, error) {
) ([]*dbsqlc.WorkflowTriggerScheduledRef, error) {
if err := r.v.Validate(opts); err != nil {
return nil, err
}
txs := []db.PrismaTransaction{}
results := []createScheduleTxResult{}
for _, scheduledTrigger := range opts.ScheduledTriggers {
createTx := r.client.WorkflowTriggerScheduledRef.CreateOne(
db.WorkflowTriggerScheduledRef.Parent.Link(
db.WorkflowVersion.ID.Equals(workflowVersionId),
),
db.WorkflowTriggerScheduledRef.TriggerAt.Set(scheduledTrigger),
db.WorkflowTriggerScheduledRef.Input.SetIfPresent(opts.Input),
).Tx()
txs = append(txs, createTx)
results = append(results, createTx)
createParams := dbsqlc.CreateSchedulesParams{
Workflowrunid: sqlchelpers.UUIDFromStr(workflowVersionId),
Input: opts.Input,
Triggertimes: make([]pgtype.Timestamp, len(opts.ScheduledTriggers)),
}
err := r.client.Prisma.Transaction(txs...).Exec(context.Background())
for i, scheduledTrigger := range opts.ScheduledTriggers {
createParams.Triggertimes[i] = sqlchelpers.TimestampFromTime(scheduledTrigger)
}
return r.queries.CreateSchedules(context.Background(), r.pool, createParams)
}
func (r *workflowEngineRepository) GetLatestWorkflowVersion(tenantId, workflowId string) (*dbsqlc.GetWorkflowVersionForEngineRow, error) {
versionId, err := r.queries.GetWorkflowLatestVersion(context.Background(), r.pool, sqlchelpers.UUIDFromStr(workflowId))
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to fetch latest version: %w", err)
}
res := make([]*db.WorkflowTriggerScheduledRefModel, 0)
versions, err := r.queries.GetWorkflowVersionForEngine(context.Background(), r.pool, dbsqlc.GetWorkflowVersionForEngineParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Ids: []pgtype.UUID{versionId},
})
for _, result := range results {
res = append(res, result.Result())
if err != nil {
return nil, fmt.Errorf("failed to fetch workflow version: %w", err)
}
return res, nil
if len(versions) != 1 {
return nil, fmt.Errorf("expected 1 workflow version for latest, got %d", len(versions))
}
return versions[0], nil
}
func (r *workflowRepository) GetWorkflowById(workflowId string) (*db.WorkflowModel, error) {
return r.client.Workflow.FindUnique(
db.Workflow.ID.Equals(workflowId),
).With(
defaultWorkflowPopulator()...,
).Exec(context.Background())
func (r *workflowEngineRepository) GetWorkflowByName(tenantId, workflowName string) (*dbsqlc.Workflow, error) {
return r.queries.GetWorkflowByName(context.Background(), r.pool, dbsqlc.GetWorkflowByNameParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Name: workflowName,
})
}
func (r *workflowRepository) GetWorkflowByName(tenantId, workflowName string) (*db.WorkflowModel, error) {
return r.client.Workflow.FindUnique(
db.Workflow.TenantIDName(
db.Workflow.TenantID.Equals(tenantId),
db.Workflow.Name.Equals(workflowName),
),
).With(
defaultWorkflowPopulator()...,
).Exec(context.Background())
func (r *workflowEngineRepository) GetWorkflowVersionById(tenantId, workflowId string) (*dbsqlc.GetWorkflowVersionForEngineRow, error) {
versions, err := r.queries.GetWorkflowVersionForEngine(context.Background(), r.pool, dbsqlc.GetWorkflowVersionForEngineParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Ids: []pgtype.UUID{sqlchelpers.UUIDFromStr(workflowId)},
})
if err != nil {
return nil, fmt.Errorf("failed to fetch workflow version: %w", err)
}
if len(versions) != 1 {
return nil, fmt.Errorf("expected 1 workflow version when getting by id, got %d", len(versions))
}
return versions[0], nil
}
func (r *workflowRepository) GetScheduledById(tenantId, scheduleTriggerId string) (*db.WorkflowTriggerScheduledRefModel, error) {
return r.client.WorkflowTriggerScheduledRef.FindUnique(
db.WorkflowTriggerScheduledRef.ID.Equals(scheduleTriggerId),
).Exec(context.Background())
}
func (r *workflowRepository) ListWorkflowsForEvent(ctx context.Context, tenantId, eventKey string) ([]*dbsqlc.GetWorkflowVersionForEngineRow, error) {
func (r *workflowEngineRepository) ListWorkflowsForEvent(ctx context.Context, tenantId, eventKey string) ([]*dbsqlc.GetWorkflowVersionForEngineRow, error) {
ctx, span1 := telemetry.NewSpan(ctx, "db-list-workflows-for-event")
defer span1.End()
@@ -418,7 +518,7 @@ func (r *workflowRepository) ListWorkflowsForEvent(ctx context.Context, tenantId
span2.End()
ctx, span3 := telemetry.NewSpan(ctx, "db-get-workflow-versions-for-engine")
ctx, span3 := telemetry.NewSpan(ctx, "db-get-workflow-versions-for-engine") // nolint: ineffassign
defer span3.End()
workflows, err := r.queries.GetWorkflowVersionForEngine(context.Background(), r.pool, dbsqlc.GetWorkflowVersionForEngineParams{
@@ -433,7 +533,7 @@ func (r *workflowRepository) ListWorkflowsForEvent(ctx context.Context, tenantId
return workflows, nil
}
func (r *workflowRepository) createWorkflowVersionTxs(ctx context.Context, tx pgx.Tx, tenantId, workflowId pgtype.UUID, opts *repository.CreateWorkflowVersionOpts) (string, error) {
func (r *workflowEngineRepository) createWorkflowVersionTxs(ctx context.Context, tx pgx.Tx, tenantId, workflowId pgtype.UUID, opts *repository.CreateWorkflowVersionOpts) (string, error) {
workflowVersionId := uuid.New().String()
var version pgtype.Text
@@ -698,56 +798,6 @@ func (r *workflowRepository) createWorkflowVersionTxs(ctx context.Context, tx pg
return workflowVersionId, nil
}
func (r *workflowRepository) DeleteWorkflow(tenantId, workflowId string) (*db.WorkflowModel, error) {
return r.client.Workflow.FindUnique(
db.Workflow.ID.Equals(workflowId),
).With(
defaultWorkflowPopulator()...,
).Delete().Exec(context.Background())
}
func (r *workflowRepository) GetWorkflowVersionById(tenantId, workflowVersionId string) (*db.WorkflowVersionModel, error) {
return r.client.WorkflowVersion.FindUnique(
db.WorkflowVersion.ID.Equals(workflowVersionId),
).With(
defaultWorkflowVersionPopulator()...,
).Exec(context.Background())
}
func (r *workflowRepository) UpsertWorkflowDeploymentConfig(workflowId string, opts *repository.UpsertWorkflowDeploymentConfigOpts) (*db.WorkflowDeploymentConfigModel, error) {
if err := r.v.Validate(opts); err != nil {
return nil, err
}
// upsert the deployment config
deploymentConfig, err := r.client.WorkflowDeploymentConfig.UpsertOne(
db.WorkflowDeploymentConfig.WorkflowID.Equals(workflowId),
).Create(
db.WorkflowDeploymentConfig.Workflow.Link(
db.Workflow.ID.Equals(workflowId),
),
db.WorkflowDeploymentConfig.GitRepoName.Set(opts.GitRepoName),
db.WorkflowDeploymentConfig.GitRepoOwner.Set(opts.GitRepoOwner),
db.WorkflowDeploymentConfig.GitRepoBranch.Set(opts.GitRepoBranch),
db.WorkflowDeploymentConfig.GithubAppInstallation.Link(
db.GithubAppInstallation.ID.Equals(opts.GithubAppInstallationId),
),
).Update(
db.WorkflowDeploymentConfig.GitRepoName.Set(opts.GitRepoName),
db.WorkflowDeploymentConfig.GitRepoOwner.Set(opts.GitRepoOwner),
db.WorkflowDeploymentConfig.GitRepoBranch.Set(opts.GitRepoBranch),
db.WorkflowDeploymentConfig.GithubAppInstallation.Link(
db.GithubAppInstallation.ID.Equals(opts.GithubAppInstallationId),
),
).Exec(context.Background())
if err != nil {
return nil, err
}
return deploymentConfig, nil
}
func defaultWorkflowPopulator() []db.WorkflowRelationWith {
return []db.WorkflowRelationWith{
db.Workflow.Tags.Fetch(),

Some files were not shown because too many files have changed in this diff Show More