Simplify notifications code

This commit is contained in:
Taras Kushnir
2025-08-15 19:29:10 +03:00
parent 97af84cd07
commit 9a6dfdcd1f
16 changed files with 152 additions and 171 deletions

View File

@@ -192,15 +192,14 @@ func run(ctx context.Context, cfg common.ConfigStore, stderr io.Writer, listener
MaxLifetime: sessionStore.MaxLifetime(),
SecureCookie: (*certFileFlag != "") && (*keyFileFlag != ""),
},
PlanService: planService,
APIURL: apiURLConfig.URL(),
CDNURL: cdnURLConfig.URL(),
PuzzleEngine: apiServer,
Metrics: metrics,
Mailer: mailer,
Notifications: &portal.NotificationScheduler{Store: businessDB},
RateLimiter: ipRateLimiter,
DataCtx: dataCtx,
PlanService: planService,
APIURL: apiURLConfig.URL(),
CDNURL: cdnURLConfig.URL(),
PuzzleEngine: apiServer,
Metrics: metrics,
Mailer: mailer,
RateLimiter: ipRateLimiter,
DataCtx: dataCtx,
}
templatesBuilder := portal.NewTemplatesBuilder()

View File

@@ -25,7 +25,7 @@ const (
)
var (
templates = email.Templates()
templates = map[string]string{}
)
func homepage(w http.ResponseWriter, r *http.Request) {
@@ -107,6 +107,10 @@ func serveTemplate(name string) http.HandlerFunc {
func main() {
http.HandleFunc("/", homepage)
for _, tpl := range email.Templates() {
templates[tpl.Name()] = tpl.Content()
}
for k := range templates {
http.HandleFunc("/"+k, serveTemplate(k))
}

View File

@@ -2,6 +2,9 @@ package common
import (
"context"
"crypto/sha1"
"encoding/hex"
"sync"
"time"
)
@@ -16,11 +19,33 @@ type ScheduledNotification struct {
Subject string
Data interface{}
DateTime time.Time
TemplateName string
TemplateHash string
Persistent bool
}
type ScheduledNotifications interface {
Add(ctx context.Context, notification *ScheduledNotification) error
Remove(ctx context.Context, userID int32, referenceID string) error
func NewEmailTemplate(name, content string) *EmailTemplate {
return &EmailTemplate{name: name, content: content}
}
type EmailTemplate struct {
name string
hash string
mux sync.Mutex
content string
}
func (et *EmailTemplate) Name() string { return et.name }
func (et *EmailTemplate) Content() string { return et.content }
func (et *EmailTemplate) Hash() string {
et.mux.Lock()
defer et.mux.Unlock()
if len(et.hash) == 0 {
h := sha1.New()
h.Write([]byte(et.content))
et.hash = hex.EncodeToString(h.Sum(nil))
}
return et.hash
}

View File

@@ -2,6 +2,7 @@ package db
import (
"context"
"encoding/json"
"errors"
"log/slog"
"slices"
@@ -1563,13 +1564,11 @@ func (s *BusinessStoreImpl) CreateNewAccount(ctx context.Context, params *dbgen.
return user, org, nil
}
func (s *BusinessStoreImpl) CreateNotificationTemplate(ctx context.Context, name, tpl string) (*dbgen.NotificationTemplate, error) {
func (s *BusinessStoreImpl) CreateNotificationTemplate(ctx context.Context, name, tpl, hash string) (*dbgen.NotificationTemplate, error) {
if s.querier == nil {
return nil, ErrMaintenance
}
hash := EmailTemplateHash(tpl)
t, err := s.querier.CreateNotificationTemplate(ctx, &dbgen.CreateNotificationTemplateParams{
Name: name,
Content: tpl,
@@ -1600,8 +1599,8 @@ func (s *BusinessStoreImpl) RetrieveNotificationTemplate(ctx context.Context, te
return reader.Read(ctx)
}
func (s *BusinessStoreImpl) CreateUserNotification(ctx context.Context, params *dbgen.CreateUserNotificationParams) (*dbgen.UserNotification, error) {
if (params == nil) || (len(params.TemplateHash.String) == 0) || (len(params.ReferenceID) == 0) {
func (s *BusinessStoreImpl) CreateUserNotification(ctx context.Context, n *common.ScheduledNotification) (*dbgen.UserNotification, error) {
if (n == nil) || (len(n.TemplateHash) == 0) || (len(n.ReferenceID) == 0) {
return nil, ErrInvalidInput
}
@@ -1609,6 +1608,23 @@ func (s *BusinessStoreImpl) CreateUserNotification(ctx context.Context, params *
return nil, ErrMaintenance
}
payload, err := json.Marshal(n.Data)
if err != nil {
slog.ErrorContext(ctx, "Failed to serialize payload for notification", common.ErrAttr(err))
return nil, err
}
// NOTE: we don't add template to DB (again) because it should have been done with RegisterEmailTemplatesJob on startup
params := &dbgen.CreateUserNotificationParams{
UserID: Int(n.UserID),
ReferenceID: n.ReferenceID,
TemplateHash: Text(n.TemplateHash),
Subject: n.Subject,
Payload: payload,
ScheduledAt: Timestampz(n.DateTime),
Persistent: n.Persistent,
}
rlog := slog.With("userID", params.UserID.Int32, "refID", params.ReferenceID)
notif, err := s.querier.CreateUserNotification(ctx, params)

View File

@@ -2,7 +2,6 @@ package db
import (
"context"
"crypto/sha1"
"encoding/hex"
"log/slog"
"strings"
@@ -443,9 +442,3 @@ func (br *StoreBulkReader[TArg, TKey, T]) Read(ctx context.Context, args map[TAr
return cached, items, nil
}
func EmailTemplateHash(content string) string {
h := sha1.New()
h.Write([]byte(content))
return hex.EncodeToString(h.Sum(nil))
}

View File

@@ -1,5 +1,7 @@
package email
import "github.com/PrivateCaptcha/PrivateCaptcha/pkg/common"
type APIKeyContext struct {
APIKeyName string
APIKeyPrefix string
@@ -11,9 +13,13 @@ type APIKeyExpirationContext struct {
ExpireDays int
}
var (
APIKeyExirationTemplate = common.NewEmailTemplate("apikey-expiration", APIKeyExpirationHTMLTemplate)
APIKeyExpiredTemplate = common.NewEmailTemplate("apikey-expired", APIKeyExpiredHTMLTemplate)
)
const (
APIKeyExpirationTemplateName = "apikey-expiration"
APIKeyExpirationHTML = `<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
APIKeyExpirationHTMLTemplate = `<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<html dir="ltr" lang="en">
<head>
<link rel="preload" as="image" href="{{.CDNURL}}/portal/img/pc-logo-dark.png" />
@@ -56,11 +62,8 @@ const (
</table>
</body>
</html>`
)
const (
APIKeyExpiredTemplateName = "apikey-expired"
APIKeyExpiredHTML = `<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
APIKeyExpiredHTMLTemplate = `<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<html dir="ltr" lang="en">
<head>
<link rel="preload" as="image" href="{{.CDNURL}}/portal/img/pc-logo-dark.png" />

View File

@@ -1,14 +1,22 @@
package email
import "strings"
import (
"strings"
func Templates() map[string]string {
return map[string]string{
WelcomeTemplateName: WelcomeHTMLTemplate,
TwoFactorTemplateName: TwoFactorHTMLTemplate,
APIKeyExpirationTemplateName: APIKeyExpirationHTML,
APIKeyExpiredTemplateName: APIKeyExpiredHTML,
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/common"
)
var (
templates = []*common.EmailTemplate{
APIKeyExirationTemplate,
APIKeyExpiredTemplate,
WelcomeEmailTemplate,
TwoFactorEmailTemplate,
}
)
func Templates() []*common.EmailTemplate {
return templates
}
func CanBeHTML(s string) bool {

View File

@@ -16,8 +16,8 @@ func TestCanBeHTML(t *testing.T) {
{TwoFactorHTMLTemplate, true},
{welcomeTextTemplate, false},
{twoFactorTextTemplate, false},
{APIKeyExpirationHTML, true},
{APIKeyExpiredHTML, true},
{APIKeyExpirationHTMLTemplate, true},
{APIKeyExpiredHTMLTemplate, true},
}
for i, tc := range testCases {

View File

@@ -1,7 +1,12 @@
package email
import "github.com/PrivateCaptcha/PrivateCaptcha/pkg/common"
var (
TwoFactorEmailTemplate = common.NewEmailTemplate("twofactor", TwoFactorHTMLTemplate)
)
const (
TwoFactorTemplateName = "twofactor"
TwoFactorHTMLTemplate = `<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<html dir="ltr" lang="en">
<head>

View File

@@ -1,7 +1,12 @@
package email
import "github.com/PrivateCaptcha/PrivateCaptcha/pkg/common"
var (
WelcomeEmailTemplate = common.NewEmailTemplate("welcome", WelcomeHTMLTemplate)
)
const (
WelcomeTemplateName = "welcome"
WelcomeHTMLTemplate = `<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<html dir="ltr" lang="en">
<head>

View File

@@ -17,7 +17,7 @@ import (
)
type RegisterEmailTemplatesJob struct {
Templates map[string]string
Templates []*common.EmailTemplate
Store db.Implementor
}
@@ -33,9 +33,9 @@ func (j *RegisterEmailTemplatesJob) InitialPause() time.Duration {
func (j *RegisterEmailTemplatesJob) RunOnce(ctx context.Context) error {
var anyError error
for name, content := range j.Templates {
if _, err := j.Store.Impl().CreateNotificationTemplate(ctx, name, content); err != nil {
slog.ErrorContext(ctx, "Failed to upsert notification template", "name", name, common.ErrAttr(err))
for _, tpl := range j.Templates {
if _, err := j.Store.Impl().CreateNotificationTemplate(ctx, tpl.Name(), tpl.Content(), tpl.Hash()); err != nil {
slog.ErrorContext(ctx, "Failed to upsert notification template", "name", tpl.Name(), common.ErrAttr(err))
anyError = err
}
}
@@ -47,7 +47,7 @@ type UserEmailNotificationsJob struct {
// this is the "actual" interval since we will be running as a DB-locked distributed job
RunInterval time.Duration
Store db.Implementor
Templates map[string]string
Templates []*common.EmailTemplate
Sender email.Sender
ChunkSize common.ConfigItem
EmailFrom common.ConfigItem
@@ -90,27 +90,17 @@ func groupNotificationsByTemplate(ctx context.Context, notifications []*dbgen.Ge
return result
}
type indexedNotificationTemplate struct {
name string
hash string
content string
}
func indexTemplates(ctx context.Context, nameToContentTplMap map[string]string) map[string]*indexedNotificationTemplate {
templates := make(map[string]*indexedNotificationTemplate)
for name, content := range nameToContentTplMap {
hash := db.EmailTemplateHash(content)
if _, ok := templates[hash]; ok {
slog.ErrorContext(ctx, "Found two templates with the same hash", "hash", hash, "name", name)
func indexTemplates(ctx context.Context, templates []*common.EmailTemplate) map[string]*common.EmailTemplate {
tplMap := make(map[string]*common.EmailTemplate)
for _, tpl := range templates {
hash := tpl.Hash()
if _, ok := tplMap[hash]; ok {
slog.ErrorContext(ctx, "Found two templates with the same hash", "hash", hash, "name", tpl.Name())
continue
}
templates[hash] = &indexedNotificationTemplate{
name: name,
hash: hash,
content: content,
}
tplMap[hash] = tpl
}
return templates
return tplMap
}
type preparedNotificationTemplate struct {
@@ -120,15 +110,15 @@ type preparedNotificationTemplate struct {
}
func (j *UserEmailNotificationsJob) retrieveTemplate(ctx context.Context,
templates map[string]*indexedNotificationTemplate,
templates map[string]*common.EmailTemplate,
templateHash string) (*preparedNotificationTemplate, error) {
hlog := slog.With("hash", templateHash)
var content string
var name string
itpl, ok := templates[templateHash]
if ok {
content = itpl.content
name = itpl.name
content = itpl.Content()
name = itpl.Name()
} else {
hlog.WarnContext(ctx, "Template is not found locally")
if dbTemplate, err := j.Store.Impl().RetrieveNotificationTemplate(ctx, templateHash); err == nil {

View File

@@ -2,23 +2,14 @@ package portal
import (
"context"
"encoding/json"
"errors"
"log/slog"
"net/http"
"strconv"
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/common"
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/db"
dbgen "github.com/PrivateCaptcha/PrivateCaptcha/pkg/db/generated"
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/email"
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/session"
)
var (
errEmailTemplateNotFound = errors.New("template with such name does not exist")
)
func (s *Server) createSystemNotificationContext(ctx context.Context, sess *common.Session) systemNotificationContext {
renderCtx := systemNotificationContext{}
@@ -55,57 +46,3 @@ func (s *Server) dismissNotification(w http.ResponseWriter, r *http.Request) {
http.Error(w, "", http.StatusBadRequest)
}
}
type NotificationScheduler struct {
Store db.Implementor
}
var _ common.ScheduledNotifications = (*NotificationScheduler)(nil)
func (ns *NotificationScheduler) Add(ctx context.Context, n *common.ScheduledNotification) error {
_, err := ns.AddEx(ctx, n)
return err
}
func (ns *NotificationScheduler) AddEx(ctx context.Context, n *common.ScheduledNotification) (*dbgen.UserNotification, error) {
templates := email.Templates()
template, ok := templates[n.TemplateName]
if !ok {
slog.ErrorContext(ctx, "Notification template with such name does not exist", "name", n.TemplateName)
return nil, errEmailTemplateNotFound
}
payload, err := json.Marshal(n.Data)
if err != nil {
slog.ErrorContext(ctx, "Failed to serialize payload for notification", common.ErrAttr(err))
return nil, err
}
// NOTE: we don't add template to DB (again) because it should have been done with RegisterEmailTemplatesJob on startup
params := &dbgen.CreateUserNotificationParams{
UserID: db.Int(n.UserID),
ReferenceID: n.ReferenceID,
TemplateHash: db.Text(db.EmailTemplateHash(template)),
Subject: n.Subject,
Payload: payload,
ScheduledAt: db.Timestampz(n.DateTime),
Persistent: n.Persistent,
}
notif, err := ns.Store.Impl().CreateUserNotification(ctx, params)
if err != nil {
slog.ErrorContext(ctx, "Failed to add scheduled notification", common.ErrAttr(err))
return nil, err
}
return notif, nil
}
func (ns *NotificationScheduler) Remove(ctx context.Context, userID int32, referenceID string) error {
if err := ns.Store.Impl().DeletePendingUserNotification(ctx, userID, referenceID); err != nil {
slog.ErrorContext(ctx, "Failed to delete scheduled notification", common.ErrAttr(err))
return err
}
return nil
}

View File

@@ -7,8 +7,6 @@ import (
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/common"
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/config"
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/db"
dbgen "github.com/PrivateCaptcha/PrivateCaptcha/pkg/db/generated"
db_tests "github.com/PrivateCaptcha/PrivateCaptcha/pkg/db/tests"
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/email"
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/maintenance"
@@ -39,17 +37,16 @@ func TestUserNotificationsJob(t *testing.T) {
const referenceID = "referenceID"
hash := db.EmailTemplateHash(email.TwoFactorHTMLTemplate)
params := &dbgen.CreateUserNotificationParams{
UserID: db.Int(user.ID),
n := &common.ScheduledNotification{
UserID: user.ID,
ReferenceID: referenceID,
TemplateHash: db.Text(hash),
TemplateHash: email.TwoFactorEmailTemplate.Hash(),
Subject: "subject",
Payload: []byte("{}"),
ScheduledAt: db.Timestampz(tnow.Add(-10 * time.Minute)),
Data: map[string]int{},
DateTime: tnow.Add(-10 * time.Minute),
Persistent: false,
}
if _, err := store.Impl().CreateUserNotification(ctx, params); err != nil {
if _, err := store.Impl().CreateUserNotification(ctx, n); err != nil {
t.Fatal(err)
}
@@ -105,17 +102,15 @@ func TestDeleteSentNotifications(t *testing.T) {
Subject: "subject",
Data: map[string]int{},
DateTime: tnow.Add(-10 * time.Minute),
TemplateName: email.TwoFactorTemplateName,
TemplateHash: email.TwoFactorEmailTemplate.Hash(),
}
scheduler := &NotificationScheduler{Store: store}
notif, err := scheduler.AddEx(ctx, sn)
notif, err := store.Impl().CreateUserNotification(ctx, sn)
if err != nil {
t.Fatal(err)
}
if _, err := scheduler.AddEx(ctx, sn); err == nil {
if _, err := store.Impl().CreateUserNotification(ctx, sn); err == nil {
t.Fatal("Shouldn't create a notification with the same referenceID")
}
@@ -128,7 +123,7 @@ func TestDeleteSentNotifications(t *testing.T) {
}
// should be able to create again (unlike before)
if _, err := scheduler.AddEx(ctx, sn); err != nil {
if _, err := store.Impl().CreateUserNotification(ctx, sn); err != nil {
t.Fatal(err)
}
}
@@ -155,25 +150,23 @@ func TestDeleteScheduledNotification(t *testing.T) {
Subject: "subject",
Data: map[string]int{},
DateTime: tnow.Add(-10 * time.Minute),
TemplateName: email.TwoFactorTemplateName,
TemplateHash: email.TwoFactorEmailTemplate.Hash(),
}
scheduler := &NotificationScheduler{Store: store}
if _, err := scheduler.AddEx(ctx, sn); err != nil {
if _, err := store.Impl().CreateUserNotification(ctx, sn); err != nil {
t.Fatal(err)
}
if _, err := scheduler.AddEx(ctx, sn); err == nil {
if _, err := store.Impl().CreateUserNotification(ctx, sn); err == nil {
t.Fatal("Shouldn't create a notification with the same referenceID")
}
if err := scheduler.Remove(ctx, user.ID, sn.ReferenceID); err != nil {
if err := store.Impl().DeletePendingUserNotification(ctx, user.ID, sn.ReferenceID); err != nil {
t.Fatal(err)
}
// should be able to create again (unlike before)
if _, err := scheduler.AddEx(ctx, sn); err != nil {
if _, err := store.Impl().CreateUserNotification(ctx, sn); err != nil {
t.Fatal(err)
}
}

View File

@@ -108,7 +108,6 @@ type Server struct {
canRegister atomic.Bool
SettingsTabs []*SettingsTab
RateLimiter ratelimit.HTTPRateLimiter
Notifications common.ScheduledNotifications
RenderConstants interface{}
Jobs Jobs
PlatformCtx interface{}

View File

@@ -118,13 +118,12 @@ func TestMain(m *testing.M) {
Store: sessionStore,
MaxLifetime: sessionStore.MaxLifetime(),
},
Mailer: &email.StubMailer{},
RateLimiter: &ratelimit.StubRateLimiter{Header: cfg.Get(common.RateLimitHeaderKey).Value()},
PuzzleEngine: &fakePuzzleEngine{result: &puzzle.VerifyResult{Error: puzzle.VerifyNoError}},
Notifications: &NotificationScheduler{Store: store},
Metrics: monitoring.NewStub(),
PlanService: planService,
DataCtx: dataCtx,
Mailer: &email.StubMailer{},
RateLimiter: &ratelimit.StubRateLimiter{Header: cfg.Get(common.RateLimitHeaderKey).Value()},
PuzzleEngine: &fakePuzzleEngine{result: &puzzle.VerifyResult{Error: puzzle.VerifyNoError}},
Metrics: monitoring.NewStub(),
PlanService: planService,
DataCtx: dataCtx,
}
ctx := context.TODO()

View File

@@ -433,7 +433,7 @@ func createAPIKeyExpirationNotification(key *dbgen.APIKey, userKey *userAPIKey)
ExpireDays: apiKeyExpirationNotificationDays,
},
DateTime: key.ExpiresAt.Time.AddDate(0, 0, -apiKeyExpirationNotificationDays),
TemplateName: email.APIKeyExpirationTemplateName,
TemplateHash: email.APIKeyExirationTemplate.Hash(),
Persistent: false,
}
}
@@ -454,7 +454,7 @@ func createAPIKeyExpiredNotification(key *dbgen.APIKey, userKey *userAPIKey) *co
APIKeySettingsPath: fmt.Sprintf("%s?%s=%s", common.SettingsEndpoint, common.ParamTab, common.APIKeysEndpoint),
},
DateTime: key.ExpiresAt.Time,
TemplateName: email.APIKeyExpiredTemplateName,
TemplateHash: email.APIKeyExpiredTemplate.Hash(),
Persistent: false,
}
}
@@ -503,12 +503,14 @@ func (s *Server) postAPIKeySettings(w http.ResponseWriter, r *http.Request) (Mod
if days > apiKeyExpirationNotificationDays {
go common.RunAdHocFunc(common.CopyTraceID(ctx, context.Background()), func(bctx context.Context) error {
return s.Notifications.Add(bctx, createAPIKeyExpirationNotification(newKey, userKey))
_, err := s.Store.Impl().CreateUserNotification(bctx, createAPIKeyExpirationNotification(newKey, userKey))
return err
})
}
go common.RunAdHocFunc(common.CopyTraceID(ctx, context.Background()), func(bctx context.Context) error {
return s.Notifications.Add(bctx, createAPIKeyExpiredNotification(newKey, userKey))
_, err := s.Store.Impl().CreateUserNotification(bctx, createAPIKeyExpiredNotification(newKey, userKey))
return err
})
} else {
slog.ErrorContext(ctx, "Failed to create API key", common.ErrAttr(err))
@@ -540,11 +542,14 @@ func (s *Server) deleteAPIKey(w http.ResponseWriter, r *http.Request) {
}
go common.RunAdHocFunc(common.CopyTraceID(ctx, context.Background()), func(bctx context.Context) error {
return s.Notifications.Remove(bctx, user.ID, apiKeyExpirationReference(int32(keyID)))
})
go common.RunAdHocFunc(common.CopyTraceID(ctx, context.Background()), func(bctx context.Context) error {
return s.Notifications.Remove(bctx, user.ID, apiKeyExpiredReference(int32(keyID)))
var anyError error
if err := s.Store.Impl().DeletePendingUserNotification(ctx, user.ID, apiKeyExpirationReference(int32(keyID))); err != nil {
anyError = err
}
if err := s.Store.Impl().DeletePendingUserNotification(ctx, user.ID, apiKeyExpiredReference(int32(keyID))); err != nil {
anyError = err
}
return anyError
})
w.WriteHeader(http.StatusOK)