[server][auth] Use struct instead of interface

This commit is contained in:
Abhishek Shroff
2025-07-20 15:11:07 +05:30
parent 75c0a642d9
commit f120f7ecbd
13 changed files with 89 additions and 83 deletions

View File

@@ -18,12 +18,12 @@ var errCredentialsInvalid = core.NewError(http.StatusUnauthorized, "credentials_
const keyAuth = "auth"
func GetAuth(c *gin.Context) auth.Auth {
func GetAuth(c *gin.Context) *auth.Auth {
val, ok := c.Get(keyAuth)
if !ok {
return nil
}
return val.(auth.Auth)
return val.(*auth.Auth)
}
func GetFileSystem(c *gin.Context) *core.FileSystem {
return GetAuth(c).GetFileSystem(db.Get(c), pgtype.UUID{})
@@ -37,7 +37,7 @@ func Require(c *gin.Context) {
}
}
func getAuth(c *gin.Context) (auth.Auth, error) {
func getAuth(c *gin.Context) (*auth.Auth, error) {
db := db.Get(c.Request.Context())
if header := c.Request.Header.Get("Authorization"); header == "" {
if cookie, err := c.Request.Cookie("api_token"); err == nil {
@@ -56,7 +56,7 @@ func getAuth(c *gin.Context) (auth.Auth, error) {
return nil, errCredentialsInvalid
} else if authHeader, ok := checkAuthHeader(header, "basic"); ok {
if keyIDStr, keyStr, ok := decodeBasicAuth(authHeader); ok {
var a auth.Auth
var a *auth.Auth
var err error
if keyIDStr == "" {

View File

@@ -27,7 +27,7 @@ func handlePasswordLogin(c *gin.Context) {
panic(core.NewError(http.StatusBadRequest, "missing_params", "Email or password not specified"))
}
loginAndSendResponse(c, func(db db.TxHandler) (auth.Auth, string, error) {
loginAndSendResponse(c, func(db db.TxHandler) (*auth.Auth, string, error) {
return auth.PerformPasswordLogin(db, params.Email, params.Password)
})
}
@@ -92,7 +92,7 @@ func handlePasswordReset(c *gin.Context) {
panic(core.NewError(http.StatusBadRequest, "missing_params", "Missing Parameters"))
}
loginAndSendResponse(c, func(db db.TxHandler) (auth.Auth, string, error) {
loginAndSendResponse(c, func(db db.TxHandler) (*auth.Auth, string, error) {
return auth.ResetUserPassword(db, params.Email, params.Token, params.Password)
})
}

View File

@@ -79,12 +79,12 @@ func getInstanceURL(req *http.Request, path string) string {
return uri.String()
}
func loginAndSendResponse(c *gin.Context, loginFn func(db.TxHandler) (auth.Auth, string, error)) {
func loginAndSendResponse(c *gin.Context, loginFn func(db.TxHandler) (*auth.Auth, string, error)) {
var err error
var key string
var response responses.Bootstrap
if err := db.RunInTx(c.Request.Context(), func(db db.TxHandler) error {
var a auth.Auth
var a *auth.Auth
if a, key, err = loginFn(db); err != nil {
if errors.Is(err, auth.ErrCredentialsInvalid) {
return core.NewError(http.StatusUnauthorized, "credentials_invalid", "invalid credentials")

View File

@@ -20,7 +20,7 @@ func handleTokenLogin(c *gin.Context) {
panic(core.NewError(http.StatusBadRequest, "missing_params", "login token not specified"))
}
loginAndSendResponse(c, func(db db.TxHandler) (auth.Auth, string, error) {
loginAndSendResponse(c, func(db db.TxHandler) (*auth.Auth, string, error) {
return auth.PerformTokenLogin(db, params.Token)
})
}

View File

@@ -32,7 +32,7 @@ func handleBootstrapRoute(c *gin.Context) {
}
}
func Bootstrap(ctx context.Context, auth auth.Auth, since int64) (responses.Bootstrap, error) {
func Bootstrap(ctx context.Context, auth *auth.Auth, since int64) (responses.Bootstrap, error) {
if !auth.HasScope("bookmarks:list") || !auth.HasScope("users:list") || !auth.HasScope("profile:read") {
return responses.Bootstrap{}, core.ErrInsufficientScope
}

View File

@@ -82,7 +82,9 @@ func handleKeysGenerate(c *gin.Context) {
expires.Time = params.Expires
expires.Valid = true
}
if id, key, token, err := auth.GenerateAPIKey(db.Get(c.Request.Context()), a.UserID(), expires, params.Description, params.Scopes); err != nil {
if a, err := auth.NewAuth(a, expires, params.Scopes); err != nil {
panic(err)
} else if id, key, token, err := auth.GenerateAPIKey(db.Get(c.Request.Context()), a, params.Description); err != nil {
panic(err)
} else {
c.JSON(200, gin.H{

View File

@@ -44,7 +44,7 @@ func (h *handler) HandleRequest(c *gin.Context) {
if keyIDStr, keyStr, ok := c.Request.BasicAuth(); ok {
ctx := c.Request.Context()
db := db.Get(ctx)
var a auth.Auth
var a *auth.Auth
var err error
if keyIDStr == "" {

View File

@@ -63,7 +63,7 @@ func scanAPIKey(row pgx.CollectableRow) (APIKey, error) {
return apiKey, err
}
func ReadAPIToken(db db.Handler, encodedKey string) (Auth, error) {
func ReadAPIToken(db db.Handler, encodedKey string) (*Auth, error) {
if b, err := b64Encoder.DecodeString(encodedKey); err != nil {
return nil, ErrCredentialsInvalid
} else if len(b) < 16 {
@@ -74,15 +74,7 @@ func ReadAPIToken(db db.Handler, encodedKey string) (Auth, error) {
}
}
func GenerateAPIKey(db db.Handler, userID int32, expires pgtype.Timestamptz, description string, scopes []string) (id, key, token string, err error) {
if id, key, err := generateAPIKey(db, userID, expires, description, scopes); err != nil {
return "", "", "", err
} else {
return id.String(), b32Encoder.EncodeToString(key), b64Encoder.EncodeToString(append(id[:], key...)), nil
}
}
func ReadAPIKey(db db.Handler, keyIDStr, keyStr string) (Auth, error) {
func ReadAPIKey(db db.Handler, keyIDStr, keyStr string) (*Auth, error) {
if keyID, err := uuid.Parse(keyIDStr); err != nil {
return nil, err
} else if key, err := b32Encoder.DecodeString(keyStr); err != nil {
@@ -92,39 +84,40 @@ func ReadAPIKey(db db.Handler, keyIDStr, keyStr string) (Auth, error) {
}
}
func readAPIKey(db db.Handler, keyID uuid.UUID, key []byte) (Auth, error) {
func readAPIKey(db db.Handler, keyID uuid.UUID, key []byte) (auth *Auth, err error) {
const q = `SELECT k.expires, u.id, u.permissions, u.home, k.scopes FROM api_keys k JOIN users u ON k.user_id = u.id WHERE k.id = $1 AND k.hash = $2`
hash := sha256.Sum256(key)
row := db.QueryRow(q, keyID, hash[:])
var expires pgtype.Timestamp
var auth auth
err := row.Scan(&expires, &auth.userID, &auth.userPermissions, &auth.homeID, &auth.scopes)
auth = new(Auth)
err = row.Scan(&auth.expires, &auth.userID, &auth.userPermissions, &auth.homeID, &auth.scopes)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
err = ErrCredentialsInvalid
}
return nil, err
} else if expires.Valid && time.Now().After(expires.Time) {
return nil, ErrCredentialsInvalid
} else if auth.expires.Valid && time.Now().After(auth.expires.Time) {
err = ErrCredentialsInvalid
}
return auth, nil
return
}
func generateAPIKey(db db.Handler, userID int32, expires pgtype.Timestamptz, description string, scopes []string) (id uuid.UUID, key []byte, err error) {
func GenerateAPIKey(db db.Handler, auth *Auth, description string) (string, string, string, error) {
const q = `INSERT INTO api_keys(id, expires, user_id, hash, description, scopes) VALUES (@id, @expires, @user_id, @hash, @description, @scopes)`
id, _ = uuid.NewV7()
key = generateSecureKey(apiKeyLength)
id, _ := uuid.NewV7()
key := generateSecureKey(apiKeyLength)
hash := sha256.Sum256(key)
args := pgx.NamedArgs{
"id": id,
"expires": expires,
"user_id": userID,
"expires": auth.expires,
"user_id": auth.userID,
"hash": hash[:],
"description": description,
"scopes": scopes,
"scopes": auth.scopes,
}
_, err = db.Exec(q, args)
return
if _, err := db.Exec(q, args); err != nil {
return "", "", "", err
}
return id.String(), b32Encoder.EncodeToString(key), b64Encoder.EncodeToString(append(id[:], key...)), nil
}

View File

@@ -1,6 +1,7 @@
package auth
import (
"errors"
"strings"
"codeberg.org/shroff/phylum/server/internal/core"
@@ -8,29 +9,57 @@ import (
"github.com/jackc/pgx/v5/pgtype"
)
type Auth interface {
UserID() int32
UserPermissions() core.UserPermissions
HasScope(scope string) bool
GetFileSystem(db db.Handler, rootOverride pgtype.UUID) *core.FileSystem
}
type auth struct {
type Auth struct {
userID int32
userPermissions core.UserPermissions
homeID pgtype.UUID // TODO: Make sure this is specified everywhere
homeID pgtype.UUID
expires pgtype.Timestamptz
scopes []string
}
func (a auth) UserID() int32 {
func NewSUAuth(user core.User) *Auth {
return &Auth{
userID: user.ID,
userPermissions: user.Permissions,
homeID: user.Home,
expires: pgtype.Timestamptz{},
scopes: []string{"*"},
}
}
func NewAuth(auth *Auth, expires pgtype.Timestamptz, scopes []string) (*Auth, error) {
if len(scopes) == 0 {
return nil, errors.New("must specify at least one scope")
}
// Make sure the generated key doesn't expire after the key that generated it
if auth.expires.Valid && (!expires.Valid || expires.Time.After(auth.expires.Time)) {
expires = auth.expires
}
for _, s := range scopes {
if !auth.HasScope(s) {
return nil, errors.New("cannot grant scopes that are not present")
}
}
return &Auth{
userID: auth.userID,
userPermissions: auth.userPermissions,
homeID: auth.homeID,
expires: expires,
scopes: scopes,
}, nil
}
func (a *Auth) UserID() int32 {
return a.userID
}
func (a auth) UserPermissions() core.UserPermissions {
func (a *Auth) UserPermissions() core.UserPermissions {
return a.userPermissions
}
func (a auth) HasScope(scope string) bool {
func (a *Auth) HasScope(scope string) bool {
for _, s := range a.scopes {
if s == "*" {
return true
@@ -39,7 +68,7 @@ func (a auth) HasScope(scope string) bool {
return false
}
func (a auth) GetFileSystem(db db.Handler, rootOverride pgtype.UUID) *core.FileSystem {
func (a *Auth) GetFileSystem(db db.Handler, rootOverride pgtype.UUID) *core.FileSystem {
pathRoot := rootOverride
if !pathRoot.Valid {
pathRoot = a.homeID

View File

@@ -6,7 +6,6 @@ import (
"codeberg.org/shroff/phylum/server/internal/core"
"codeberg.org/shroff/phylum/server/internal/db"
"github.com/jackc/pgx/v5/pgtype"
)
type PasswordBackend interface {
@@ -15,7 +14,7 @@ type PasswordBackend interface {
UpdateUserPassword(db db.Handler, email, password string) error
}
func VerifyUserPassword(d db.Handler, email, password string) (Auth, error) {
func VerifyUserPassword(d db.Handler, email, password string) (*Auth, error) {
email = strings.ToLower(email)
if b, err := passwordBackend.VerifyUserPassword(d, email, password); err != nil {
return nil, err
@@ -29,20 +28,14 @@ func VerifyUserPassword(d db.Handler, email, password string) (Auth, error) {
return err
})
}
auth := auth{
userID: user.ID,
homeID: user.Home,
userPermissions: user.Permissions,
scopes: []string{"*"},
}
return auth, err
return NewSUAuth(user), err
}
func PerformPasswordLogin(db db.TxHandler, email, password string) (auth Auth, apiToken string, err error) {
func PerformPasswordLogin(db db.TxHandler, email, password string) (auth *Auth, apiToken string, err error) {
if auth, err = VerifyUserPassword(db, email, password); err != nil {
return
} else {
_, _, apiToken, err = GenerateAPIKey(db, auth.UserID(), pgtype.Timestamptz{}, "Login - Password", []string{"*"})
_, _, apiToken, err = GenerateAPIKey(db, auth, "Login - Password")
return
}
}

View File

@@ -10,7 +10,6 @@ import (
"codeberg.org/shroff/phylum/server/internal/core"
"codeberg.org/shroff/phylum/server/internal/db"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
)
func CreateResetToken(db db.TxHandler, email string) (core.User, string, error) {
@@ -35,7 +34,7 @@ func CreateResetToken(db db.TxHandler, email string) (core.User, string, error)
}
func ResetUserPassword(db db.TxHandler, email, resetToken, password string) (a Auth, apiToken string, err error) {
func ResetUserPassword(db db.TxHandler, email, resetToken, password string) (auth *Auth, apiToken string, err error) {
if passwordBackend == nil || !passwordBackend.SupportsPasswordUpdate() {
err = errors.New("password update not supported")
return
@@ -76,13 +75,8 @@ func ResetUserPassword(db db.TxHandler, email, resetToken, password string) (a A
return
}
a = auth{
userID: user.ID,
homeID: user.Home,
userPermissions: user.Permissions,
scopes: []string{"*"},
}
_, _, apiToken, err = GenerateAPIKey(db, user.ID, pgtype.Timestamptz{}, "Login - Password Reset", []string{"*"})
auth = NewSUAuth(user)
_, _, apiToken, err = GenerateAPIKey(db, auth, "Login - Password Reset")
return
}

View File

@@ -10,10 +10,9 @@ import (
"codeberg.org/shroff/phylum/server/internal/db"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
)
func PerformTokenLogin(db db.TxHandler, encodedToken string) (Auth, string, error) {
func PerformTokenLogin(db db.TxHandler, encodedToken string) (*Auth, string, error) {
if b, err := b32Encoder.DecodeString(encodedToken); err != nil {
return nil, "", ErrCredentialsInvalid
} else if len(b) < 16 {
@@ -53,7 +52,7 @@ func CreateLoginToken(db db.TxHandler, email string) (core.User, string, error)
return user, b32Encoder.EncodeToString(append(tokenID[:], token...)), nil
}
func performTokenLogin(db db.TxHandler, tokenID uuid.UUID, token []byte) (a Auth, apiToken string, err error) {
func performTokenLogin(db db.TxHandler, tokenID uuid.UUID, token []byte) (auth *Auth, apiToken string, err error) {
const q = "DELETE FROM pending_logins WHERE token_id = @token_id AND token_hash = @token_hash AND user_id IS NOT NULL RETURNING user_id, expires"
hash := sha256.Sum256([]byte(token))
args := pgx.NamedArgs{
@@ -75,15 +74,9 @@ func performTokenLogin(db db.TxHandler, tokenID uuid.UUID, token []byte) (a Auth
return
} else if user, err = core.UserByID(db, userID); err != nil {
return
} else if _, _, apiToken, err = GenerateAPIKey(db, userID, pgtype.Timestamptz{}, "Login - Token", []string{"*"}); err != nil {
return
} else {
a = auth{
userID: user.ID,
homeID: user.Home,
userPermissions: user.Permissions,
scopes: []string{"*"},
}
auth = NewSUAuth(user)
_, _, apiToken, err = GenerateAPIKey(db, auth, "Login - Token")
return
}
}

View File

@@ -79,7 +79,9 @@ func setupGenerateCommand() *cobra.Command {
}
if err := db.Get(context.Background()).RunInTx(func(db db.TxHandler) error {
if id, key, token, err := auth.GenerateAPIKey(db, u.ID, expires, description, scopes); err != nil {
if a, err := auth.NewAuth(auth.NewSUAuth(*u), expires, scopes); err != nil {
return err
} else if id, key, token, err := auth.GenerateAPIKey(db, a, description); err != nil {
return err
} else {
fmt.Println("Key ID:", id)