From f120f7ecbd3bf7b9ee00b0fa600cf1868bab1c2b Mon Sep 17 00:00:00 2001 From: Abhishek Shroff Date: Sun, 20 Jul 2025 15:11:07 +0530 Subject: [PATCH] [server][auth] Use struct instead of interface --- .../api/authenticator/authenticator.go | 8 +-- server/internal/api/v1/auth/password.go | 4 +- server/internal/api/v1/auth/routes.go | 4 +- server/internal/api/v1/auth/token.go | 2 +- server/internal/api/v1/user/bootstrap.go | 2 +- server/internal/api/v1/user/keys.go | 4 +- server/internal/api/webdav/handler.go | 2 +- server/internal/auth/api_key.go | 45 +++++++-------- server/internal/auth/authorization.go | 55 ++++++++++++++----- server/internal/auth/password.go | 15 ++--- server/internal/auth/password_reset.go | 12 +--- server/internal/auth/token.go | 15 ++--- server/internal/command/user/keys/cmd.go | 4 +- 13 files changed, 89 insertions(+), 83 deletions(-) diff --git a/server/internal/api/authenticator/authenticator.go b/server/internal/api/authenticator/authenticator.go index c6234e55..3abdb402 100644 --- a/server/internal/api/authenticator/authenticator.go +++ b/server/internal/api/authenticator/authenticator.go @@ -18,12 +18,12 @@ var errCredentialsInvalid = core.NewError(http.StatusUnauthorized, "credentials_ const keyAuth = "auth" -func GetAuth(c *gin.Context) auth.Auth { +func GetAuth(c *gin.Context) *auth.Auth { val, ok := c.Get(keyAuth) if !ok { return nil } - return val.(auth.Auth) + return val.(*auth.Auth) } func GetFileSystem(c *gin.Context) *core.FileSystem { return GetAuth(c).GetFileSystem(db.Get(c), pgtype.UUID{}) @@ -37,7 +37,7 @@ func Require(c *gin.Context) { } } -func getAuth(c *gin.Context) (auth.Auth, error) { +func getAuth(c *gin.Context) (*auth.Auth, error) { db := db.Get(c.Request.Context()) if header := c.Request.Header.Get("Authorization"); header == "" { if cookie, err := c.Request.Cookie("api_token"); err == nil { @@ -56,7 +56,7 @@ func getAuth(c *gin.Context) (auth.Auth, error) { return nil, errCredentialsInvalid } else if authHeader, ok := checkAuthHeader(header, "basic"); ok { if keyIDStr, keyStr, ok := decodeBasicAuth(authHeader); ok { - var a auth.Auth + var a *auth.Auth var err error if keyIDStr == "" { diff --git a/server/internal/api/v1/auth/password.go b/server/internal/api/v1/auth/password.go index 93cabeea..9615a840 100644 --- a/server/internal/api/v1/auth/password.go +++ b/server/internal/api/v1/auth/password.go @@ -27,7 +27,7 @@ func handlePasswordLogin(c *gin.Context) { panic(core.NewError(http.StatusBadRequest, "missing_params", "Email or password not specified")) } - loginAndSendResponse(c, func(db db.TxHandler) (auth.Auth, string, error) { + loginAndSendResponse(c, func(db db.TxHandler) (*auth.Auth, string, error) { return auth.PerformPasswordLogin(db, params.Email, params.Password) }) } @@ -92,7 +92,7 @@ func handlePasswordReset(c *gin.Context) { panic(core.NewError(http.StatusBadRequest, "missing_params", "Missing Parameters")) } - loginAndSendResponse(c, func(db db.TxHandler) (auth.Auth, string, error) { + loginAndSendResponse(c, func(db db.TxHandler) (*auth.Auth, string, error) { return auth.ResetUserPassword(db, params.Email, params.Token, params.Password) }) } diff --git a/server/internal/api/v1/auth/routes.go b/server/internal/api/v1/auth/routes.go index a479021c..d815da79 100644 --- a/server/internal/api/v1/auth/routes.go +++ b/server/internal/api/v1/auth/routes.go @@ -79,12 +79,12 @@ func getInstanceURL(req *http.Request, path string) string { return uri.String() } -func loginAndSendResponse(c *gin.Context, loginFn func(db.TxHandler) (auth.Auth, string, error)) { +func loginAndSendResponse(c *gin.Context, loginFn func(db.TxHandler) (*auth.Auth, string, error)) { var err error var key string var response responses.Bootstrap if err := db.RunInTx(c.Request.Context(), func(db db.TxHandler) error { - var a auth.Auth + var a *auth.Auth if a, key, err = loginFn(db); err != nil { if errors.Is(err, auth.ErrCredentialsInvalid) { return core.NewError(http.StatusUnauthorized, "credentials_invalid", "invalid credentials") diff --git a/server/internal/api/v1/auth/token.go b/server/internal/api/v1/auth/token.go index 4f8ec575..e7a5716b 100644 --- a/server/internal/api/v1/auth/token.go +++ b/server/internal/api/v1/auth/token.go @@ -20,7 +20,7 @@ func handleTokenLogin(c *gin.Context) { panic(core.NewError(http.StatusBadRequest, "missing_params", "login token not specified")) } - loginAndSendResponse(c, func(db db.TxHandler) (auth.Auth, string, error) { + loginAndSendResponse(c, func(db db.TxHandler) (*auth.Auth, string, error) { return auth.PerformTokenLogin(db, params.Token) }) } diff --git a/server/internal/api/v1/user/bootstrap.go b/server/internal/api/v1/user/bootstrap.go index 910c124a..6ba44ab5 100644 --- a/server/internal/api/v1/user/bootstrap.go +++ b/server/internal/api/v1/user/bootstrap.go @@ -32,7 +32,7 @@ func handleBootstrapRoute(c *gin.Context) { } } -func Bootstrap(ctx context.Context, auth auth.Auth, since int64) (responses.Bootstrap, error) { +func Bootstrap(ctx context.Context, auth *auth.Auth, since int64) (responses.Bootstrap, error) { if !auth.HasScope("bookmarks:list") || !auth.HasScope("users:list") || !auth.HasScope("profile:read") { return responses.Bootstrap{}, core.ErrInsufficientScope } diff --git a/server/internal/api/v1/user/keys.go b/server/internal/api/v1/user/keys.go index c0f5cd6c..536c29af 100644 --- a/server/internal/api/v1/user/keys.go +++ b/server/internal/api/v1/user/keys.go @@ -82,7 +82,9 @@ func handleKeysGenerate(c *gin.Context) { expires.Time = params.Expires expires.Valid = true } - if id, key, token, err := auth.GenerateAPIKey(db.Get(c.Request.Context()), a.UserID(), expires, params.Description, params.Scopes); err != nil { + if a, err := auth.NewAuth(a, expires, params.Scopes); err != nil { + panic(err) + } else if id, key, token, err := auth.GenerateAPIKey(db.Get(c.Request.Context()), a, params.Description); err != nil { panic(err) } else { c.JSON(200, gin.H{ diff --git a/server/internal/api/webdav/handler.go b/server/internal/api/webdav/handler.go index 5be17006..8b97e192 100644 --- a/server/internal/api/webdav/handler.go +++ b/server/internal/api/webdav/handler.go @@ -44,7 +44,7 @@ func (h *handler) HandleRequest(c *gin.Context) { if keyIDStr, keyStr, ok := c.Request.BasicAuth(); ok { ctx := c.Request.Context() db := db.Get(ctx) - var a auth.Auth + var a *auth.Auth var err error if keyIDStr == "" { diff --git a/server/internal/auth/api_key.go b/server/internal/auth/api_key.go index 386e14fd..0ba3ee65 100644 --- a/server/internal/auth/api_key.go +++ b/server/internal/auth/api_key.go @@ -63,7 +63,7 @@ func scanAPIKey(row pgx.CollectableRow) (APIKey, error) { return apiKey, err } -func ReadAPIToken(db db.Handler, encodedKey string) (Auth, error) { +func ReadAPIToken(db db.Handler, encodedKey string) (*Auth, error) { if b, err := b64Encoder.DecodeString(encodedKey); err != nil { return nil, ErrCredentialsInvalid } else if len(b) < 16 { @@ -74,15 +74,7 @@ func ReadAPIToken(db db.Handler, encodedKey string) (Auth, error) { } } -func GenerateAPIKey(db db.Handler, userID int32, expires pgtype.Timestamptz, description string, scopes []string) (id, key, token string, err error) { - if id, key, err := generateAPIKey(db, userID, expires, description, scopes); err != nil { - return "", "", "", err - } else { - return id.String(), b32Encoder.EncodeToString(key), b64Encoder.EncodeToString(append(id[:], key...)), nil - } -} - -func ReadAPIKey(db db.Handler, keyIDStr, keyStr string) (Auth, error) { +func ReadAPIKey(db db.Handler, keyIDStr, keyStr string) (*Auth, error) { if keyID, err := uuid.Parse(keyIDStr); err != nil { return nil, err } else if key, err := b32Encoder.DecodeString(keyStr); err != nil { @@ -92,39 +84,40 @@ func ReadAPIKey(db db.Handler, keyIDStr, keyStr string) (Auth, error) { } } -func readAPIKey(db db.Handler, keyID uuid.UUID, key []byte) (Auth, error) { +func readAPIKey(db db.Handler, keyID uuid.UUID, key []byte) (auth *Auth, err error) { const q = `SELECT k.expires, u.id, u.permissions, u.home, k.scopes FROM api_keys k JOIN users u ON k.user_id = u.id WHERE k.id = $1 AND k.hash = $2` hash := sha256.Sum256(key) row := db.QueryRow(q, keyID, hash[:]) - var expires pgtype.Timestamp - var auth auth - err := row.Scan(&expires, &auth.userID, &auth.userPermissions, &auth.homeID, &auth.scopes) + auth = new(Auth) + err = row.Scan(&auth.expires, &auth.userID, &auth.userPermissions, &auth.homeID, &auth.scopes) if err != nil { if errors.Is(err, pgx.ErrNoRows) { err = ErrCredentialsInvalid } - return nil, err - } else if expires.Valid && time.Now().After(expires.Time) { - return nil, ErrCredentialsInvalid + } else if auth.expires.Valid && time.Now().After(auth.expires.Time) { + err = ErrCredentialsInvalid } - return auth, nil + return } -func generateAPIKey(db db.Handler, userID int32, expires pgtype.Timestamptz, description string, scopes []string) (id uuid.UUID, key []byte, err error) { +func GenerateAPIKey(db db.Handler, auth *Auth, description string) (string, string, string, error) { const q = `INSERT INTO api_keys(id, expires, user_id, hash, description, scopes) VALUES (@id, @expires, @user_id, @hash, @description, @scopes)` - id, _ = uuid.NewV7() - key = generateSecureKey(apiKeyLength) + id, _ := uuid.NewV7() + key := generateSecureKey(apiKeyLength) hash := sha256.Sum256(key) args := pgx.NamedArgs{ "id": id, - "expires": expires, - "user_id": userID, + "expires": auth.expires, + "user_id": auth.userID, "hash": hash[:], "description": description, - "scopes": scopes, + "scopes": auth.scopes, } - _, err = db.Exec(q, args) - return + + if _, err := db.Exec(q, args); err != nil { + return "", "", "", err + } + return id.String(), b32Encoder.EncodeToString(key), b64Encoder.EncodeToString(append(id[:], key...)), nil } diff --git a/server/internal/auth/authorization.go b/server/internal/auth/authorization.go index 3b2767a5..772ce2f3 100644 --- a/server/internal/auth/authorization.go +++ b/server/internal/auth/authorization.go @@ -1,6 +1,7 @@ package auth import ( + "errors" "strings" "codeberg.org/shroff/phylum/server/internal/core" @@ -8,29 +9,57 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) -type Auth interface { - UserID() int32 - UserPermissions() core.UserPermissions - HasScope(scope string) bool - GetFileSystem(db db.Handler, rootOverride pgtype.UUID) *core.FileSystem -} - -type auth struct { +type Auth struct { userID int32 userPermissions core.UserPermissions - homeID pgtype.UUID // TODO: Make sure this is specified everywhere + homeID pgtype.UUID + expires pgtype.Timestamptz scopes []string } -func (a auth) UserID() int32 { +func NewSUAuth(user core.User) *Auth { + return &Auth{ + userID: user.ID, + userPermissions: user.Permissions, + homeID: user.Home, + expires: pgtype.Timestamptz{}, + scopes: []string{"*"}, + } +} + +func NewAuth(auth *Auth, expires pgtype.Timestamptz, scopes []string) (*Auth, error) { + if len(scopes) == 0 { + return nil, errors.New("must specify at least one scope") + } + + // Make sure the generated key doesn't expire after the key that generated it + if auth.expires.Valid && (!expires.Valid || expires.Time.After(auth.expires.Time)) { + expires = auth.expires + } + for _, s := range scopes { + if !auth.HasScope(s) { + return nil, errors.New("cannot grant scopes that are not present") + } + } + + return &Auth{ + userID: auth.userID, + userPermissions: auth.userPermissions, + homeID: auth.homeID, + expires: expires, + scopes: scopes, + }, nil +} + +func (a *Auth) UserID() int32 { return a.userID } -func (a auth) UserPermissions() core.UserPermissions { +func (a *Auth) UserPermissions() core.UserPermissions { return a.userPermissions } -func (a auth) HasScope(scope string) bool { +func (a *Auth) HasScope(scope string) bool { for _, s := range a.scopes { if s == "*" { return true @@ -39,7 +68,7 @@ func (a auth) HasScope(scope string) bool { return false } -func (a auth) GetFileSystem(db db.Handler, rootOverride pgtype.UUID) *core.FileSystem { +func (a *Auth) GetFileSystem(db db.Handler, rootOverride pgtype.UUID) *core.FileSystem { pathRoot := rootOverride if !pathRoot.Valid { pathRoot = a.homeID diff --git a/server/internal/auth/password.go b/server/internal/auth/password.go index b500188d..675317cd 100644 --- a/server/internal/auth/password.go +++ b/server/internal/auth/password.go @@ -6,7 +6,6 @@ import ( "codeberg.org/shroff/phylum/server/internal/core" "codeberg.org/shroff/phylum/server/internal/db" - "github.com/jackc/pgx/v5/pgtype" ) type PasswordBackend interface { @@ -15,7 +14,7 @@ type PasswordBackend interface { UpdateUserPassword(db db.Handler, email, password string) error } -func VerifyUserPassword(d db.Handler, email, password string) (Auth, error) { +func VerifyUserPassword(d db.Handler, email, password string) (*Auth, error) { email = strings.ToLower(email) if b, err := passwordBackend.VerifyUserPassword(d, email, password); err != nil { return nil, err @@ -29,20 +28,14 @@ func VerifyUserPassword(d db.Handler, email, password string) (Auth, error) { return err }) } - auth := auth{ - userID: user.ID, - homeID: user.Home, - userPermissions: user.Permissions, - scopes: []string{"*"}, - } - return auth, err + return NewSUAuth(user), err } -func PerformPasswordLogin(db db.TxHandler, email, password string) (auth Auth, apiToken string, err error) { +func PerformPasswordLogin(db db.TxHandler, email, password string) (auth *Auth, apiToken string, err error) { if auth, err = VerifyUserPassword(db, email, password); err != nil { return } else { - _, _, apiToken, err = GenerateAPIKey(db, auth.UserID(), pgtype.Timestamptz{}, "Login - Password", []string{"*"}) + _, _, apiToken, err = GenerateAPIKey(db, auth, "Login - Password") return } } diff --git a/server/internal/auth/password_reset.go b/server/internal/auth/password_reset.go index 0d2ed5f9..49535cfb 100644 --- a/server/internal/auth/password_reset.go +++ b/server/internal/auth/password_reset.go @@ -10,7 +10,6 @@ import ( "codeberg.org/shroff/phylum/server/internal/core" "codeberg.org/shroff/phylum/server/internal/db" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" ) func CreateResetToken(db db.TxHandler, email string) (core.User, string, error) { @@ -35,7 +34,7 @@ func CreateResetToken(db db.TxHandler, email string) (core.User, string, error) } -func ResetUserPassword(db db.TxHandler, email, resetToken, password string) (a Auth, apiToken string, err error) { +func ResetUserPassword(db db.TxHandler, email, resetToken, password string) (auth *Auth, apiToken string, err error) { if passwordBackend == nil || !passwordBackend.SupportsPasswordUpdate() { err = errors.New("password update not supported") return @@ -76,13 +75,8 @@ func ResetUserPassword(db db.TxHandler, email, resetToken, password string) (a A return } - a = auth{ - userID: user.ID, - homeID: user.Home, - userPermissions: user.Permissions, - scopes: []string{"*"}, - } - _, _, apiToken, err = GenerateAPIKey(db, user.ID, pgtype.Timestamptz{}, "Login - Password Reset", []string{"*"}) + auth = NewSUAuth(user) + _, _, apiToken, err = GenerateAPIKey(db, auth, "Login - Password Reset") return } diff --git a/server/internal/auth/token.go b/server/internal/auth/token.go index d4d886da..9dce0dc3 100644 --- a/server/internal/auth/token.go +++ b/server/internal/auth/token.go @@ -10,10 +10,9 @@ import ( "codeberg.org/shroff/phylum/server/internal/db" "github.com/google/uuid" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" ) -func PerformTokenLogin(db db.TxHandler, encodedToken string) (Auth, string, error) { +func PerformTokenLogin(db db.TxHandler, encodedToken string) (*Auth, string, error) { if b, err := b32Encoder.DecodeString(encodedToken); err != nil { return nil, "", ErrCredentialsInvalid } else if len(b) < 16 { @@ -53,7 +52,7 @@ func CreateLoginToken(db db.TxHandler, email string) (core.User, string, error) return user, b32Encoder.EncodeToString(append(tokenID[:], token...)), nil } -func performTokenLogin(db db.TxHandler, tokenID uuid.UUID, token []byte) (a Auth, apiToken string, err error) { +func performTokenLogin(db db.TxHandler, tokenID uuid.UUID, token []byte) (auth *Auth, apiToken string, err error) { const q = "DELETE FROM pending_logins WHERE token_id = @token_id AND token_hash = @token_hash AND user_id IS NOT NULL RETURNING user_id, expires" hash := sha256.Sum256([]byte(token)) args := pgx.NamedArgs{ @@ -75,15 +74,9 @@ func performTokenLogin(db db.TxHandler, tokenID uuid.UUID, token []byte) (a Auth return } else if user, err = core.UserByID(db, userID); err != nil { return - } else if _, _, apiToken, err = GenerateAPIKey(db, userID, pgtype.Timestamptz{}, "Login - Token", []string{"*"}); err != nil { - return } else { - a = auth{ - userID: user.ID, - homeID: user.Home, - userPermissions: user.Permissions, - scopes: []string{"*"}, - } + auth = NewSUAuth(user) + _, _, apiToken, err = GenerateAPIKey(db, auth, "Login - Token") return } } diff --git a/server/internal/command/user/keys/cmd.go b/server/internal/command/user/keys/cmd.go index fc1a3b2e..49ae1b99 100644 --- a/server/internal/command/user/keys/cmd.go +++ b/server/internal/command/user/keys/cmd.go @@ -79,7 +79,9 @@ func setupGenerateCommand() *cobra.Command { } if err := db.Get(context.Background()).RunInTx(func(db db.TxHandler) error { - if id, key, token, err := auth.GenerateAPIKey(db, u.ID, expires, description, scopes); err != nil { + if a, err := auth.NewAuth(auth.NewSUAuth(*u), expires, scopes); err != nil { + return err + } else if id, key, token, err := auth.GenerateAPIKey(db, a, description); err != nil { return err } else { fmt.Println("Key ID:", id)