mirror of
https://github.com/btouchard/ackify-ce.git
synced 2026-02-07 06:19:37 -06:00
fix(db): use dbctx.GetQuerier in MagicLinkRepository for RLS support
MagicLinkRepository was bypassing RLS by using r.db directly instead of dbctx.GetQuerier(ctx, r.db). This meant queries ran outside the transaction with app.tenant_id set, causing RLS policies to not apply. All methods now use dbctx.GetQuerier to properly participate in the RLS transaction context.
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/application/services"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/dbctx"
|
||||
)
|
||||
|
||||
type magicLinkRepo struct {
|
||||
@@ -34,7 +35,7 @@ func (r *magicLinkRepo) CreateToken(ctx context.Context, token *models.MagicLink
|
||||
purpose = "login"
|
||||
}
|
||||
|
||||
return r.db.QueryRowContext(ctx, query,
|
||||
return dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, query,
|
||||
token.TenantID, // Can be NULL for login requests
|
||||
token.Token,
|
||||
token.Email,
|
||||
@@ -61,7 +62,7 @@ func (r *magicLinkRepo) GetByToken(ctx context.Context, token string) (*models.M
|
||||
var usedByIP, usedByUserAgent, docID sql.NullString
|
||||
var tenantID sql.NullString
|
||||
|
||||
err := r.db.QueryRowContext(ctx, query, token).Scan(
|
||||
err := dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, query, token).Scan(
|
||||
&t.ID,
|
||||
&tenantID,
|
||||
&t.Token,
|
||||
@@ -114,7 +115,7 @@ func (r *magicLinkRepo) MarkAsUsed(ctx context.Context, token string, ip string,
|
||||
WHERE token = $1 AND used_at IS NULL
|
||||
`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, token, ip, userAgent)
|
||||
result, err := dbctx.GetQuerier(ctx, r.db).ExecContext(ctx, query, token, ip, userAgent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -136,7 +137,7 @@ func (r *magicLinkRepo) DeleteExpired(ctx context.Context) (int64, error) {
|
||||
WHERE expires_at < now() OR (created_at < now() - INTERVAL '7 days' AND used_at IS NULL)
|
||||
`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query)
|
||||
result, err := dbctx.GetQuerier(ctx, r.db).ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -152,7 +153,7 @@ func (r *magicLinkRepo) LogAttempt(ctx context.Context, attempt *models.MagicLin
|
||||
RETURNING id, attempted_at
|
||||
`
|
||||
|
||||
return r.db.QueryRowContext(ctx, query,
|
||||
return dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, query,
|
||||
attempt.TenantID, // Can be NULL before authentication
|
||||
attempt.Email,
|
||||
attempt.Success,
|
||||
@@ -170,7 +171,7 @@ func (r *magicLinkRepo) CountRecentAttempts(ctx context.Context, email string, s
|
||||
WHERE email = $1 AND attempted_at > $2
|
||||
`
|
||||
|
||||
err := r.db.QueryRowContext(ctx, query, email, since).Scan(&count)
|
||||
err := dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, query, email, since).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
@@ -182,6 +183,6 @@ func (r *magicLinkRepo) CountRecentAttemptsByIP(ctx context.Context, ip string,
|
||||
WHERE ip_address = $1 AND attempted_at > $2
|
||||
`
|
||||
|
||||
err := r.db.QueryRowContext(ctx, query, ip, since).Scan(&count)
|
||||
err := dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, query, ip, since).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user