mirror of
https://github.com/PrivateCaptcha/PrivateCaptcha.git
synced 2026-02-11 08:18:52 -06:00
212 lines
3.8 KiB
Go
212 lines
3.8 KiB
Go
package session
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/gob"
|
|
"errors"
|
|
"log/slog"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/common"
|
|
)
|
|
|
|
var (
|
|
ErrSessionMissing = errors.New("session is missing")
|
|
)
|
|
|
|
type SessionKey int
|
|
|
|
const (
|
|
KeyLoginStep SessionKey = iota
|
|
KeyUserID
|
|
KeyUserEmail
|
|
KeyTwoFactorCode
|
|
KeyUserName
|
|
KeyPersistent
|
|
KeyNotificationID
|
|
KeyReturnURL
|
|
KeyTwoFactorCodeTimestamp
|
|
// Add new fields _above_
|
|
SESSION_KEYS_COUNT
|
|
)
|
|
|
|
func (key SessionKey) String() string {
|
|
switch key {
|
|
case KeyLoginStep:
|
|
return "LoginStep"
|
|
case KeyUserID:
|
|
return "UserID"
|
|
case KeyTwoFactorCode:
|
|
return "TwoFactorCode"
|
|
case KeyUserName:
|
|
return "UserName"
|
|
case KeyPersistent:
|
|
return "Persistent"
|
|
case KeyNotificationID:
|
|
return "NotificationID"
|
|
case KeyReturnURL:
|
|
return "ReturnURL"
|
|
default:
|
|
return "SessionKey"
|
|
}
|
|
}
|
|
|
|
type SessionValue = interface{}
|
|
|
|
type SessionData struct {
|
|
sid string
|
|
values map[SessionKey]SessionValue
|
|
lock sync.Mutex
|
|
}
|
|
|
|
func NewSessionData(sid string) *SessionData {
|
|
return &SessionData{
|
|
sid: sid,
|
|
values: make(map[SessionKey]SessionValue),
|
|
}
|
|
}
|
|
|
|
func (sd *SessionData) Size() int {
|
|
sd.lock.Lock()
|
|
defer sd.lock.Unlock()
|
|
return len(sd.values)
|
|
}
|
|
|
|
func (sd *SessionData) MarshalBinary() ([]byte, error) {
|
|
var buf bytes.Buffer
|
|
encoder := gob.NewEncoder(&buf)
|
|
|
|
sd.lock.Lock()
|
|
defer sd.lock.Unlock()
|
|
|
|
if err := encoder.Encode(sd.values); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return buf.Bytes(), nil
|
|
}
|
|
|
|
func (sd *SessionData) UnmarshalBinary(data []byte) error {
|
|
values := make(map[SessionKey]SessionValue, 0)
|
|
|
|
buf := bytes.NewBuffer(data)
|
|
decoder := gob.NewDecoder(buf)
|
|
|
|
if err := decoder.Decode(&values); err != nil {
|
|
return err
|
|
}
|
|
|
|
sd.lock.Lock()
|
|
sd.values = values
|
|
sd.lock.Unlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (sd *SessionData) Merge(from *SessionData) {
|
|
// Acquire locks in consistent order to prevent deadlock
|
|
first, second := sd, from
|
|
if sd.sid > from.sid {
|
|
first, second = from, sd
|
|
}
|
|
|
|
first.lock.Lock()
|
|
defer first.lock.Unlock()
|
|
|
|
second.lock.Lock()
|
|
defer second.lock.Unlock()
|
|
|
|
for key, value := range from.values {
|
|
if _, ok := sd.values[key]; !ok {
|
|
sd.values[key] = value
|
|
}
|
|
}
|
|
}
|
|
|
|
func (sd *SessionData) ID() string {
|
|
return sd.sid
|
|
}
|
|
|
|
func (sd *SessionData) Has(key SessionKey) bool {
|
|
sd.lock.Lock()
|
|
defer sd.lock.Unlock()
|
|
|
|
_, ok := sd.values[key]
|
|
return ok
|
|
}
|
|
|
|
func (sd *SessionData) set(key SessionKey, value SessionValue) {
|
|
sd.lock.Lock()
|
|
sd.values[key] = value
|
|
sd.lock.Unlock()
|
|
}
|
|
|
|
func (sd *SessionData) get(key SessionKey) (any, bool) {
|
|
sd.lock.Lock()
|
|
defer sd.lock.Unlock()
|
|
|
|
v, ok := sd.values[key]
|
|
return v, ok
|
|
}
|
|
|
|
func (sd *SessionData) delete(key SessionKey) {
|
|
sd.lock.Lock()
|
|
delete(sd.values, key)
|
|
sd.lock.Unlock()
|
|
}
|
|
|
|
type Session struct {
|
|
data *SessionData
|
|
store Store
|
|
}
|
|
|
|
func NewSession(data *SessionData, store Store) *Session {
|
|
return &Session{
|
|
data: data,
|
|
store: store,
|
|
}
|
|
}
|
|
|
|
func (s *Session) Merge(from *Session) {
|
|
s.data.Merge(from.data)
|
|
}
|
|
|
|
func (s *Session) Data() *SessionData {
|
|
return s.data
|
|
}
|
|
|
|
func (s *Session) Set(key SessionKey, value SessionValue) error {
|
|
s.data.set(key, value)
|
|
|
|
return s.store.Update(s)
|
|
}
|
|
|
|
func (s *Session) ID() string {
|
|
return s.data.ID()
|
|
}
|
|
|
|
func (s *Session) Get(ctx context.Context, key SessionKey) SessionValue {
|
|
v, ok := s.data.get(key)
|
|
if !ok {
|
|
slog.Log(ctx, common.LevelTrace, "Access to missing key in session", common.SessionIDAttr(s.data.ID()), "key", key.String())
|
|
}
|
|
|
|
return v
|
|
}
|
|
|
|
func (s *Session) Delete(key SessionKey) error {
|
|
s.data.delete(key)
|
|
|
|
return s.store.Update(s)
|
|
}
|
|
|
|
type Store interface {
|
|
Start(ctx context.Context, interval time.Duration)
|
|
Init(ctx context.Context, session *Session) error
|
|
Read(ctx context.Context, sid string, skipCache bool) (*Session, error)
|
|
Update(session *Session) error
|
|
Destroy(ctx context.Context, sid string) error
|
|
}
|