From 89128a60d221a41b807fa87f3aad505ed8cfb744 Mon Sep 17 00:00:00 2001 From: Taras Kushnir Date: Sun, 16 Nov 2025 13:47:59 +0200 Subject: [PATCH] Remove /twofactor endpoint and unite all login forms This is done to eliminate the redirect to /twofactor endpoint which potentially can land on another server in case load balancing is not sticky --- pkg/db/business_impl.go | 11 +++- pkg/db/session.go | 4 +- pkg/portal/csrf.go | 18 ++++-- pkg/portal/login.go | 28 ++++++--- pkg/portal/login_test.go | 2 +- pkg/portal/notification.go | 6 +- pkg/portal/org.go | 2 +- pkg/portal/register.go | 45 +++++++------- pkg/portal/render.go | 7 ++- pkg/portal/render_test.go | 16 +++-- pkg/portal/server.go | 3 +- pkg/portal/session_test.go | 17 +++++- pkg/portal/tests/helpers.go | 4 ++ pkg/portal/twofactor.go | 61 ++++++------------- pkg/portal/utils.go | 7 ++- pkg/session/common.go | 26 +++++++- pkg/session/manager.go | 45 ++++++++++++-- web/layouts/login/dashes.html | 26 ++++++++ web/layouts/login/login-contents.html | 10 +++ .../login/{form.html => login-form.html} | 0 web/layouts/login/login.html | 42 ++----------- web/layouts/login/register-contents.html | 11 ++++ .../form.html => login/register-form.html} | 2 +- .../{twofactor => login}/resend-error.html | 0 web/layouts/{twofactor => login}/resend.html | 0 web/layouts/login/scripts.html | 16 +++++ web/layouts/login/twofactor-contents.html | 11 ++++ .../form.html => login/twofactor-form.html} | 6 +- web/layouts/register/register.html | 59 ------------------ web/layouts/register/scripts.html | 12 ---- web/layouts/twofactor/scripts.html | 21 ------- web/layouts/twofactor/twofactor.html | 60 ------------------ 32 files changed, 280 insertions(+), 298 deletions(-) create mode 100644 web/layouts/login/dashes.html create mode 100644 web/layouts/login/login-contents.html rename web/layouts/login/{form.html => login-form.html} (100%) create mode 100644 web/layouts/login/register-contents.html rename web/layouts/{register/form.html => login/register-form.html} (97%) rename web/layouts/{twofactor => login}/resend-error.html (100%) rename web/layouts/{twofactor => login}/resend.html (100%) create mode 100644 web/layouts/login/twofactor-contents.html rename web/layouts/{twofactor/form.html => login/twofactor-form.html} (56%) delete mode 100644 web/layouts/register/register.html delete mode 100644 web/layouts/register/scripts.html delete mode 100644 web/layouts/twofactor/scripts.html delete mode 100644 web/layouts/twofactor/twofactor.html diff --git a/pkg/db/business_impl.go b/pkg/db/business_impl.go index f6d0d676..b42dfb51 100644 --- a/pkg/db/business_impl.go +++ b/pkg/db/business_impl.go @@ -361,11 +361,20 @@ func (impl *BusinessStoreImpl) CacheUserSession(ctx context.Context, data *sessi return impl.cache.Set(ctx, SessionCacheKey(data.ID()), data) } -func (impl *BusinessStoreImpl) RetrieveUserSession(ctx context.Context, sid string) (*session.SessionData, error) { +func (impl *BusinessStoreImpl) RetrieveUserSession(ctx context.Context, sid string, skipCache bool) (*session.SessionData, error) { if len(sid) == 0 { return nil, ErrInvalidInput } + if skipCache { + session, err := impl.doGetSessionbyID(ctx, sid) + if err == nil { + // yes, it's "skip READ from cache", not "skip cache ENTIRELY" + impl.cache.Set(ctx, SessionCacheKey(sid), session) + } + return session, err + } + reader := &StoreOneReader[string, session.SessionData]{ CacheKey: SessionCacheKey(sid), Cache: impl.cache, diff --git a/pkg/db/session.go b/pkg/db/session.go index 882fe482..d5ec3eff 100644 --- a/pkg/db/session.go +++ b/pkg/db/session.go @@ -51,8 +51,8 @@ func (ss *SessionStore) Init(ctx context.Context, session *session.Session) erro return ss.store.Impl().CacheUserSession(ctx, session.Data()) } -func (ss *SessionStore) Read(ctx context.Context, sid string) (*session.Session, error) { - sd, err := ss.store.Impl().RetrieveUserSession(ctx, sid) +func (ss *SessionStore) Read(ctx context.Context, sid string, skipCache bool) (*session.Session, error) { + sd, err := ss.store.Impl().RetrieveUserSession(ctx, sid, skipCache) if err != nil { if (err == ErrNegativeCacheHit) || (err == ErrCacheMiss) { return nil, session.ErrSessionMissing diff --git a/pkg/portal/csrf.go b/pkg/portal/csrf.go index 9181bb0e..0409b082 100644 --- a/pkg/portal/csrf.go +++ b/pkg/portal/csrf.go @@ -18,7 +18,12 @@ func (s *Server) CreateCsrfContext(user *dbgen.User) CsrfRenderContext { } func (s *Server) csrfUserEmailKeyFunc(w http.ResponseWriter, r *http.Request) string { - sess := s.Sessions.SessionStart(w, r) + // we're using session Get (and not Start) because we don't save session anywhere + sess, ok := s.Sessions.SessionGet(r) + if !ok { + return "" + } + ctx := r.Context() userEmail, ok := sess.Get(ctx, session.KeyUserEmail).(string) if !ok { @@ -29,7 +34,12 @@ func (s *Server) csrfUserEmailKeyFunc(w http.ResponseWriter, r *http.Request) st } func (s *Server) csrfUserIDKeyFunc(w http.ResponseWriter, r *http.Request) string { - sess := s.Sessions.SessionStart(w, r) + // we're using session Get (and not Start) because we don't save session anywhere + sess, ok := s.Sessions.SessionGet(r) + if !ok { + return "" + } + ctx := r.Context() userID, ok := sess.Get(ctx, session.KeyUserID).(int32) if !ok { @@ -62,10 +72,10 @@ func (s *Server) csrf(keyFunc CsrfKeyFunc) alice.Constructor { next.ServeHTTP(w, r) return } else { - slog.WarnContext(ctx, "Failed to verify CSRF token") + slog.WarnContext(ctx, "Failed to verify CSRF token", "path", r.URL.Path, "method", r.Method, "userID", userID) } } else { - slog.WarnContext(ctx, "CSRF token is missing") + slog.WarnContext(ctx, "CSRF token is missing", "path", r.URL.Path, "method", r.Method) } common.Redirect(s.RelURL(common.ExpiredEndpoint), http.StatusUnauthorized, w, r) diff --git a/pkg/portal/login.go b/pkg/portal/login.go index 97d9676d..2624fd5b 100644 --- a/pkg/portal/login.go +++ b/pkg/portal/login.go @@ -19,9 +19,10 @@ const ( loginStepSignInVerify = 1 loginStepSignUpVerify = 2 loginStepCompleted = 3 - loginFormTemplate = "login/form.html" loginTemplate = "login/login.html" + loginContentsTemplate = "login/login-contents.html" captchaVerificationFailed = "Captcha verification failed." + twofactorContentsTemplate = "login/twofactor-contents.html" ) var ( @@ -31,8 +32,12 @@ var ( type loginRenderContext struct { CsrfRenderContext CaptchaRenderContext + Email string EmailError string + CodeError string + NameError string CanRegister bool + IsRegister bool } type portalPropertyOwnerSource struct { @@ -84,14 +89,14 @@ func (s *Server) postLogin(w http.ResponseWriter, r *http.Request) { if len(captchaSolution) == 0 { slog.WarnContext(ctx, "Captcha solution field is empty") data.CaptchaError = "You need to solve captcha to login." - s.render(w, r, loginFormTemplate, data) + s.render(w, r, loginContentsTemplate, data) return } payload, err := s.PuzzleEngine.ParseSolutionPayload(ctx, []byte(captchaSolution)) if err != nil { data.CaptchaError = captchaVerificationFailed - s.render(w, r, loginFormTemplate, data) + s.render(w, r, loginContentsTemplate, data) return } @@ -100,7 +105,7 @@ func (s *Server) postLogin(w http.ResponseWriter, r *http.Request) { if err != nil || !verifyResult.Success() { slog.ErrorContext(ctx, "Failed to verify captcha", "verify", verifyResult.Error.String(), common.ErrAttr(err)) data.CaptchaError = captchaVerificationFailed - s.render(w, r, loginFormTemplate, data) + s.render(w, r, loginContentsTemplate, data) return } @@ -108,7 +113,7 @@ func (s *Server) postLogin(w http.ResponseWriter, r *http.Request) { if err = checkmail.ValidateFormat(email); err != nil { slog.WarnContext(ctx, "Failed to validate email format", common.ErrAttr(err)) data.EmailError = "Email address is not valid." - s.render(w, r, loginFormTemplate, data) + s.render(w, r, loginContentsTemplate, data) return } @@ -116,11 +121,11 @@ func (s *Server) postLogin(w http.ResponseWriter, r *http.Request) { if err != nil { slog.WarnContext(ctx, "Failed to find user by email", "email", email, common.ErrAttr(err)) data.EmailError = "User with such email does not exist." - s.render(w, r, loginFormTemplate, data) + s.render(w, r, loginContentsTemplate, data) return } - sess := s.Sessions.SessionStart(w, r) + sess, _ := s.Sessions.SessionStart(w, r) if step, ok := sess.Get(ctx, session.KeyLoginStep).(int); ok { if step == loginStepCompleted { slog.DebugContext(ctx, "User seem to be already logged in", "email", email) @@ -145,6 +150,13 @@ func (s *Server) postLogin(w http.ResponseWriter, r *http.Request) { _ = sess.Set(session.KeyUserName, user.Name) _ = sess.Set(session.KeyTwoFactorCode, code) _ = sess.Set(session.KeyUserID, user.ID) + // this is needed in case we will be routed to another server that does not have our session in memory + // (previously we persisted ONLY logged in sessions, but if we're rerouted during login, it will break) + // this should be OK now because we verified that user is a registered user AND they solved captcha + _ = sess.Set(session.KeyPersistent, true) - common.Redirect(s.RelURL(common.TwoFactorEndpoint), http.StatusOK, w, r) + data.Token = s.XSRF.Token(email) + data.Email = common.MaskEmail(email, '*') + + s.render(w, r, twofactorContentsTemplate, data) } diff --git a/pkg/portal/login_test.go b/pkg/portal/login_test.go index cf7e05bf..593b0d91 100644 --- a/pkg/portal/login_test.go +++ b/pkg/portal/login_test.go @@ -139,7 +139,7 @@ func TestPostLogin(t *testing.T) { rr = httptest.NewRecorder() server.postLogin(rr, req) - if rr.Code != http.StatusSeeOther { + if rr.Code != http.StatusOK { t.Errorf("Unexpected post login code: %v", rr.Code) } diff --git a/pkg/portal/notification.go b/pkg/portal/notification.go index d655a8c7..4833ba2d 100644 --- a/pkg/portal/notification.go +++ b/pkg/portal/notification.go @@ -24,7 +24,11 @@ func (s *Server) createSystemNotificationContext(ctx context.Context, sess *sess func (s *Server) dismissNotification(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - sess := s.Sessions.SessionStart(w, r) + sess, found := s.Sessions.SessionGet(r) + if !found { + http.Error(w, "", http.StatusBadRequest) + return + } id, value, err := common.IntPathArg(r, common.ParamID, s.IDHasher) if err == nil { diff --git a/pkg/portal/org.go b/pkg/portal/org.go index 1fba8250..7c78cfdb 100644 --- a/pkg/portal/org.go +++ b/pkg/portal/org.go @@ -253,7 +253,7 @@ func (s *Server) createOrgDashboardContext(ctx context.Context, orgID int32, ses func (s *Server) getPortal(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - sess := s.Sessions.SessionStart(w, r) + sess := s.Session(w, r) orgID, _, err := common.IntPathArg(r, common.ParamOrg, s.IDHasher) if err != nil { diff --git a/pkg/portal/register.go b/pkg/portal/register.go index 74519083..8b339c8d 100644 --- a/pkg/portal/register.go +++ b/pkg/portal/register.go @@ -24,29 +24,22 @@ var ( ) const ( - registerFormTemplate = "register/form.html" - registerTemplate = "register/register.html" - userNameErrorMessage = "Name contains invalid characters." + registerContentsTemplate = "login/register-contents.html" + userNameErrorMessage = "Name contains invalid characters." ) -type registerRenderContext struct { - CsrfRenderContext - CaptchaRenderContext - NameError string - EmailError string -} - func (s *Server) getRegister(w http.ResponseWriter, r *http.Request) (Model, string, error) { if !s.canRegister.Load() { return nil, "", errRegistrationDisabled } - return ®isterRenderContext{ + return &loginRenderContext{ CsrfRenderContext: CsrfRenderContext{ Token: s.XSRF.Token(""), }, CaptchaRenderContext: s.CreateCaptchaRenderContext(db.PortalRegisterSitekey), - }, registerTemplate, nil + IsRegister: true, + }, loginTemplate, nil } func isUserNameValid(name string) bool { @@ -88,15 +81,16 @@ func (s *Server) postRegister(w http.ResponseWriter, r *http.Request) { return } - data := ®isterRenderContext{ + data := &loginRenderContext{ CsrfRenderContext: CsrfRenderContext{ Token: s.XSRF.Token(""), }, CaptchaRenderContext: s.CreateCaptchaRenderContext(db.PortalRegisterSitekey), + IsRegister: true, } if _, termsAndConditions := r.Form[common.ParamTerms]; !termsAndConditions { - // it's error because they are marked 'required' on the frontend, so something went terribly wrong + // it's an error because they are marked 'required' on the frontend, so something went terribly wrong slog.ErrorContext(ctx, "Terms and conditions were not accepted") s.RedirectError(http.StatusBadRequest, w, r) return @@ -106,14 +100,14 @@ func (s *Server) postRegister(w http.ResponseWriter, r *http.Request) { if len(captchaSolution) == 0 { slog.WarnContext(ctx, "Captcha solution field is empty") data.CaptchaError = "You need to solve captcha to register." - s.render(w, r, registerFormTemplate, data) + s.render(w, r, registerContentsTemplate, data) return } payload, err := s.PuzzleEngine.ParseSolutionPayload(ctx, []byte(captchaSolution)) if err != nil { data.CaptchaError = captchaVerificationFailed - s.render(w, r, registerFormTemplate, data) + s.render(w, r, registerContentsTemplate, data) return } @@ -122,20 +116,20 @@ func (s *Server) postRegister(w http.ResponseWriter, r *http.Request) { if err != nil || !verifyResult.Success() { slog.ErrorContext(ctx, "Failed to verify captcha", "errors", verifyResult.Error.String(), common.ErrAttr(err)) data.CaptchaError = captchaVerificationFailed - s.render(w, r, registerFormTemplate, data) + s.render(w, r, registerContentsTemplate, data) return } name := strings.TrimSpace(r.FormValue(common.ParamName)) if len(name) < 3 { data.NameError = "Please use a longer name." - s.render(w, r, registerFormTemplate, data) + s.render(w, r, registerContentsTemplate, data) return } if !isUserNameValid(name) { data.NameError = userNameErrorMessage - s.render(w, r, registerFormTemplate, data) + s.render(w, r, registerContentsTemplate, data) return } @@ -143,14 +137,14 @@ func (s *Server) postRegister(w http.ResponseWriter, r *http.Request) { if err := checkmail.ValidateFormat(email); err != nil { slog.WarnContext(ctx, "Failed to validate email format", common.ErrAttr(err)) data.EmailError = "Email address is not valid." - s.render(w, r, registerFormTemplate, data) + s.render(w, r, registerContentsTemplate, data) return } if _, err := s.Store.Impl().FindUserByEmail(ctx, email); err == nil { slog.WarnContext(ctx, "User with such email already exists", "email", email) data.EmailError = "Such email is already registered. Login instead?" - s.render(w, r, registerFormTemplate, data) + s.render(w, r, registerContentsTemplate, data) return } @@ -163,17 +157,22 @@ func (s *Server) postRegister(w http.ResponseWriter, r *http.Request) { return } - sess := s.Sessions.SessionStart(w, r) + sess, _ := s.Sessions.SessionStart(w, r) ctx = context.WithValue(ctx, common.SessionIDContextKey, sess.ID()) _ = sess.Set(session.KeyLoginStep, loginStepSignUpVerify) _ = sess.Set(session.KeyUserEmail, email) _ = sess.Set(session.KeyUserName, name) _ = sess.Set(session.KeyTwoFactorCode, code) + // see comment in postLogin() why we have to use persistent here (although "registered user" argument does not apply) + _ = sess.Set(session.KeyPersistent, true) + + data.Token = s.XSRF.Token(email) + data.Email = common.MaskEmail(email, '*') slog.DebugContext(ctx, "Started 2FA registration flow", "email", email) - common.Redirect(s.RelURL(common.TwoFactorEndpoint), http.StatusOK, w, r) + s.render(w, r, twofactorContentsTemplate, data) } func createInternalTrial(plan billing.Plan, status string) *dbgen.CreateSubscriptionParams { diff --git a/pkg/portal/render.go b/pkg/portal/render.go index b679e832..298089b6 100644 --- a/pkg/portal/render.go +++ b/pkg/portal/render.go @@ -153,9 +153,10 @@ func (s *Server) render(w http.ResponseWriter, r *http.Request, name string, dat CDN: s.CDNURL, } - sess := s.Sessions.SessionStart(w, r) - if username, ok := sess.Get(ctx, session.KeyUserName).(string); ok { - reqCtx.UserName = username + if sess, found := s.Sessions.SessionGet(r); found { + if username, ok := sess.Get(ctx, session.KeyUserName).(string); ok { + reqCtx.UserName = username + } } out, err := s.RenderResponse(ctx, name, data, reqCtx) diff --git a/pkg/portal/render_test.go b/pkg/portal/render_test.go index f482846f..b98eddd9 100644 --- a/pkg/portal/render_test.go +++ b/pkg/portal/render_test.go @@ -78,14 +78,20 @@ func TestRenderHTML(t *testing.T) { model: &loginRenderContext{CsrfRenderContext: stubToken()}, }, { - path: []string{common.TwoFactorEndpoint}, - template: twofactorTemplate, - model: &twoFactorRenderContext{CsrfRenderContext: stubToken(), Email: "foo@bar.com"}, + path: []string{common.LoginEndpoint}, + template: twofactorContentsTemplate, + model: &loginRenderContext{CsrfRenderContext: stubToken(), Email: "foo@bar.com"}, }, { path: []string{common.RegisterEndpoint}, - template: registerTemplate, - model: ®isterRenderContext{CsrfRenderContext: stubToken()}, + template: loginTemplate, + model: &loginRenderContext{CsrfRenderContext: stubToken(), IsRegister: true}, + }, + // technically this is not needed (copy of the above), but it's an insurance against typos in case IsRegister will change + { + path: []string{common.RegisterEndpoint}, + template: registerContentsTemplate, + model: &loginRenderContext{CsrfRenderContext: stubToken(), IsRegister: true}, }, { path: []string{common.OrgEndpoint, common.NewEndpoint}, diff --git a/pkg/portal/server.go b/pkg/portal/server.go index ad6ba52f..81dbbbe9 100644 --- a/pkg/portal/server.go +++ b/pkg/portal/server.go @@ -252,7 +252,6 @@ func (s *Server) setupWithPrefix(router *http.ServeMux, rg *RouteGenerator, secu openRead := public.Append(s.maintenance, publicTimeout) router.Handle(rg.Get(common.LoginEndpoint), openRead.Then(common.Cached(s.Handler(s.getLogin)))) router.Handle(rg.Get(common.RegisterEndpoint), openRead.Then(common.Cached(s.Handler(s.getRegister)))) - router.Handle(rg.Get(common.TwoFactorEndpoint), openRead.ThenFunc(s.getTwoFactor)) router.Handle(rg.Get(common.ErrorEndpoint, arg(common.ParamCode)), public.ThenFunc(s.error)) router.Handle(rg.Get(common.ExpiredEndpoint), public.ThenFunc(s.expired)) router.Handle(rg.Get(common.LogoutEndpoint), public.ThenFunc(s.logout)) @@ -382,7 +381,7 @@ func (s *Server) private(next http.Handler) http.Handler { ) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - sess := s.Sessions.SessionStart(w, r) + sess, _ := s.Sessions.SessionStart(w, r) ctx := r.Context() ctx = context.WithValue(ctx, common.SessionIDContextKey, sess.ID()) diff --git a/pkg/portal/session_test.go b/pkg/portal/session_test.go index b937d632..0a861bae 100644 --- a/pkg/portal/session_test.go +++ b/pkg/portal/session_test.go @@ -17,7 +17,10 @@ func setupSessionSuite(ctx context.Context, manager *session.Manager, t *testing req := httptest.NewRequest("GET", "/settings", nil) w := httptest.NewRecorder() - sess := manager.SessionStart(w, req) + sess, started := manager.SessionStart(w, req) + if !started { + t.Error("session was not started") + } sess.Set(session.KeyUserName, t.Name()) sess.Set(session.KeyPersistent, true) @@ -77,7 +80,11 @@ func TestPersistentSession(t *testing.T) { req2.AddCookie(cookie) w2 := httptest.NewRecorder() - sess2 := manager.SessionStart(w2, req2) + sess2, started := manager.SessionStart(w2, req2) + + if started { + t.Error("new session was started") + } if sess1.ID() != sess2.ID() { t.Errorf("New session ID (%v) is different from original (%v)", sess2.ID(), sess1.ID()) @@ -115,7 +122,11 @@ func TestDeleteSession(t *testing.T) { req3 := httptest.NewRequest("GET", "/about", nil) req3.AddCookie(cookie) w3 := httptest.NewRecorder() - sess2 := manager.SessionStart(w3, req3) + sess2, started := manager.SessionStart(w3, req3) + + if !started { + t.Error("new session was not started") + } if sess1.ID() != sess2.ID() { t.Errorf("New session ID (%v) is different from original (%v)", sess2.ID(), sess1.ID()) diff --git a/pkg/portal/tests/helpers.go b/pkg/portal/tests/helpers.go index 35bb88fb..55708996 100644 --- a/pkg/portal/tests/helpers.go +++ b/pkg/portal/tests/helpers.go @@ -151,6 +151,10 @@ func AuthenticateSuite(ctx context.Context, email string, srv *http.ServeMux, xs return nil, fmt.Errorf("unexpected post twofactor code: %v", w.Code) } + if location, _ := w.Result().Location(); location.String() != "/" { + return nil, fmt.Errorf("unexpected redirect: %v", location) + } + slog.Log(ctx, common.LevelTrace, "Looks like we are authenticated", "code", w.Code) return cookie, nil diff --git a/pkg/portal/twofactor.go b/pkg/portal/twofactor.go index 603b676c..4293a8f1 100644 --- a/pkg/portal/twofactor.go +++ b/pkg/portal/twofactor.go @@ -11,47 +11,10 @@ import ( "github.com/PrivateCaptcha/PrivateCaptcha/pkg/session" ) -const ( - twofactorTemplate = "twofactor/twofactor.html" -) - var ( renderContextNothing = struct{}{} ) -type twoFactorRenderContext struct { - CsrfRenderContext - Email string - Error string -} - -func (s *Server) getTwoFactor(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - sess := s.Sessions.SessionStart(w, r) - if step, ok := sess.Get(ctx, session.KeyLoginStep).(int); !ok || ((step != loginStepSignInVerify) && (step != loginStepSignUpVerify)) { - slog.WarnContext(ctx, "User session is not valid", "step", step, "found", ok) - common.Redirect(s.RelURL(common.LoginEndpoint), http.StatusUnauthorized, w, r) - return - } - - email, ok := sess.Get(ctx, session.KeyUserEmail).(string) - if !ok { - slog.ErrorContext(ctx, "Failed to get email from session") - common.Redirect(s.RelURL(common.LoginEndpoint), http.StatusUnauthorized, w, r) - return - } - - data := &twoFactorRenderContext{ - CsrfRenderContext: CsrfRenderContext{ - Token: s.XSRF.Token(email), - }, - Email: common.MaskEmail(email, '*'), - } - - s.render(w, r, twofactorTemplate, data) -} - func (s *Server) postTwoFactor(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -62,9 +25,19 @@ func (s *Server) postTwoFactor(w http.ResponseWriter, r *http.Request) { return } - sess := s.Sessions.SessionStart(w, r) + sess, started := s.Sessions.SessionStart(w, r) ctx = context.WithValue(ctx, common.SessionIDContextKey, sess.ID()) + // we start session ONLY when session cookie is empty or when DB explicitly returned read error + // so "random" POST request to /twofactor might mean we access it from another node without this session + if started { + slog.DebugContext(ctx, "Attempting to reread potential stale session from DB", "started", started) + if dbSess, err := s.Sessions.RetrieveSession(ctx, sess.ID()); err == nil { + slog.InfoContext(ctx, "Using DB session instead for two factor") + sess.Merge(dbSess) + } + } + step, ok := sess.Get(ctx, session.KeyLoginStep).(int) if !ok || ((step != loginStepSignInVerify) && (step != loginStepSignUpVerify)) { slog.WarnContext(ctx, "User session is not valid", "step", step) @@ -86,7 +59,7 @@ func (s *Server) postTwoFactor(w http.ResponseWriter, r *http.Request) { return } - data := &twoFactorRenderContext{ + data := &loginRenderContext{ CsrfRenderContext: CsrfRenderContext{ Token: s.XSRF.Token(email), }, @@ -95,9 +68,9 @@ func (s *Server) postTwoFactor(w http.ResponseWriter, r *http.Request) { formCode := r.FormValue(common.ParamVerificationCode) if enteredCode, err := strconv.Atoi(formCode); (err != nil) || (enteredCode != sentCode) { - data.Error = "Code is not valid." + data.CodeError = "Code is not valid." slog.WarnContext(ctx, "Code verification failed", "actual", formCode, "expected", sentCode, common.ErrAttr(err)) - s.render(w, r, "twofactor/form.html", data) + s.render(w, r, "login/twofactor-form.html", data) return } @@ -145,7 +118,7 @@ func (s *Server) postTwoFactor(w http.ResponseWriter, r *http.Request) { func (s *Server) resend2fa(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - sess := s.Sessions.SessionStart(w, r) + sess, _ := s.Sessions.SessionStart(w, r) if step, ok := sess.Get(ctx, session.KeyLoginStep).(int); !ok || ((step != loginStepSignInVerify) && (step != loginStepSignUpVerify)) { slog.WarnContext(ctx, "User session is not valid", "step", step) common.Redirect(s.RelURL(common.LoginEndpoint), http.StatusUnauthorized, w, r) @@ -164,10 +137,10 @@ func (s *Server) resend2fa(w http.ResponseWriter, r *http.Request) { if err := s.Mailer.SendTwoFactor(ctx, email, code, r.UserAgent(), location); err != nil { slog.ErrorContext(ctx, "Failed to send email message", common.ErrAttr(err)) - s.render(w, r, "twofactor/resend-error.html", renderContextNothing) + s.render(w, r, "login/resend-error.html", renderContextNothing) return } _ = sess.Set(session.KeyTwoFactorCode, code) - s.render(w, r, "twofactor/resend.html", renderContextNothing) + s.render(w, r, "login/resend.html", renderContextNothing) } diff --git a/pkg/portal/utils.go b/pkg/portal/utils.go index 24938fd8..313b3591 100644 --- a/pkg/portal/utils.go +++ b/pkg/portal/utils.go @@ -125,7 +125,12 @@ func (s *Server) Session(w http.ResponseWriter, r *http.Request) *session.Sessio sess, ok := ctx.Value(common.SessionContextKey).(*session.Session) if !ok || (sess == nil) { slog.ErrorContext(ctx, "Failed to get session from context") - sess = s.Sessions.SessionStart(w, r) + var found bool + sess, found = s.Sessions.SessionGet(r) + if !found || (sess == nil) { + slog.ErrorContext(ctx, "Failed to get started session") + sess, _ = s.Sessions.SessionStart(w, r) + } } return sess diff --git a/pkg/session/common.go b/pkg/session/common.go index 3a5c2b28..8f9b4000 100644 --- a/pkg/session/common.go +++ b/pkg/session/common.go @@ -104,6 +104,26 @@ func (sd *SessionData) UnmarshalBinary(data []byte) error { return nil } +func (sd *SessionData) Merge(from *SessionData) { + // Acquire locks in consistent order to prevent deadlock + first, second := sd, from + if sd.sid > from.sid { + first, second = from, sd + } + + first.lock.Lock() + defer first.lock.Unlock() + + second.lock.Lock() + defer second.lock.Unlock() + + for key, value := range from.values { + if _, ok := sd.values[key]; !ok { + sd.values[key] = value + } + } +} + func (sd *SessionData) ID() string { return sd.sid } @@ -148,6 +168,10 @@ func NewSession(data *SessionData, store Store) *Session { } } +func (s *Session) Merge(from *Session) { + s.data.Merge(from.data) +} + func (s *Session) Data() *SessionData { return s.data } @@ -180,7 +204,7 @@ func (s *Session) Delete(key SessionKey) error { type Store interface { Start(ctx context.Context, interval time.Duration) Init(ctx context.Context, session *Session) error - Read(ctx context.Context, sid string) (*Session, error) + Read(ctx context.Context, sid string, skipCache bool) (*Session, error) Update(session *Session) error Destroy(ctx context.Context, sid string) error } diff --git a/pkg/session/manager.go b/pkg/session/manager.go index 93e7b883..25f523c5 100644 --- a/pkg/session/manager.go +++ b/pkg/session/manager.go @@ -28,7 +28,32 @@ func (m *Manager) Init(svc string, path string, interval time.Duration) { m.Store.Start(context.WithValue(context.Background(), common.ServiceContextKey, svc), interval) } -func (m *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session *Session) { +func (m *Manager) SessionGet(r *http.Request) (*Session, bool) { + cookie, err := r.Cookie(m.CookieName) + if err != nil || cookie.Value == "" { + return nil, false + } + + sid, _ := url.QueryUnescape(cookie.Value) + sslog := slog.With(common.SessionIDAttr(sid)) + + ctx := r.Context() + sslog.Log(ctx, common.LevelTrace, "Session cookie found in the request for start", "path", r.URL.Path, "method", r.Method) + session, err := m.Store.Read(ctx, sid, false /*skip cache*/) + if err != nil { + level := slog.LevelWarn + if err != ErrSessionMissing { + level = slog.LevelError + } + sslog.Log(ctx, level, "Failed to read session from store", common.ErrAttr(err)) + + return nil, false + } + + return session, true +} + +func (m *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session *Session, started bool) { cookie, err := r.Cookie(m.CookieName) ctx := r.Context() if err != nil || cookie.Value == "" { @@ -36,6 +61,7 @@ func (m *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session sid := m.sessionID() sslog := slog.With(common.SessionIDAttr(sid)) session = NewSession(NewSessionData(sid), m.Store) + started = true sslog.DebugContext(ctx, "Registering new session", "path", r.URL.Path, "method", r.Method) if err = m.Store.Init(ctx, session); err != nil { sslog.ErrorContext(ctx, "Failed to register session", common.SessionIDAttr(sid), common.ErrAttr(err)) @@ -54,21 +80,28 @@ func (m *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session sid, _ := url.QueryUnescape(cookie.Value) sslog := slog.With(common.SessionIDAttr(sid)) sslog.Log(ctx, common.LevelTrace, "Session cookie found in the request for start", "path", r.URL.Path, "method", r.Method) - session, err = m.Store.Read(ctx, sid) - if err == ErrSessionMissing { - sslog.WarnContext(ctx, "Session from cookie is missing") + session, err = m.Store.Read(ctx, sid, false /*skip cache*/) + if err != nil { + level := slog.LevelWarn + if err != ErrSessionMissing { + level = slog.LevelError + } + sslog.Log(ctx, level, "Failed to read session from store", common.ErrAttr(err)) session = NewSession(NewSessionData(sid), m.Store) + started = true sslog.DebugContext(ctx, "Registering new session", "path", r.URL.Path, "method", r.Method) if err = m.Store.Init(ctx, session); err != nil { sslog.ErrorContext(ctx, "Failed to register session with existing cookie", common.ErrAttr(err)) } - } else if err != nil { - sslog.ErrorContext(ctx, "Failed to read session from store", common.ErrAttr(err)) } } return } +func (m *Manager) RetrieveSession(ctx context.Context, sid string) (*Session, error) { + return m.Store.Read(ctx, sid, true /*skip cache*/) +} + func (m *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie(m.CookieName) if err != nil || cookie.Value == "" { diff --git a/web/layouts/login/dashes.html b/web/layouts/login/dashes.html new file mode 100644 index 00000000..9c7999ee --- /dev/null +++ b/web/layouts/login/dashes.html @@ -0,0 +1,26 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/layouts/login/login-contents.html b/web/layouts/login/login-contents.html new file mode 100644 index 00000000..4a724cf7 --- /dev/null +++ b/web/layouts/login/login-contents.html @@ -0,0 +1,10 @@ +
+

Sign in

+ {{ if .Params.CanRegister }}

Don’t have an account? Join now

{{ end }} +
+ +
+ {{template "login-form.html" .}} +
+ +{{ template "dashes.html" . }} diff --git a/web/layouts/login/form.html b/web/layouts/login/login-form.html similarity index 100% rename from web/layouts/login/form.html rename to web/layouts/login/login-form.html diff --git a/web/layouts/login/login.html b/web/layouts/login/login.html index 72fb9d90..d7e07ac1 100644 --- a/web/layouts/login/login.html +++ b/web/layouts/login/login.html @@ -13,42 +13,12 @@
-
-
-

Sign in

- {{ if .Params.CanRegister }}

Don’t have an account? Join now

{{ end }} -
- -
- {{template "form.html" .}} -
- - - - - - - - - - - - - - - - - - - - - - - - - - - +
+ {{ if .Params.IsRegister }} + {{ template "register-contents.html" . }} + {{ else }} + {{ template "login-contents.html" . }} + {{ end }}
diff --git a/web/layouts/login/register-contents.html b/web/layouts/login/register-contents.html new file mode 100644 index 00000000..a861192b --- /dev/null +++ b/web/layouts/login/register-contents.html @@ -0,0 +1,11 @@ +
+

Sign up

+ +

Already joined? Login now

+
+ +
+ {{template "register-form.html" .}} +
+ +{{ template "dashes.html" . }} diff --git a/web/layouts/register/form.html b/web/layouts/login/register-form.html similarity index 97% rename from web/layouts/register/form.html rename to web/layouts/login/register-form.html index 84130c9d..727dcbbd 100644 --- a/web/layouts/register/form.html +++ b/web/layouts/login/register-form.html @@ -52,7 +52,7 @@
-
diff --git a/web/layouts/register/register.html b/web/layouts/register/register.html deleted file mode 100644 index 8c0e8e7b..00000000 --- a/web/layouts/register/register.html +++ /dev/null @@ -1,59 +0,0 @@ -{{template "base.html" .}} - -{{define "title"}}Sign up{{end}} - -{{define "header"}}{{template "header-signed-out" .}}{{end}} -{{define "footer"}}{{template "footer-signed-out" .}}{{end}} - -{{define "body_class"}}pc-vertical-stretch{{end}} - -{{define "main"}} -
-
-
-
-
-
-
-

Sign up

- -

Already joined? Login now

-
- -
- {{template "form.html" .}} -
- - - - - - - - - - - - - - - - - - - - - - - - - - - -
-
-
-
-
-
-{{end}} diff --git a/web/layouts/register/scripts.html b/web/layouts/register/scripts.html deleted file mode 100644 index 5c6e985f..00000000 --- a/web/layouts/register/scripts.html +++ /dev/null @@ -1,12 +0,0 @@ -{{define "scripts"}} -{{template "default-scripts.html" .}} - - -{{end}} diff --git a/web/layouts/twofactor/scripts.html b/web/layouts/twofactor/scripts.html deleted file mode 100644 index b5a942d0..00000000 --- a/web/layouts/twofactor/scripts.html +++ /dev/null @@ -1,21 +0,0 @@ -{{define "scripts"}} -{{template "default-scripts.html" .}} - -{{end}} diff --git a/web/layouts/twofactor/twofactor.html b/web/layouts/twofactor/twofactor.html deleted file mode 100644 index c190ed89..00000000 --- a/web/layouts/twofactor/twofactor.html +++ /dev/null @@ -1,60 +0,0 @@ -{{template "base.html" .}} - -{{define "title"}}Verify{{end}} - -{{define "header"}}{{template "header-signed-out" .}}{{end}} -{{define "footer"}}{{template "footer-signed-out" .}}{{end}} - -{{define "body_class"}}pc-vertical-stretch{{end}} - -{{define "main"}} -
-
-
-
-
-
-
-

Verify your account

- -

Back to Login

-
- -
- {{template "form.html" .}} -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - -
-
-
-
-
-
-{{end}}