diff --git a/db/upgrade/upgrade.go b/db/upgrade/upgrade.go index 7513bdb2..fe41e3cd 100644 --- a/db/upgrade/upgrade.go +++ b/db/upgrade/upgrade.go @@ -5102,7 +5102,7 @@ var upgradeFunctions = map[string]func(ctx context.Context, tx pgx.Tx) (string, `, ar.moduleName, ar.attributeId.String())); err != nil { return "", err } - if err := pgIndex.SetAutoFkiForAttribute_tx(tx, ar.relationId, + if err := pgIndex.SetAutoFkiForAttribute_tx(db.Ctx, tx, ar.relationId, ar.attributeId, (ar.content == "1:1")); err != nil { return "", err diff --git a/handler/websocket/websocket.go b/handler/websocket/websocket.go index acf39131..152bda8d 100644 --- a/handler/websocket/websocket.go +++ b/handler/websocket/websocket.go @@ -9,6 +9,7 @@ import ( "r3/bruteforce" "r3/cache" "r3/cluster" + "r3/config" "r3/handler" "r3/log" "r3/login/login_session" @@ -16,6 +17,7 @@ import ( "r3/types" "sync" "sync/atomic" + "time" "github.com/gofrs/uuid" "github.com/gorilla/websocket" @@ -127,22 +129,31 @@ func Handler(w http.ResponseWriter, r *http.Request) { func (hub *hubType) start() { - var removeClient = func(client *clientType) { - if _, exists := hub.clients[client]; exists { - log.Info(handlerContext, fmt.Sprintf("disconnecting client at %s", client.address)) - if !client.ioFailure.Load() { - client.write_mx.Lock() - client.ws.WriteMessage(websocket.CloseMessage, []byte{}) - client.write_mx.Unlock() - } - client.ws.Close() - client.ctxCancel() - delete(hub.clients, client) + var removeClient = func(client *clientType, wasKicked bool) { + if _, exists := hub.clients[client]; !exists { + return + } + if !client.ioFailure.Load() { + client.write_mx.Lock() + client.ws.WriteMessage(websocket.CloseMessage, []byte{}) + client.write_mx.Unlock() + } + client.ws.Close() + client.ctxCancel() + delete(hub.clients, client) + + // run DB calls in async func as they must not block hub operations during heavy DB load + go func() { + if wasKicked { + log.Info(handlerContext, fmt.Sprintf("kicked client (login ID %d) at %s", client.loginId, client.address)) + } else { + log.Info(handlerContext, fmt.Sprintf("disconnected client (login ID %d) at %s", client.loginId, client.address)) + } if err := login_session.LogRemove(client.id); err != nil { log.Error(handlerContext, "failed to remove login session log", err) } - } + }() } for { @@ -152,11 +163,11 @@ func (hub *hubType) start() { hub.clients[client] = true case client := <-hub.clientDel: - removeClient(client) + removeClient(client, false) case event := <-cluster.WebsocketClientEvents: - // prepare json message for client based on event content + // prepare json message for client(s) based on event content var err error = nil jsonMsg := []byte{} // message back to client singleRecipient := false // message is only sent to single recipient (first valid one) @@ -196,7 +207,8 @@ func (hub *hubType) start() { } if err != nil { - log.Error(handlerContext, "could not prepare unrequested transaction", err) + // run DB calls in async func as they must not block hub operations during heavy DB load + go log.Error(handlerContext, "could not prepare unrequested transaction", err) continue } @@ -214,8 +226,7 @@ func (hub *hubType) start() { // disconnect and do not send message if kicked if event.Content == "kick" || (event.Content == "kickNonAdmin" && !client.admin) { - log.Info(handlerContext, fmt.Sprintf("kicking client (login ID %d)", client.loginId)) - removeClient(client) + removeClient(client, true) continue } @@ -284,7 +295,10 @@ func (client *clientType) handleTransaction(reqTransJson json.RawMessage) json.R if !authRequest { // execute non-authentication transaction - resTrans = request.ExecTransaction(client.ctx, client.address, client.loginId, + ctx, _ := context.WithTimeout(client.ctx, + time.Duration(int64(config.GetUint64("dbTimeoutDataWs")))*time.Second) + + resTrans = request.ExecTransaction(ctx, client.address, client.loginId, client.admin, client.device, client.noAuth, reqTrans, resTrans) } else { diff --git a/login/login.go b/login/login.go index 4d5bf0d4..81f22dc3 100644 --- a/login/login.go +++ b/login/login.go @@ -1,6 +1,7 @@ package login import ( + "context" "errors" "fmt" "math/rand" @@ -258,11 +259,11 @@ func Set_tx(tx pgx.Tx, id int64, loginTemplateId pgtype.Int8, ldapId pgtype.Int4 return 0, err } } - s, err := login_setting.Get(pgtype.Int8{}, loginTemplateId) + s, err := login_setting.Get_tx(db.Ctx, tx, pgtype.Int8{}, loginTemplateId) if err != nil { return 0, err } - if err := login_setting.Set_tx(tx, pgtype.Int8{Int64: id, Valid: true}, pgtype.Int8{}, s, true); err != nil { + if err := login_setting.Set_tx(db.Ctx, tx, pgtype.Int8{Int64: id, Valid: true}, pgtype.Int8{}, s, true); err != nil { return 0, err } } else { @@ -276,7 +277,7 @@ func Set_tx(tx pgx.Tx, id int64, loginTemplateId pgtype.Int8, ldapId pgtype.Int4 } if pass != "" { - if err := SetSaltHash_tx(tx, salt, hash, id); err != nil { + if err := SetSaltHash_tx(db.Ctx, tx, salt, hash, id); err != nil { return 0, err } } @@ -324,8 +325,8 @@ func Set_tx(tx pgx.Tx, id int64, loginTemplateId pgtype.Int8, ldapId pgtype.Int4 return id, setRoleIds_tx(tx, id, roleIds) } -func SetSaltHash_tx(tx pgx.Tx, salt pgtype.Text, hash pgtype.Text, id int64) error { - _, err := tx.Exec(db.Ctx, ` +func SetSaltHash_tx(ctx context.Context, tx pgx.Tx, salt pgtype.Text, hash pgtype.Text, id int64) error { + _, err := tx.Exec(ctx, ` UPDATE instance.login SET salt = $1, hash = $2 WHERE id = $3 @@ -366,7 +367,7 @@ func GetByRole(roleId uuid.UUID) ([]types.Login, error) { // get names for public lookups for non-admins // returns slice of up to 10 logins -func GetNames(id int64, idsExclude []int64, byString string, noLdapAssign bool) ([]types.Login, error) { +func GetNames_tx(ctx context.Context, tx pgx.Tx, id int64, idsExclude []int64, byString string, noLdapAssign bool) ([]types.Login, error) { names := make([]types.Login, 0) var qb tools.QueryBuilder @@ -409,7 +410,7 @@ func GetNames(id int64, idsExclude []int64, byString string, noLdapAssign bool) return names, err } - rows, err := db.Pool.Query(db.Ctx, query, qb.GetParaValues()...) + rows, err := tx.Query(ctx, query, qb.GetParaValues()...) if err != nil { return names, err } @@ -426,18 +427,18 @@ func GetNames(id int64, idsExclude []int64, byString string, noLdapAssign bool) } // user creatable fixed (permanent) tokens for less sensitive access permissions -func DelTokenFixed(loginId int64, id int64) error { - _, err := db.Pool.Exec(db.Ctx, ` +func DelTokenFixed_tx(ctx context.Context, tx pgx.Tx, loginId int64, id int64) error { + _, err := tx.Exec(ctx, ` DELETE FROM instance.login_token_fixed WHERE login_id = $1 AND id = $2 `, loginId, id) return err } -func GetTokensFixed(loginId int64) ([]types.LoginTokenFixed, error) { +func GetTokensFixed_tx(ctx context.Context, tx pgx.Tx, loginId int64) ([]types.LoginTokenFixed, error) { tokens := make([]types.LoginTokenFixed, 0) - rows, err := db.Pool.Query(db.Ctx, ` + rows, err := tx.Query(ctx, ` SELECT id, name, context, token, date_create FROM instance.login_token_fixed WHERE login_id = $1 @@ -459,11 +460,11 @@ func GetTokensFixed(loginId int64) ([]types.LoginTokenFixed, error) { } return tokens, nil } -func SetTokenFixed_tx(tx pgx.Tx, loginId int64, name string, context string) (string, error) { +func SetTokenFixed_tx(ctx context.Context, tx pgx.Tx, loginId int64, name string, context string) (string, error) { min, max := 32, 48 tokenFixed := tools.RandStringRunes(rand.Intn(max-min+1) + min) - if _, err := tx.Exec(db.Ctx, ` + if _, err := tx.Exec(ctx, ` INSERT INTO instance.login_token_fixed (login_id,token,name,context,date_create) VALUES ($1,$2,$3,$4,$5) `, loginId, tokenFixed, name, context, tools.GetTimeUnix()); err != nil { diff --git a/login/login_check/login_check.go b/login/login_check/login_check.go index 36cbf199..fb454428 100644 --- a/login/login_check/login_check.go +++ b/login/login_check/login_check.go @@ -1,9 +1,9 @@ package login_check import ( + "context" "fmt" "r3/config" - "r3/db" "r3/tools" "regexp" @@ -11,11 +11,11 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) -func Password(tx pgx.Tx, loginId int64, pwOld string) error { +func Password(ctx context.Context, tx pgx.Tx, loginId int64, pwOld string) error { var salt, hash string var ldapId pgtype.Int4 - if err := tx.QueryRow(db.Ctx, ` + if err := tx.QueryRow(ctx, ` SELECT salt, hash, ldap_id FROM instance.login WHERE active diff --git a/login/login_clientEvent/login_clientEvent.go b/login/login_clientEvent/login_clientEvent.go index 19469d14..843f75b1 100644 --- a/login/login_clientEvent/login_clientEvent.go +++ b/login/login_clientEvent/login_clientEvent.go @@ -1,15 +1,15 @@ package login_clientEvent import ( - "r3/db" + "context" "r3/types" "github.com/gofrs/uuid" "github.com/jackc/pgx/v5" ) -func Del_tx(tx pgx.Tx, loginId int64, clientEventId uuid.UUID) error { - _, err := tx.Exec(db.Ctx, ` +func Del_tx(ctx context.Context, tx pgx.Tx, loginId int64, clientEventId uuid.UUID) error { + _, err := tx.Exec(ctx, ` DELETE FROM instance.login_client_event WHERE login_id = $1 AND client_event_id = $2 @@ -17,10 +17,10 @@ func Del_tx(tx pgx.Tx, loginId int64, clientEventId uuid.UUID) error { return err } -func Get(loginId int64) (map[uuid.UUID]types.LoginClientEvent, error) { +func Get_tx(ctx context.Context, tx pgx.Tx, loginId int64) (map[uuid.UUID]types.LoginClientEvent, error) { lceIdMap := make(map[uuid.UUID]types.LoginClientEvent) - rows, err := db.Pool.Query(db.Ctx, ` + rows, err := tx.Query(ctx, ` SELECT client_event_id, hotkey_modifier1, hotkey_modifier2, hotkey_char FROM instance.login_client_event WHERE login_id = $1 @@ -41,10 +41,10 @@ func Get(loginId int64) (map[uuid.UUID]types.LoginClientEvent, error) { return lceIdMap, nil } -func Set_tx(tx pgx.Tx, loginId int64, clientEventId uuid.UUID, lce types.LoginClientEvent) error { +func Set_tx(ctx context.Context, tx pgx.Tx, loginId int64, clientEventId uuid.UUID, lce types.LoginClientEvent) error { exists := false - if err := tx.QueryRow(db.Ctx, ` + if err := tx.QueryRow(ctx, ` SELECT EXISTS( SELECT client_event_id FROM instance.login_client_event @@ -57,14 +57,14 @@ func Set_tx(tx pgx.Tx, loginId int64, clientEventId uuid.UUID, lce types.LoginCl var err error if exists { - _, err = tx.Exec(db.Ctx, ` + _, err = tx.Exec(ctx, ` UPDATE instance.login_client_event SET hotkey_modifier1 = $1, hotkey_modifier2 = $2, hotkey_char = $3 WHERE login_id = $4 AND client_event_id = $5 `, lce.HotkeyModifier1, lce.HotkeyModifier2, lce.HotkeyChar, loginId, clientEventId) } else { - _, err = tx.Exec(db.Ctx, ` + _, err = tx.Exec(ctx, ` INSERT INTO instance.login_client_event ( login_id, client_event_id, hotkey_modifier1, hotkey_modifier2, hotkey_char) VALUES ($1,$2,$3,$4,$5) diff --git a/login/login_keys/login_keys.go b/login/login_keys/login_keys.go index c7035c07..29beae64 100644 --- a/login/login_keys/login_keys.go +++ b/login/login_keys/login_keys.go @@ -86,11 +86,11 @@ func GetPublic(ctx context.Context, relationId uuid.UUID, return keys, nil } -func Reset_tx(tx pgx.Tx, loginId int64) error { +func Reset_tx(ctx context.Context, tx pgx.Tx, loginId int64) error { cache.Schema_mx.RLock() defer cache.Schema_mx.RUnlock() - if _, err := tx.Exec(db.Ctx, ` + if _, err := tx.Exec(ctx, ` UPDATE instance.login SET key_private_enc = NULL, key_private_enc_backup = NULL, key_public = NULL WHERE id = $1 @@ -101,7 +101,7 @@ func Reset_tx(tx pgx.Tx, loginId int64) error { // delete unusable data keys for _, rel := range cache.RelationIdMap { if rel.Encryption { - if _, err := tx.Exec(db.Ctx, fmt.Sprintf(` + if _, err := tx.Exec(ctx, fmt.Sprintf(` DELETE FROM instance_e2ee."%s" WHERE login_id = $1 `, schema.GetEncKeyTableName(rel.Id)), loginId); err != nil { @@ -112,10 +112,10 @@ func Reset_tx(tx pgx.Tx, loginId int64) error { return nil } -func Store_tx(tx pgx.Tx, loginId int64, privateKeyEnc string, +func Store_tx(ctx context.Context, tx pgx.Tx, loginId int64, privateKeyEnc string, privateKeyEncBackup string, publicKey string) error { - _, err := tx.Exec(db.Ctx, ` + _, err := tx.Exec(ctx, ` UPDATE instance.login SET key_private_enc = $1, key_private_enc_backup = $2, key_public = $3 WHERE id = $4 @@ -124,9 +124,9 @@ func Store_tx(tx pgx.Tx, loginId int64, privateKeyEnc string, return err } -func StorePrivate_tx(tx pgx.Tx, loginId int64, privateKeyEnc string) error { +func StorePrivate_tx(ctx context.Context, tx pgx.Tx, loginId int64, privateKeyEnc string) error { - _, err := tx.Exec(db.Ctx, ` + _, err := tx.Exec(ctx, ` UPDATE instance.login SET key_private_enc = $1 WHERE id = $2 diff --git a/login/login_setting/login_setting.go b/login/login_setting/login_setting.go index cea87222..cccab4e8 100644 --- a/login/login_setting/login_setting.go +++ b/login/login_setting/login_setting.go @@ -1,16 +1,16 @@ package login_setting import ( + "context" "errors" "fmt" - "r3/db" "r3/types" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" ) -func Get(loginId pgtype.Int8, loginTemplateId pgtype.Int8) (types.Settings, error) { +func Get_tx(ctx context.Context, tx pgx.Tx, loginId pgtype.Int8, loginTemplateId pgtype.Int8) (types.Settings, error) { var s types.Settings if (loginId.Valid && loginTemplateId.Valid) || (!loginId.Valid && !loginTemplateId.Valid) { @@ -25,7 +25,7 @@ func Get(loginId pgtype.Int8, loginTemplateId pgtype.Int8) (types.Settings, erro entryName = "login_template_id" } - err := db.Pool.QueryRow(db.Ctx, fmt.Sprintf(` + err := tx.QueryRow(ctx, fmt.Sprintf(` SELECT language_code, date_format, sunday_first_dow, font_size, borders_all, borders_squared, header_captions, header_modules, spacing, dark, hint_update_version, mobile_scroll_form, @@ -52,7 +52,7 @@ func Get(loginId pgtype.Int8, loginTemplateId pgtype.Int8) (types.Settings, erro return s, err } -func Set_tx(tx pgx.Tx, loginId pgtype.Int8, loginTemplateId pgtype.Int8, s types.Settings, isNew bool) error { +func Set_tx(ctx context.Context, tx pgx.Tx, loginId pgtype.Int8, loginTemplateId pgtype.Int8, s types.Settings, isNew bool) error { if (loginId.Valid && loginTemplateId.Valid) || (!loginId.Valid && !loginTemplateId.Valid) { return errors.New("settings can only be applied for either login or login template") @@ -67,7 +67,7 @@ func Set_tx(tx pgx.Tx, loginId pgtype.Int8, loginTemplateId pgtype.Int8, s types } if isNew { - if _, err := tx.Exec(db.Ctx, fmt.Sprintf(` + if _, err := tx.Exec(ctx, fmt.Sprintf(` INSERT INTO instance.login_setting (%s, language_code, date_format, sunday_first_dow, font_size, borders_all, borders_squared, header_captions, header_modules, spacing, dark, @@ -88,7 +88,7 @@ func Set_tx(tx pgx.Tx, loginId pgtype.Int8, loginTemplateId pgtype.Int8, s types return err } } else { - if _, err := tx.Exec(db.Ctx, fmt.Sprintf(` + if _, err := tx.Exec(ctx, fmt.Sprintf(` UPDATE instance.login_setting SET language_code = $1, date_format = $2, sunday_first_dow = $3, font_size = $4, borders_all = $5, borders_squared = $6, @@ -115,7 +115,7 @@ func Set_tx(tx pgx.Tx, loginId pgtype.Int8, loginTemplateId pgtype.Int8, s types // update full text search dictionaries if !isNew { - if _, err := tx.Exec(db.Ctx, fmt.Sprintf(` + if _, err := tx.Exec(ctx, fmt.Sprintf(` DELETE FROM instance.login_search_dict WHERE %s = $1 `, entryName), entryId); err != nil { @@ -124,7 +124,7 @@ func Set_tx(tx pgx.Tx, loginId pgtype.Int8, loginTemplateId pgtype.Int8, s types } for i, dictName := range s.SearchDictionaries { - if _, err := tx.Exec(db.Ctx, fmt.Sprintf(` + if _, err := tx.Exec(ctx, fmt.Sprintf(` INSERT INTO instance.login_search_dict (%s, position, name) VALUES ($1, $2, $3) `, entryName), entryId, i, dictName); err != nil { diff --git a/login/login_template/login_template.go b/login/login_template/login_template.go index d2afa0f6..8edae00d 100644 --- a/login/login_template/login_template.go +++ b/login/login_template/login_template.go @@ -1,8 +1,8 @@ package login_template import ( + "context" "fmt" - "r3/db" "r3/login/login_setting" "r3/types" @@ -10,8 +10,8 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) -func Del_tx(tx pgx.Tx, id int64) error { - _, err := tx.Exec(db.Ctx, ` +func Del_tx(ctx context.Context, tx pgx.Tx, id int64) error { + _, err := tx.Exec(ctx, ` DELETE FROM instance.login_template WHERE id = $1 AND name <> 'GLOBAL' -- protect global default @@ -19,7 +19,7 @@ func Del_tx(tx pgx.Tx, id int64) error { return err } -func Get(byId int64) ([]types.LoginTemplateAdmin, error) { +func Get_tx(ctx context.Context, tx pgx.Tx, byId int64) ([]types.LoginTemplateAdmin, error) { templates := make([]types.LoginTemplateAdmin, 0) sqlParams := make([]interface{}, 0) @@ -29,7 +29,7 @@ func Get(byId int64) ([]types.LoginTemplateAdmin, error) { sqlWhere = "WHERE id = $1" } - rows, err := db.Pool.Query(db.Ctx, fmt.Sprintf(` + rows, err := tx.Query(ctx, fmt.Sprintf(` SELECT id, name, comment FROM instance.login_template %s @@ -49,7 +49,8 @@ func Get(byId int64) ([]types.LoginTemplateAdmin, error) { rows.Close() for i, _ := range templates { - templates[i].Settings, err = login_setting.Get( + templates[i].Settings, err = login_setting.Get_tx( + ctx, tx, pgtype.Int8{}, pgtype.Int8{Int64: templates[i].Id, Valid: true}) @@ -60,11 +61,11 @@ func Get(byId int64) ([]types.LoginTemplateAdmin, error) { return templates, nil } -func Set_tx(tx pgx.Tx, t types.LoginTemplateAdmin) (int64, error) { +func Set_tx(ctx context.Context, tx pgx.Tx, t types.LoginTemplateAdmin) (int64, error) { isNew := t.Id == 0 if isNew { - if err := tx.QueryRow(db.Ctx, ` + if err := tx.QueryRow(ctx, ` INSERT INTO instance.login_template (name, comment) VALUES ($1,$2) RETURNING id @@ -72,7 +73,7 @@ func Set_tx(tx pgx.Tx, t types.LoginTemplateAdmin) (int64, error) { return t.Id, err } } else { - if _, err := tx.Exec(db.Ctx, ` + if _, err := tx.Exec(ctx, ` UPDATE instance.login_template SET name = $1, comment = $2 WHERE id = $3 @@ -82,7 +83,7 @@ func Set_tx(tx pgx.Tx, t types.LoginTemplateAdmin) (int64, error) { } } - return t.Id, login_setting.Set_tx(tx, + return t.Id, login_setting.Set_tx(ctx, tx, pgtype.Int8{}, pgtype.Int8{Int64: t.Id, Valid: true}, t.Settings, isNew) diff --git a/login/login_widget/login_widget.go b/login/login_widget/login_widget.go index 631b7f3c..2911b4ac 100644 --- a/login/login_widget/login_widget.go +++ b/login/login_widget/login_widget.go @@ -1,7 +1,7 @@ package login_widget import ( - "r3/db" + "context" "r3/types" "github.com/gofrs/uuid" @@ -9,10 +9,10 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) -func Get(loginId int64) ([]types.LoginWidgetGroup, error) { +func Get_tx(ctx context.Context, tx pgx.Tx, loginId int64) ([]types.LoginWidgetGroup, error) { groups := make([]types.LoginWidgetGroup, 0) - rows, err := db.Pool.Query(db.Ctx, ` + rows, err := tx.Query(ctx, ` SELECT g.id, g.title, w.widget_id, w.module_id, w.content FROM instance.login_widget_group AS g LEFT JOIN instance.login_widget_group_item AS w ON w.login_widget_group_id = g.id @@ -61,9 +61,9 @@ func Get(loginId int64) ([]types.LoginWidgetGroup, error) { return groups, nil } -func Set_tx(tx pgx.Tx, loginId int64, groups []types.LoginWidgetGroup) error { +func Set_tx(ctx context.Context, tx pgx.Tx, loginId int64, groups []types.LoginWidgetGroup) error { - if _, err := tx.Exec(db.Ctx, ` + if _, err := tx.Exec(ctx, ` DELETE FROM instance.login_widget_group WHERE login_id = $1 `, loginId); err != nil { @@ -73,7 +73,7 @@ func Set_tx(tx pgx.Tx, loginId int64, groups []types.LoginWidgetGroup) error { for posGroup, g := range groups { var groupId uuid.UUID - if err := tx.QueryRow(db.Ctx, ` + if err := tx.QueryRow(ctx, ` INSERT INTO instance.login_widget_group (login_id, title, position) VALUES ($1,$2,$3) RETURNING id @@ -82,7 +82,7 @@ func Set_tx(tx pgx.Tx, loginId int64, groups []types.LoginWidgetGroup) error { } for posItem, w := range g.Items { - if _, err := tx.Exec(db.Ctx, ` + if _, err := tx.Exec(ctx, ` INSERT INTO instance.login_widget_group_item ( login_widget_group_id, position, widget_id, module_id, content) VALUES ($1,$2,$3,$4,$5) diff --git a/request/request.go b/request/request.go index e0732a73..40811f00 100644 --- a/request/request.go +++ b/request/request.go @@ -12,22 +12,15 @@ import ( "r3/handler" "r3/log" "r3/types" - "time" "github.com/jackc/pgx/v5" ) // executes a websocket transaction with multiple requests within a single DB transaction -func ExecTransaction(ctxClient context.Context, address string, loginId int64, isAdmin bool, +func ExecTransaction(ctx context.Context, address string, loginId int64, isAdmin bool, device types.WebsocketClientDevice, isNoAuth bool, reqTrans types.RequestTransaction, resTrans types.ResponseTransaction) types.ResponseTransaction { - // start transaction - ctx, ctxCancel := context.WithTimeout(ctxClient, - time.Duration(int64(config.GetUint64("dbTimeoutDataWs")))*time.Second) - - defer ctxCancel() - // run in a loop as there is an error case where it needs to be repeated runAgainNewCache := false for runOnce := true; runOnce || runAgainNewCache; runOnce = false { @@ -43,6 +36,7 @@ func ExecTransaction(ctxClient context.Context, address string, loginId int64, i if err := tx.Conn().DeallocateAll(ctx); err != nil { log.Error("websocket", "failed to deallocate DB connection", err) resTrans.Error = handler.ErrGeneral + tx.Rollback(ctx) return resTrans } runAgainNewCache = false @@ -52,6 +46,8 @@ func ExecTransaction(ctxClient context.Context, address string, loginId int64, i // set local transaction configuration parameters // these are used by system functions, such as instance.get_login_id() if err := db.SetSessionConfig_tx(ctx, tx, loginId); err != nil { + tx.Rollback(ctx) + log.Error("websocket", fmt.Sprintf("TRANSACTION %d, transaction config failure (login ID %d)", reqTrans.TransactionNr, loginId), err) @@ -140,9 +136,9 @@ func Exec_tx(ctx context.Context, tx pgx.Tx, address string, loginId int64, isAd case "clientEvent": switch action { case "exec": - return clientEventExecFatClient(reqJson, loginId, address) + return clientEventExecFatClient_tx(ctx, tx, reqJson, loginId, address) case "get": - return clientEventGetFatClient(loginId) + return clientEventGetFatClient_tx(ctx, tx, loginId) } } return nil, errors.New(handler.ErrUnauthorized) @@ -179,7 +175,7 @@ func Exec_tx(ctx context.Context, tx pgx.Tx, address string, loginId int64, isAd case "feedback": switch action { case "send": - return FeedbackSend_tx(tx, reqJson) + return FeedbackSend(reqJson) } case "file": switch action { @@ -189,33 +185,33 @@ func Exec_tx(ctx context.Context, tx pgx.Tx, address string, loginId int64, isAd case "login": switch action { case "getNames": - return LoginGetNames(reqJson) + return LoginGetNames_tx(ctx, tx, reqJson) case "delTokenFixed": - return LoginDelTokenFixed(reqJson, loginId) + return LoginDelTokenFixed_tx(ctx, tx, reqJson, loginId) case "getTokensFixed": - return LoginGetTokensFixed(loginId) + return LoginGetTokensFixed_tx(ctx, tx, loginId) case "setTokenFixed": - return LoginSetTokenFixed_tx(tx, reqJson, loginId) + return LoginSetTokenFixed_tx(ctx, tx, reqJson, loginId) } case "loginClientEvent": switch action { case "del": - return loginClientEventDel_tx(tx, reqJson, loginId) + return loginClientEventDel_tx(ctx, tx, reqJson, loginId) case "get": - return loginClientEventGet(loginId) + return loginClientEventGet_tx(ctx, tx, loginId) case "set": - return loginClientEventSet_tx(tx, reqJson, loginId) + return loginClientEventSet_tx(ctx, tx, reqJson, loginId) } case "loginKeys": switch action { case "getPublic": return LoginKeysGetPublic(ctx, reqJson) case "reset": - return LoginKeysReset_tx(tx, loginId) + return LoginKeysReset_tx(ctx, tx, loginId) case "store": - return LoginKeysStore_tx(tx, reqJson, loginId) + return LoginKeysStore_tx(ctx, tx, reqJson, loginId) case "storePrivate": - return LoginKeysStorePrivate_tx(tx, reqJson, loginId) + return LoginKeysStorePrivate_tx(ctx, tx, reqJson, loginId) } case "loginPassword": switch action { @@ -223,34 +219,34 @@ func Exec_tx(ctx context.Context, tx pgx.Tx, address string, loginId int64, isAd if isNoAuth { return nil, errors.New(handler.ErrUnauthorized) } - return loginPasswortSet_tx(tx, reqJson, loginId) + return loginPasswortSet_tx(ctx, tx, reqJson, loginId) } case "loginSetting": switch action { case "get": - return LoginSettingsGet(loginId) + return LoginSettingsGet(ctx, tx, loginId) case "set": if isNoAuth { return nil, errors.New(handler.ErrUnauthorized) } - return LoginSettingsSet_tx(tx, reqJson, loginId) + return LoginSettingsSet_tx(ctx, tx, reqJson, loginId) } case "loginWidgetGroups": switch action { case "get": - return LoginWidgetGroupsGet(loginId) + return LoginWidgetGroupsGet_tx(ctx, tx, loginId) case "set": - return LoginWidgetGroupsSet_tx(tx, reqJson, loginId) + return LoginWidgetGroupsSet_tx(ctx, tx, reqJson, loginId) } case "lookup": switch action { case "get": - return lookupGet(reqJson, loginId) + return lookupGet_tx(ctx, tx, reqJson, loginId) } case "pgFunction": switch action { case "exec": // user may exec non-trigger backend function, available to frontend - return PgFunctionExec_tx(tx, reqJson, true) + return PgFunctionExec_tx(ctx, tx, reqJson, true) } } @@ -448,11 +444,11 @@ func Exec_tx(ctx context.Context, tx pgx.Tx, address string, loginId int64, isAd case "loginTemplate": switch action { case "del": - return LoginTemplateDel_tx(tx, reqJson) + return LoginTemplateDel_tx(ctx, tx, reqJson) case "get": - return LoginTemplateGet(reqJson) + return LoginTemplateGet_tx(ctx, tx, reqJson) case "set": - return LoginTemplateSet_tx(tx, reqJson) + return LoginTemplateSet_tx(ctx, tx, reqJson) } case "mailAccount": switch action { @@ -525,20 +521,18 @@ func Exec_tx(ctx context.Context, tx pgx.Tx, address string, loginId int64, isAd case "pgFunction": switch action { case "del": - return PgFunctionDel_tx(tx, reqJson) + return PgFunctionDel_tx(ctx, tx, reqJson) case "execAny": // admin may exec any non-trigger backend function - return PgFunctionExec_tx(tx, reqJson, false) + return PgFunctionExec_tx(ctx, tx, reqJson, false) case "set": - return PgFunctionSet_tx(tx, reqJson) + return PgFunctionSet_tx(ctx, tx, reqJson) } case "pgIndex": switch action { case "del": - return PgIndexDel_tx(tx, reqJson) - case "get": - return PgIndexGet(reqJson) + return PgIndexDel_tx(ctx, tx, reqJson) case "set": - return PgIndexSet_tx(tx, reqJson) + return PgIndexSet_tx(ctx, tx, reqJson) } case "pgTrigger": switch action { diff --git a/request/request_clientEvent.go b/request/request_clientEvent.go index 5ed51bc0..66dc7741 100644 --- a/request/request_clientEvent.go +++ b/request/request_clientEvent.go @@ -1,11 +1,11 @@ package request import ( + "context" "encoding/json" "fmt" "r3/cache" "r3/cluster" - "r3/db" "r3/handler" "r3/login/login_clientEvent" "r3/schema/clientEvent" @@ -37,7 +37,7 @@ func clientEventSet_tx(tx pgx.Tx, reqJson json.RawMessage) (interface{}, error) } // fat client requests -func clientEventGetFatClient(loginId int64) (interface{}, error) { +func clientEventGetFatClient_tx(ctx context.Context, tx pgx.Tx, loginId int64) (interface{}, error) { var err error var res struct { @@ -48,7 +48,7 @@ func clientEventGetFatClient(loginId int64) (interface{}, error) { res.ClientEventIdMapLogin = make(map[uuid.UUID]types.LoginClientEvent) // collect login client events for login (currently only used to enable and overwrite hotkeys) - res.ClientEventIdMapLogin, err = login_clientEvent.Get(loginId) + res.ClientEventIdMapLogin, err = login_clientEvent.Get_tx(ctx, tx, loginId) if err != nil { return nil, err } @@ -74,7 +74,7 @@ func clientEventGetFatClient(loginId int64) (interface{}, error) { cache.Schema_mx.RUnlock() return res, nil } -func clientEventExecFatClient(reqJson json.RawMessage, loginId int64, address string) (interface{}, error) { +func clientEventExecFatClient_tx(ctx context.Context, tx pgx.Tx, reqJson json.RawMessage, loginId int64, address string) (interface{}, error) { var req struct { Id uuid.UUID `json:"id"` @@ -119,7 +119,7 @@ func clientEventExecFatClient(reqJson json.RawMessage, loginId int64, address st } var returnIf interface{} - err := db.Pool.QueryRow(db.Ctx, fmt.Sprintf(`SELECT "%s"."%s"(%s)`, mod.Name, fnc.Name, strings.Join(placeholders, ",")), + err := tx.QueryRow(ctx, fmt.Sprintf(`SELECT "%s"."%s"(%s)`, mod.Name, fnc.Name, strings.Join(placeholders, ",")), req.Arguments...).Scan(&returnIf) return nil, err diff --git a/request/request_feedback.go b/request/request_feedback.go index d1649b43..f6cf0402 100644 --- a/request/request_feedback.go +++ b/request/request_feedback.go @@ -5,10 +5,9 @@ import ( "r3/repo" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5" ) -func FeedbackSend_tx(tx pgx.Tx, reqJson json.RawMessage) (interface{}, error) { +func FeedbackSend(reqJson json.RawMessage) (interface{}, error) { var req struct { Code int `json:"code"` diff --git a/request/request_login.go b/request/request_login.go index c18719bd..fe7c84bf 100644 --- a/request/request_login.go +++ b/request/request_login.go @@ -1,6 +1,7 @@ package request import ( + "context" "encoding/base32" "encoding/json" "r3/cluster" @@ -13,7 +14,7 @@ import ( ) // user requests -func LoginGetNames(reqJson json.RawMessage) (interface{}, error) { +func LoginGetNames_tx(ctx context.Context, tx pgx.Tx, reqJson json.RawMessage) (interface{}, error) { var req struct { ByString string `json:"byString"` @@ -25,21 +26,21 @@ func LoginGetNames(reqJson json.RawMessage) (interface{}, error) { if err := json.Unmarshal(reqJson, &req); err != nil { return nil, err } - return login.GetNames(req.Id, req.IdsExclude, req.ByString, req.NoLdapAssign) + return login.GetNames_tx(ctx, tx, req.Id, req.IdsExclude, req.ByString, req.NoLdapAssign) } -func LoginDelTokenFixed(reqJson json.RawMessage, loginId int64) (interface{}, error) { +func LoginDelTokenFixed_tx(ctx context.Context, tx pgx.Tx, reqJson json.RawMessage, loginId int64) (interface{}, error) { var req struct { Id int64 `json:"id"` } if err := json.Unmarshal(reqJson, &req); err != nil { return nil, err } - return nil, login.DelTokenFixed(loginId, req.Id) + return nil, login.DelTokenFixed_tx(ctx, tx, loginId, req.Id) } -func LoginGetTokensFixed(loginId int64) (interface{}, error) { - return login.GetTokensFixed(loginId) +func LoginGetTokensFixed_tx(ctx context.Context, tx pgx.Tx, loginId int64) (interface{}, error) { + return login.GetTokensFixed_tx(ctx, tx, loginId) } -func LoginSetTokenFixed_tx(tx pgx.Tx, reqJson json.RawMessage, loginId int64) (interface{}, error) { +func LoginSetTokenFixed_tx(ctx context.Context, tx pgx.Tx, reqJson json.RawMessage, loginId int64) (interface{}, error) { var ( err error @@ -56,7 +57,7 @@ func LoginSetTokenFixed_tx(tx pgx.Tx, reqJson json.RawMessage, loginId int64) (i if err := json.Unmarshal(reqJson, &req); err != nil { return nil, err } - res.TokenFixed, err = login.SetTokenFixed_tx(tx, loginId, req.Name, req.Context) + res.TokenFixed, err = login.SetTokenFixed_tx(ctx, tx, loginId, req.Name, req.Context) res.TokenFixedB32 = base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString([]byte(res.TokenFixed)) return res, err diff --git a/request/request_login_clientEvent.go b/request/request_login_clientEvent.go index 0a900b61..7ca51808 100644 --- a/request/request_login_clientEvent.go +++ b/request/request_login_clientEvent.go @@ -1,6 +1,7 @@ package request import ( + "context" "encoding/json" "r3/login/login_clientEvent" "r3/types" @@ -9,21 +10,21 @@ import ( "github.com/jackc/pgx/v5" ) -func loginClientEventDel_tx(tx pgx.Tx, reqJson json.RawMessage, loginId int64) (interface{}, error) { +func loginClientEventDel_tx(ctx context.Context, tx pgx.Tx, reqJson json.RawMessage, loginId int64) (interface{}, error) { var req struct { ClientEventId uuid.UUID `json:"clientEventId"` } if err := json.Unmarshal(reqJson, &req); err != nil { return nil, err } - return nil, login_clientEvent.Del_tx(tx, loginId, req.ClientEventId) + return nil, login_clientEvent.Del_tx(ctx, tx, loginId, req.ClientEventId) } -func loginClientEventGet(loginId int64) (interface{}, error) { - return login_clientEvent.Get(loginId) +func loginClientEventGet_tx(ctx context.Context, tx pgx.Tx, loginId int64) (interface{}, error) { + return login_clientEvent.Get_tx(ctx, tx, loginId) } -func loginClientEventSet_tx(tx pgx.Tx, reqJson json.RawMessage, loginId int64) (interface{}, error) { +func loginClientEventSet_tx(ctx context.Context, tx pgx.Tx, reqJson json.RawMessage, loginId int64) (interface{}, error) { var req struct { ClientEventId uuid.UUID `json:"clientEventId"` LoginClientEvent types.LoginClientEvent `json:"loginClientEvent"` @@ -31,5 +32,5 @@ func loginClientEventSet_tx(tx pgx.Tx, reqJson json.RawMessage, loginId int64) ( if err := json.Unmarshal(reqJson, &req); err != nil { return nil, err } - return nil, login_clientEvent.Set_tx(tx, loginId, req.ClientEventId, req.LoginClientEvent) + return nil, login_clientEvent.Set_tx(ctx, tx, loginId, req.ClientEventId, req.LoginClientEvent) } diff --git a/request/request_login_keys.go b/request/request_login_keys.go index e16c8145..fad56dec 100644 --- a/request/request_login_keys.go +++ b/request/request_login_keys.go @@ -23,11 +23,11 @@ func LoginKeysGetPublic(ctx context.Context, reqJson json.RawMessage) (interface return login_keys.GetPublic(ctx, req.RelationId, req.RecordIds, req.LoginIds) } -func LoginKeysReset_tx(tx pgx.Tx, loginId int64) (interface{}, error) { - return nil, login_keys.Reset_tx(tx, loginId) +func LoginKeysReset_tx(ctx context.Context, tx pgx.Tx, loginId int64) (interface{}, error) { + return nil, login_keys.Reset_tx(ctx, tx, loginId) } -func LoginKeysStore_tx(tx pgx.Tx, reqJson json.RawMessage, loginId int64) (interface{}, error) { +func LoginKeysStore_tx(ctx context.Context, tx pgx.Tx, reqJson json.RawMessage, loginId int64) (interface{}, error) { var req struct { PrivateKeyEnc string `json:"privateKeyEnc"` @@ -38,11 +38,11 @@ func LoginKeysStore_tx(tx pgx.Tx, reqJson json.RawMessage, loginId int64) (inter if err := json.Unmarshal(reqJson, &req); err != nil { return nil, err } - return nil, login_keys.Store_tx(tx, loginId, + return nil, login_keys.Store_tx(ctx, tx, loginId, req.PrivateKeyEnc, req.PrivateKeyEncBackup, req.PublicKey) } -func LoginKeysStorePrivate_tx(tx pgx.Tx, reqJson json.RawMessage, loginId int64) (interface{}, error) { +func LoginKeysStorePrivate_tx(ctx context.Context, tx pgx.Tx, reqJson json.RawMessage, loginId int64) (interface{}, error) { var req struct { PrivateKeyEnc string `json:"privateKeyEnc"` @@ -51,5 +51,5 @@ func LoginKeysStorePrivate_tx(tx pgx.Tx, reqJson json.RawMessage, loginId int64) if err := json.Unmarshal(reqJson, &req); err != nil { return nil, err } - return nil, login_keys.StorePrivate_tx(tx, loginId, req.PrivateKeyEnc) + return nil, login_keys.StorePrivate_tx(ctx, tx, loginId, req.PrivateKeyEnc) } diff --git a/request/request_login_password.go b/request/request_login_password.go index 0e4d493d..6afb25bd 100644 --- a/request/request_login_password.go +++ b/request/request_login_password.go @@ -1,6 +1,7 @@ package request import ( + "context" "encoding/json" "fmt" "r3/login" @@ -9,7 +10,7 @@ import ( "github.com/jackc/pgx/v5" ) -func loginPasswortSet_tx(tx pgx.Tx, reqJson json.RawMessage, loginId int64) (interface{}, error) { +func loginPasswortSet_tx(ctx context.Context, tx pgx.Tx, reqJson json.RawMessage, loginId int64) (interface{}, error) { var req struct { PwNew0 string `json:"pwNew0"` @@ -24,7 +25,7 @@ func loginPasswortSet_tx(tx pgx.Tx, reqJson json.RawMessage, loginId int64) (int return nil, fmt.Errorf("invalid input") } - if err := login_check.Password(tx, loginId, req.PwOld); err != nil { + if err := login_check.Password(ctx, tx, loginId, req.PwOld); err != nil { return nil, err } if err := login_check.PasswordComplexity(req.PwNew0); err != nil { @@ -32,5 +33,5 @@ func loginPasswortSet_tx(tx pgx.Tx, reqJson json.RawMessage, loginId int64) (int } salt, hash := login.GenerateSaltHash(req.PwNew0) - return nil, login.SetSaltHash_tx(tx, salt, hash, loginId) + return nil, login.SetSaltHash_tx(ctx, tx, salt, hash, loginId) } diff --git a/request/request_login_setting.go b/request/request_login_setting.go index 29932a1c..ea8626af 100644 --- a/request/request_login_setting.go +++ b/request/request_login_setting.go @@ -1,6 +1,7 @@ package request import ( + "context" "encoding/json" "r3/login/login_setting" "r3/types" @@ -9,18 +10,18 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) -func LoginSettingsGet(loginId int64) (interface{}, error) { - return login_setting.Get( +func LoginSettingsGet(ctx context.Context, tx pgx.Tx, loginId int64) (interface{}, error) { + return login_setting.Get_tx(ctx, tx, pgtype.Int8{Int64: loginId, Valid: true}, pgtype.Int8{}) } -func LoginSettingsSet_tx(tx pgx.Tx, reqJson json.RawMessage, loginId int64) (interface{}, error) { +func LoginSettingsSet_tx(ctx context.Context, tx pgx.Tx, reqJson json.RawMessage, loginId int64) (interface{}, error) { var req types.Settings if err := json.Unmarshal(reqJson, &req); err != nil { return nil, err } - return nil, login_setting.Set_tx(tx, + return nil, login_setting.Set_tx(ctx, tx, pgtype.Int8{Int64: loginId, Valid: true}, pgtype.Int8{}, req, false) diff --git a/request/request_login_template.go b/request/request_login_template.go index b3eeb38d..f9e20166 100644 --- a/request/request_login_template.go +++ b/request/request_login_template.go @@ -1,6 +1,7 @@ package request import ( + "context" "encoding/json" "r3/login/login_template" "r3/types" @@ -8,7 +9,7 @@ import ( "github.com/jackc/pgx/v5" ) -func LoginTemplateDel_tx(tx pgx.Tx, reqJson json.RawMessage) (interface{}, error) { +func LoginTemplateDel_tx(ctx context.Context, tx pgx.Tx, reqJson json.RawMessage) (interface{}, error) { var req struct { Id int64 `json:"id"` } @@ -16,9 +17,9 @@ func LoginTemplateDel_tx(tx pgx.Tx, reqJson json.RawMessage) (interface{}, error if err := json.Unmarshal(reqJson, &req); err != nil { return nil, err } - return nil, login_template.Del_tx(tx, req.Id) + return nil, login_template.Del_tx(ctx, tx, req.Id) } -func LoginTemplateGet(reqJson json.RawMessage) (interface{}, error) { +func LoginTemplateGet_tx(ctx context.Context, tx pgx.Tx, reqJson json.RawMessage) (interface{}, error) { var req struct { ById int64 `json:"byId"` } @@ -26,13 +27,13 @@ func LoginTemplateGet(reqJson json.RawMessage) (interface{}, error) { if err := json.Unmarshal(reqJson, &req); err != nil { return nil, err } - return login_template.Get(req.ById) + return login_template.Get_tx(ctx, tx, req.ById) } -func LoginTemplateSet_tx(tx pgx.Tx, reqJson json.RawMessage) (interface{}, error) { +func LoginTemplateSet_tx(ctx context.Context, tx pgx.Tx, reqJson json.RawMessage) (interface{}, error) { var req types.LoginTemplateAdmin if err := json.Unmarshal(reqJson, &req); err != nil { return nil, err } - return login_template.Set_tx(tx, req) + return login_template.Set_tx(ctx, tx, req) } diff --git a/request/request_login_widgets.go b/request/request_login_widgets.go index 07c46e4a..3884fd2b 100644 --- a/request/request_login_widgets.go +++ b/request/request_login_widgets.go @@ -1,6 +1,7 @@ package request import ( + "context" "encoding/json" "r3/login/login_widget" "r3/types" @@ -8,14 +9,14 @@ import ( "github.com/jackc/pgx/v5" ) -func LoginWidgetGroupsGet(loginId int64) (interface{}, error) { - return login_widget.Get(loginId) +func LoginWidgetGroupsGet_tx(ctx context.Context, tx pgx.Tx, loginId int64) (interface{}, error) { + return login_widget.Get_tx(ctx, tx, loginId) } -func LoginWidgetGroupsSet_tx(tx pgx.Tx, reqJson json.RawMessage, loginId int64) (interface{}, error) { +func LoginWidgetGroupsSet_tx(ctx context.Context, tx pgx.Tx, reqJson json.RawMessage, loginId int64) (interface{}, error) { var req []types.LoginWidgetGroup if err := json.Unmarshal(reqJson, &req); err != nil { return nil, err } - return nil, login_widget.Set_tx(tx, loginId, req) + return nil, login_widget.Set_tx(ctx, tx, loginId, req) } diff --git a/request/request_lookups.go b/request/request_lookups.go index edb4dd5a..fa90f6bc 100644 --- a/request/request_lookups.go +++ b/request/request_lookups.go @@ -1,16 +1,17 @@ package request import ( + "context" "encoding/json" "fmt" "r3/cache" "r3/config" - "r3/db" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" ) -func lookupGet(reqJson json.RawMessage, loginId int64) (interface{}, error) { +func lookupGet_tx(ctx context.Context, tx pgx.Tx, reqJson json.RawMessage, loginId int64) (interface{}, error) { var req struct { Name string `json:"name"` @@ -34,7 +35,7 @@ func lookupGet(reqJson json.RawMessage, loginId int64) (interface{}, error) { case "loginHasClient": var hasClient bool - err := db.Pool.QueryRow(db.Ctx, ` + err := tx.QueryRow(ctx, ` SELECT EXISTS( SELECT * FROM instance.login_token_fixed @@ -53,7 +54,7 @@ func lookupGet(reqJson json.RawMessage, loginId int64) (interface{}, error) { Public pgtype.Text `json:"public"` } - err := db.Pool.QueryRow(db.Ctx, ` + err := tx.QueryRow(ctx, ` SELECT key_private_enc, key_private_enc_backup, key_public FROM instance.login WHERE id = $1 diff --git a/request/request_pgFunction.go b/request/request_pgFunction.go index d578c4ab..b95295e9 100644 --- a/request/request_pgFunction.go +++ b/request/request_pgFunction.go @@ -1,10 +1,10 @@ package request import ( + "context" "encoding/json" "fmt" "r3/cache" - "r3/db" "r3/handler" "r3/schema/pgFunction" "r3/types" @@ -14,7 +14,7 @@ import ( "github.com/jackc/pgx/v5" ) -func PgFunctionDel_tx(tx pgx.Tx, reqJson json.RawMessage) (interface{}, error) { +func PgFunctionDel_tx(ctx context.Context, tx pgx.Tx, reqJson json.RawMessage) (interface{}, error) { var req struct { Id uuid.UUID `json:"id"` @@ -23,10 +23,10 @@ func PgFunctionDel_tx(tx pgx.Tx, reqJson json.RawMessage) (interface{}, error) { if err := json.Unmarshal(reqJson, &req); err != nil { return nil, err } - return nil, pgFunction.Del_tx(tx, req.Id) + return nil, pgFunction.Del_tx(ctx, tx, req.Id) } -func PgFunctionExec_tx(tx pgx.Tx, reqJson json.RawMessage, onlyFrontendFnc bool) (interface{}, error) { +func PgFunctionExec_tx(ctx context.Context, tx pgx.Tx, reqJson json.RawMessage, onlyFrontendFnc bool) (interface{}, error) { cache.Schema_mx.RLock() defer cache.Schema_mx.RUnlock() @@ -58,7 +58,7 @@ func PgFunctionExec_tx(tx pgx.Tx, reqJson json.RawMessage, onlyFrontendFnc bool) } var returnIf interface{} - if err := tx.QueryRow(db.Ctx, fmt.Sprintf(` + if err := tx.QueryRow(ctx, fmt.Sprintf(` SELECT "%s"."%s"(%s) `, mod.Name, fnc.Name, strings.Join(placeholders, ",")), req.Args...).Scan(&returnIf); err != nil { @@ -68,12 +68,12 @@ func PgFunctionExec_tx(tx pgx.Tx, reqJson json.RawMessage, onlyFrontendFnc bool) return returnIf, nil } -func PgFunctionSet_tx(tx pgx.Tx, reqJson json.RawMessage) (interface{}, error) { +func PgFunctionSet_tx(ctx context.Context, tx pgx.Tx, reqJson json.RawMessage) (interface{}, error) { var req types.PgFunction if err := json.Unmarshal(reqJson, &req); err != nil { return nil, err } - return nil, pgFunction.Set_tx(tx, req) + return nil, pgFunction.Set_tx(ctx, tx, req) } diff --git a/request/request_pgIndex.go b/request/request_pgIndex.go index 396a8258..15520ee8 100644 --- a/request/request_pgIndex.go +++ b/request/request_pgIndex.go @@ -1,6 +1,7 @@ package request import ( + "context" "encoding/json" "r3/schema/pgIndex" "r3/types" @@ -9,27 +10,17 @@ import ( "github.com/jackc/pgx/v5" ) -func PgIndexDel_tx(tx pgx.Tx, reqJson json.RawMessage) (interface{}, error) { +func PgIndexDel_tx(ctx context.Context, tx pgx.Tx, reqJson json.RawMessage) (interface{}, error) { var req struct { Id uuid.UUID `json:"id"` } if err := json.Unmarshal(reqJson, &req); err != nil { return nil, err } - return nil, pgIndex.Del_tx(tx, req.Id) + return nil, pgIndex.Del_tx(ctx, tx, req.Id) } -func PgIndexGet(reqJson json.RawMessage) (interface{}, error) { - var req struct { - RelationId uuid.UUID `json:"relationId"` - } - if err := json.Unmarshal(reqJson, &req); err != nil { - return nil, err - } - return pgIndex.Get(req.RelationId) -} - -func PgIndexSet_tx(tx pgx.Tx, reqJson json.RawMessage) (interface{}, error) { +func PgIndexSet_tx(ctx context.Context, tx pgx.Tx, reqJson json.RawMessage) (interface{}, error) { var req types.PgIndex if err := json.Unmarshal(reqJson, &req); err != nil { @@ -39,5 +30,5 @@ func PgIndexSet_tx(tx pgx.Tx, reqJson json.RawMessage) (interface{}, error) { req.AutoFki = false req.PrimaryKey = false - return nil, pgIndex.Set_tx(tx, req) + return nil, pgIndex.Set_tx(ctx, tx, req) } diff --git a/schema/attribute/attribute.go b/schema/attribute/attribute.go index 908639f5..d9f3fc75 100644 --- a/schema/attribute/attribute.go +++ b/schema/attribute/attribute.go @@ -37,7 +37,7 @@ func Del_tx(tx pgx.Tx, id uuid.UUID) error { // delete FK index if relationship attribute if schema.IsContentRelationship(content) { - if err := pgIndex.DelAutoFkiForAttribute_tx(tx, id); err != nil { + if err := pgIndex.DelAutoFkiForAttribute_tx(db.Ctx, tx, id); err != nil { return err } } @@ -259,10 +259,10 @@ func Set_tx(tx pgx.Tx, atr types.Attribute) error { // rebuild foreign key index if content changed (as in 1:1 -> n:1) // this also adds/removes unique constraint, if required if atr.Content != contentEx { - if err := pgIndex.DelAutoFkiForAttribute_tx(tx, atr.Id); err != nil { + if err := pgIndex.DelAutoFkiForAttribute_tx(db.Ctx, tx, atr.Id); err != nil { return err } - if err := pgIndex.SetAutoFkiForAttribute_tx(tx, atr.RelationId, atr.Id, (atr.Content == "1:1")); err != nil { + if err := pgIndex.SetAutoFkiForAttribute_tx(db.Ctx, tx, atr.RelationId, atr.Id, (atr.Content == "1:1")); err != nil { return err } } @@ -420,7 +420,7 @@ func Set_tx(tx pgx.Tx, atr types.Attribute) error { // create PK PG index reference for new attributes if isNew { - if err := pgIndex.SetPrimaryKeyForAttribute_tx(tx, atr.RelationId, atr.Id); err != nil { + if err := pgIndex.SetPrimaryKeyForAttribute_tx(db.Ctx, tx, atr.RelationId, atr.Id); err != nil { return err } } @@ -468,7 +468,7 @@ func Set_tx(tx pgx.Tx, atr types.Attribute) error { } if isNew { // add automatic FK index for new attributes - if err := pgIndex.SetAutoFkiForAttribute_tx(tx, atr.RelationId, + if err := pgIndex.SetAutoFkiForAttribute_tx(db.Ctx, tx, atr.RelationId, atr.Id, (atr.Content == "1:1")); err != nil { return err diff --git a/schema/pgFunction/pgFunction.go b/schema/pgFunction/pgFunction.go index 3d6e8f6f..e89041ea 100644 --- a/schema/pgFunction/pgFunction.go +++ b/schema/pgFunction/pgFunction.go @@ -1,6 +1,7 @@ package pgFunction import ( + "context" "errors" "fmt" "r3/db" @@ -17,20 +18,20 @@ import ( "github.com/jackc/pgx/v5" ) -func Del_tx(tx pgx.Tx, id uuid.UUID) error { +func Del_tx(ctx context.Context, tx pgx.Tx, id uuid.UUID) error { nameMod, nameEx, _, _, err := schema.GetPgFunctionDetailsById_tx(tx, id) if err != nil { return err } - if _, err := tx.Exec(db.Ctx, fmt.Sprintf(` + if _, err := tx.Exec(ctx, fmt.Sprintf(` DROP FUNCTION "%s"."%s" `, nameMod, nameEx)); err != nil { return err } - if _, err := tx.Exec(db.Ctx, ` + if _, err := tx.Exec(ctx, ` DELETE FROM app.pg_function WHERE id = $1 `, id); err != nil { @@ -126,7 +127,7 @@ func getSchedules_tx(tx pgx.Tx, pgFunctionId uuid.UUID) ([]types.PgFunctionSched return schedules, nil } -func Set_tx(tx pgx.Tx, fnc types.PgFunction) error { +func Set_tx(ctx context.Context, tx pgx.Tx, fnc types.PgFunction) error { if err := check.DbIdentifier(fnc.Name); err != nil { return err @@ -174,7 +175,7 @@ func Set_tx(tx pgx.Tx, fnc types.PgFunction) error { return errors.New("cannot convert between trigger and non-trigger function") } - if _, err := tx.Exec(db.Ctx, ` + if _, err := tx.Exec(ctx, ` UPDATE app.pg_function SET name = $1, code_args = $2, code_function = $3, code_returns = $4, is_frontend_exec = $5, volatility = $6 @@ -192,7 +193,7 @@ func Set_tx(tx pgx.Tx, fnc types.PgFunction) error { if !fnc.IsTrigger { // drop non-trigger function because function arguments can change // two functions with the same name but different interfaces can exist (overloading) - if _, err := tx.Exec(db.Ctx, fmt.Sprintf(`DROP FUNCTION "%s"."%s"`, nameMod, nameEx)); err != nil { + if _, err := tx.Exec(ctx, fmt.Sprintf(`DROP FUNCTION "%s"."%s"`, nameMod, nameEx)); err != nil { return err } } else { @@ -201,7 +202,7 @@ func Set_tx(tx pgx.Tx, fnc types.PgFunction) error { // we cannot drop trigger functions without recreating triggers // renaming changes the function name in the trigger and allows us to replace it // as triggers do not take arguments, overloading is not a problem - if _, err := tx.Exec(db.Ctx, fmt.Sprintf(` + if _, err := tx.Exec(ctx, fmt.Sprintf(` ALTER FUNCTION "%s"."%s" RENAME TO "%s" `, nameMod, nameEx, fnc.Name)); err != nil { return err @@ -209,7 +210,7 @@ func Set_tx(tx pgx.Tx, fnc types.PgFunction) error { } } } else { - if _, err := tx.Exec(db.Ctx, ` + if _, err := tx.Exec(ctx, ` INSERT INTO app.pg_function (id, module_id, name, code_args, code_function, code_returns, is_frontend_exec, is_login_sync, is_trigger, volatility) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10) @@ -233,7 +234,7 @@ func Set_tx(tx pgx.Tx, fnc types.PgFunction) error { s.AtDay = schema.GetValidAtDay(s.IntervalType, s.AtDay) if known { - if _, err := tx.Exec(db.Ctx, ` + if _, err := tx.Exec(ctx, ` UPDATE app.pg_function_schedule SET at_second = $1, at_minute = $2, at_hour = $3, at_day = $4, interval_type = $5, interval_value = $6 @@ -244,7 +245,7 @@ func Set_tx(tx pgx.Tx, fnc types.PgFunction) error { return err } } else { - if _, err := tx.Exec(db.Ctx, ` + if _, err := tx.Exec(ctx, ` INSERT INTO app.pg_function_schedule ( id, pg_function_id, at_second, at_minute, at_hour, at_day, interval_type, interval_value @@ -255,7 +256,7 @@ func Set_tx(tx pgx.Tx, fnc types.PgFunction) error { return err } - if _, err := tx.Exec(db.Ctx, ` + if _, err := tx.Exec(ctx, ` INSERT INTO instance.schedule ( pg_function_schedule_id,date_attempt,date_success ) @@ -267,7 +268,7 @@ func Set_tx(tx pgx.Tx, fnc types.PgFunction) error { scheduleIds = append(scheduleIds, s.Id) } - if _, err := tx.Exec(db.Ctx, ` + if _, err := tx.Exec(ctx, ` DELETE FROM app.pg_function_schedule WHERE pg_function_id = $1 AND id <> ALL($2) @@ -286,7 +287,7 @@ func Set_tx(tx pgx.Tx, fnc types.PgFunction) error { return fmt.Errorf("failed to process entity IDs, %s", err) } - _, err = tx.Exec(db.Ctx, fmt.Sprintf(` + _, err = tx.Exec(ctx, fmt.Sprintf(` CREATE OR REPLACE FUNCTION "%s"."%s"(%s) RETURNS %s LANGUAGE plpgsql %s AS %s `, nameMod, fnc.Name, fnc.CodeArgs, fnc.CodeReturns, fnc.Volatility, fnc.CodeFunction)) @@ -344,7 +345,7 @@ func RecreateAffectedBy_tx(tx pgx.Tx, entity string, entityId uuid.UUID) error { if err != nil { return err } - if err := Set_tx(tx, f); err != nil { + if err := Set_tx(db.Ctx, tx, f); err != nil { return err } } diff --git a/schema/pgIndex/pgIndex.go b/schema/pgIndex/pgIndex.go index 1df158a4..cf2498bd 100644 --- a/schema/pgIndex/pgIndex.go +++ b/schema/pgIndex/pgIndex.go @@ -1,6 +1,7 @@ package pgIndex import ( + "context" "errors" "fmt" "r3/db" @@ -13,12 +14,12 @@ import ( "github.com/jackc/pgx/v5" ) -func DelAutoFkiForAttribute_tx(tx pgx.Tx, attributeId uuid.UUID) error { +func DelAutoFkiForAttribute_tx(ctx context.Context, tx pgx.Tx, attributeId uuid.UUID) error { // get ID of automatically created FK index for relationship attribute var pgIndexId uuid.UUID - err := tx.QueryRow(db.Ctx, ` + err := tx.QueryRow(ctx, ` SELECT i.id FROM app.pg_index AS i INNER JOIN app.pg_index_attribute AS a ON a.pg_index_id = i.id @@ -37,9 +38,9 @@ func DelAutoFkiForAttribute_tx(tx pgx.Tx, attributeId uuid.UUID) error { } // delete auto FK index for attribute - return Del_tx(tx, pgIndexId) + return Del_tx(ctx, tx, pgIndexId) } -func Del_tx(tx pgx.Tx, id uuid.UUID) error { +func Del_tx(ctx context.Context, tx pgx.Tx, id uuid.UUID) error { moduleName, _, err := schema.GetPgIndexNamesById_tx(tx, id) if err != nil { @@ -48,13 +49,13 @@ func Del_tx(tx pgx.Tx, id uuid.UUID) error { // can also be deleted by cascaded entity (relation/attribute) // drop if it still exists - if _, err := tx.Exec(db.Ctx, fmt.Sprintf(` + if _, err := tx.Exec(ctx, fmt.Sprintf(` DROP INDEX IF EXISTS "%s"."%s" `, moduleName, schema.GetPgIndexName(id))); err != nil { return err } - _, err = tx.Exec(db.Ctx, `DELETE FROM app.pg_index WHERE id = $1`, id) + _, err = tx.Exec(ctx, `DELETE FROM app.pg_index WHERE id = $1`, id) return err } @@ -123,8 +124,8 @@ func GetAttributes(pgIndexId uuid.UUID) ([]types.PgIndexAttribute, error) { return attributes, nil } -func SetAutoFkiForAttribute_tx(tx pgx.Tx, relationId uuid.UUID, attributeId uuid.UUID, noDuplicates bool) error { - return Set_tx(tx, types.PgIndex{ +func SetAutoFkiForAttribute_tx(ctx context.Context, tx pgx.Tx, relationId uuid.UUID, attributeId uuid.UUID, noDuplicates bool) error { + return Set_tx(ctx, tx, types.PgIndex{ Id: uuid.Nil, RelationId: relationId, AutoFki: true, @@ -132,7 +133,7 @@ func SetAutoFkiForAttribute_tx(tx pgx.Tx, relationId uuid.UUID, attributeId uuid NoDuplicates: noDuplicates, PrimaryKey: false, Attributes: []types.PgIndexAttribute{ - types.PgIndexAttribute{ + { AttributeId: attributeId, Position: 0, OrderAsc: true, @@ -140,8 +141,8 @@ func SetAutoFkiForAttribute_tx(tx pgx.Tx, relationId uuid.UUID, attributeId uuid }, }) } -func SetPrimaryKeyForAttribute_tx(tx pgx.Tx, relationId uuid.UUID, attributeId uuid.UUID) error { - return Set_tx(tx, types.PgIndex{ +func SetPrimaryKeyForAttribute_tx(ctx context.Context, tx pgx.Tx, relationId uuid.UUID, attributeId uuid.UUID) error { + return Set_tx(ctx, tx, types.PgIndex{ Id: uuid.Nil, RelationId: relationId, AutoFki: false, @@ -149,7 +150,7 @@ func SetPrimaryKeyForAttribute_tx(tx pgx.Tx, relationId uuid.UUID, attributeId u NoDuplicates: true, PrimaryKey: true, Attributes: []types.PgIndexAttribute{ - types.PgIndexAttribute{ + { AttributeId: attributeId, Position: 0, OrderAsc: true, @@ -157,7 +158,7 @@ func SetPrimaryKeyForAttribute_tx(tx pgx.Tx, relationId uuid.UUID, attributeId u }, }) } -func Set_tx(tx pgx.Tx, pgi types.PgIndex) error { +func Set_tx(ctx context.Context, tx pgx.Tx, pgi types.PgIndex) error { if len(pgi.Attributes) == 0 { return errors.New("cannot create index without attributes") @@ -196,7 +197,7 @@ func Set_tx(tx pgx.Tx, pgi types.PgIndex) error { } // insert pg index references - if _, err := tx.Exec(db.Ctx, ` + if _, err := tx.Exec(ctx, ` INSERT INTO app.pg_index (id, relation_id, attribute_id_dict, method, no_duplicates, auto_fki, primary_key) VALUES ($1,$2,$3,$4,$5,$6,$7) @@ -205,7 +206,7 @@ func Set_tx(tx pgx.Tx, pgi types.PgIndex) error { return err } for position, atr := range pgi.Attributes { - if _, err := tx.Exec(db.Ctx, ` + if _, err := tx.Exec(ctx, ` INSERT INTO app.pg_index_attribute ( pg_index_id, attribute_id, position, order_asc) VALUES ($1,$2,$3,$4) @@ -270,7 +271,7 @@ func Set_tx(tx pgx.Tx, pgi types.PgIndex) error { indexType = "UNIQUE INDEX" } - _, err = tx.Exec(db.Ctx, fmt.Sprintf(` + _, err = tx.Exec(ctx, fmt.Sprintf(` CREATE %s "%s" ON "%s"."%s" USING %s `, indexType, schema.GetPgIndexName(pgi.Id), modName, relName, indexDef)) diff --git a/transfer/transfer_delete/transfer_delete.go b/transfer/transfer_delete/transfer_delete.go index 0c1626ef..832baf24 100644 --- a/transfer/transfer_delete/transfer_delete.go +++ b/transfer/transfer_delete/transfer_delete.go @@ -191,7 +191,7 @@ func deleteRelationPgIndexes_tx(tx pgx.Tx, moduleId uuid.UUID, relations []types } for _, id := range idsDelete { log.Info("transfer", fmt.Sprintf("del PG index %s", id.String())) - if err := pgIndex.Del_tx(tx, id); err != nil { + if err := pgIndex.Del_tx(db.Ctx, tx, id); err != nil { return err } } @@ -577,7 +577,7 @@ func deletePgFunctions_tx(tx pgx.Tx, moduleId uuid.UUID, pgFunctions []types.PgF } for _, id := range idsDelete { log.Info("transfer", fmt.Sprintf("del PG function %s", id.String())) - if err := pgFunction.Del_tx(tx, id); err != nil { + if err := pgFunction.Del_tx(db.Ctx, tx, id); err != nil { return err } } diff --git a/transfer/transfer_import.go b/transfer/transfer_import.go index a5372a93..0a851e97 100644 --- a/transfer/transfer_import.go +++ b/transfer/transfer_import.go @@ -1,6 +1,7 @@ package transfer import ( + "context" "encoding/base64" "encoding/json" "errors" @@ -137,7 +138,7 @@ func ImportFromFiles(filePathsImport []string) error { } } - if err := importModule_tx(tx, m, firstRun, lastRun, idMapSkipped); err != nil { + if err := importModule_tx(ctx, tx, m, firstRun, lastRun, idMapSkipped); err != nil { return err } @@ -181,7 +182,7 @@ func ImportFromFiles(filePathsImport []string) error { return cluster.SchemaChanged(true, moduleIdsUpdated) } -func importModule_tx(tx pgx.Tx, mod types.Module, firstRun bool, lastRun bool, +func importModule_tx(ctx context.Context, tx pgx.Tx, mod types.Module, firstRun bool, lastRun bool, idMapSkipped map[uuid.UUID]types.Void) error { // we use a sensible import order to avoid conflicts but some cannot be avoided: @@ -379,7 +380,7 @@ func importModule_tx(tx pgx.Tx, mod types.Module, firstRun bool, lastRun bool, } log.Info("transfer", fmt.Sprintf("set PG function %s", e.Id)) - if err := importCheckResultAndApply(tx, pgFunction.Set_tx(tx, e), e.Id, idMapSkipped); err != nil { + if err := importCheckResultAndApply(tx, pgFunction.Set_tx(ctx, tx, e), e.Id, idMapSkipped); err != nil { return err } } @@ -413,7 +414,7 @@ func importModule_tx(tx pgx.Tx, mod types.Module, firstRun bool, lastRun bool, } log.Info("transfer", fmt.Sprintf("set index %s", e.Id)) - if err := importCheckResultAndApply(tx, pgIndex.Set_tx(tx, e), e.Id, idMapSkipped); err != nil { + if err := importCheckResultAndApply(tx, pgIndex.Set_tx(ctx, tx, e), e.Id, idMapSkipped); err != nil { return err } }