This commit is contained in:
d34dscene
2025-06-18 17:23:41 +02:00
parent 0e8af0e27c
commit d8f8afb2c6
5 changed files with 18 additions and 19 deletions

View File

@@ -106,7 +106,7 @@ func Authentication(app *config.App) connect.UnaryInterceptorFunc {
} }
// Add claims to context // Add claims to context
ctx = context.WithValue(ctx, AuthUserIDKey, claims.ID) ctx = context.WithValue(ctx, AuthUserIDKey, claims.UserID)
return next(ctx, req) return next(ctx, req)
} }
}) })

View File

@@ -2,6 +2,7 @@ package middlewares
import ( import (
"net/http" "net/http"
"time"
connectcors "connectrpc.com/cors" connectcors "connectrpc.com/cors"
"github.com/mizuchilabs/mantrae/internal/config" "github.com/mizuchilabs/mantrae/internal/config"
@@ -25,10 +26,10 @@ func WithCORS(h http.Handler, app *config.App, port string) http.Handler {
} }
return cors.New(cors.Options{ return cors.New(cors.Options{
AllowedOrigins: allowedOrigins, AllowedOrigins: allowedOrigins,
AllowedMethods: connectcors.AllowedMethods(), AllowedMethods: connectcors.AllowedMethods(),
AllowedHeaders: connectcors.AllowedHeaders(), AllowedHeaders: connectcors.AllowedHeaders(),
ExposedHeaders: connectcors.ExposedHeaders(), ExposedHeaders: connectcors.ExposedHeaders(),
AllowCredentials: true, MaxAge: int(2 * time.Hour / time.Second),
}).Handler(h) }).Handler(h)
} }

View File

@@ -3,7 +3,6 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
"io/fs"
"log" "log"
"log/slog" "log/slog"
"net/http" "net/http"
@@ -19,7 +18,6 @@ import (
"github.com/mizuchilabs/mantrae/internal/api/service" "github.com/mizuchilabs/mantrae/internal/api/service"
"github.com/mizuchilabs/mantrae/internal/config" "github.com/mizuchilabs/mantrae/internal/config"
"github.com/mizuchilabs/mantrae/proto/gen/mantrae/v1/mantraev1connect" "github.com/mizuchilabs/mantrae/proto/gen/mantrae/v1/mantraev1connect"
"github.com/mizuchilabs/mantrae/web"
) )
const elementsHTML = `<!DOCTYPE html> const elementsHTML = `<!DOCTYPE html>
@@ -127,11 +125,11 @@ func (s *Server) registerServices() {
} }
// Static files // Static files
staticContent, err := fs.Sub(web.StaticFS, "build") // staticContent, err := fs.Sub(web.StaticFS, "build")
if err != nil { // if err != nil {
log.Fatal(err) // log.Fatal(err)
} // }
s.mux.Handle("/", http.FileServer(http.FS(staticContent))) // s.mux.Handle("/", http.FileServer(http.FS(staticContent)))
serviceNames := []string{ serviceNames := []string{
mantraev1connect.ProfileServiceName, mantraev1connect.ProfileServiceName,

View File

@@ -51,7 +51,7 @@ func (s *UserService) LoginUser(
if req.Msg.Remember { if req.Msg.Remember {
expirationTime = time.Now().Add(30 * 24 * time.Hour) expirationTime = time.Now().Add(30 * 24 * time.Hour)
} }
token, err := util.EncodeUserJWT(user.Username, s.app.Secret, expirationTime) token, err := util.EncodeUserJWT(user.ID, s.app.Secret, expirationTime)
if err != nil { if err != nil {
return nil, connect.NewError(connect.CodeInternal, err) return nil, connect.NewError(connect.CodeInternal, err)
} }
@@ -109,7 +109,7 @@ func (s *UserService) VerifyOTP(
} }
expirationTime := time.Now().Add(1 * time.Hour) expirationTime := time.Now().Add(1 * time.Hour)
token, err := util.EncodeUserJWT(user.Username, s.app.Secret, expirationTime) token, err := util.EncodeUserJWT(user.ID, s.app.Secret, expirationTime)
if err != nil { if err != nil {
return nil, connect.NewError(connect.CodeInternal, err) return nil, connect.NewError(connect.CodeInternal, err)
} }

View File

@@ -10,20 +10,20 @@ import (
const CookieName = "auth_token" const CookieName = "auth_token"
type UserClaims struct { type UserClaims struct {
Username string `json:"username,omitempty"` UserID string `json:"user_id,omitempty"`
jwt.RegisteredClaims jwt.RegisteredClaims
} }
// EncodeUserJWT generates a JWT for user login // EncodeUserJWT generates a JWT for user login
func EncodeUserJWT(username, secret string, expirationTime time.Time) (string, error) { func EncodeUserJWT(userID, secret string, expirationTime time.Time) (string, error) {
if username == "" { if userID == "" {
return "", errors.New("username cannot be empty") return "", errors.New("username cannot be empty")
} }
if expirationTime.IsZero() { if expirationTime.IsZero() {
expirationTime = time.Now().Add(24 * time.Hour) expirationTime = time.Now().Add(24 * time.Hour)
} }
claims := &UserClaims{ claims := &UserClaims{
Username: username, UserID: userID,
RegisteredClaims: jwt.RegisteredClaims{ RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expirationTime), ExpiresAt: jwt.NewNumericDate(expirationTime),
IssuedAt: jwt.NewNumericDate(time.Now()), IssuedAt: jwt.NewNumericDate(time.Now()),