chore: Refactor auth/CSRF middleware (#10113)

* chore: Refactor auth/CSRF middleware

* sp
This commit is contained in:
Tom Moor
2025-09-07 14:36:46 +02:00
committed by GitHub
parent 58a41a6fde
commit 5d5bed8270
4 changed files with 211 additions and 172 deletions

View File

@@ -28,7 +28,7 @@ const router = new Router();
router.post(
"files.create",
rateLimiter(RateLimiterStrategy.TenPerMinute),
auth({ allowMultipart: true }),
auth(),
validate(T.FilesCreateSchema),
multipart({
maximumFileSize: Math.max(

View File

@@ -17,185 +17,29 @@ import {
} from "../errors";
type AuthenticationOptions = {
/** Role requuired to access the route. */
/** Role required to access the route. */
role?: UserRole;
/** Type of authentication required to access the route. */
type?: AuthenticationType | AuthenticationType[];
/** Authentication is parsed, but optional. */
optional?: boolean;
/**
* Allow multipart requests with cookie authentication, otherwise
* the request will fail if the content type is not application/json.
* This is useful for file uploads where the cookie is used to authenticate.
*/
allowMultipart?: boolean;
};
type AuthTransport = "cookie" | "header" | "body" | "query";
type AuthInput = {
/** The authentication token extracted from the request, if any. */
token?: string;
/** The method used to receive the authentication token. */
transport?: AuthTransport;
};
export default function auth(options: AuthenticationOptions = {}) {
return async function authMiddleware(ctx: AppContext, next: Next) {
let token;
const authorizationHeader = ctx.request.get("authorization");
if (authorizationHeader) {
const parts = authorizationHeader.split(" ");
if (parts.length === 2) {
const scheme = parts[0];
const credentials = parts[1];
if (/^Bearer$/i.test(scheme)) {
token = credentials;
}
} else {
throw AuthenticationError(
`Bad Authorization header format. Format is "Authorization: Bearer <token>"`
);
}
} else if (
ctx.request.body &&
typeof ctx.request.body === "object" &&
"token" in ctx.request.body
) {
token = ctx.request.body.token;
} else if (ctx.request.query?.token) {
token = ctx.request.query.token;
} else {
token = ctx.cookies.get("accessToken");
// check if the request is application/json encoded
// TODO: Enable once clients have updated
// if (
// token &&
// !ctx.request.is("application/json") &&
// !options.allowMultipart
// ) {
// throw AuthenticationError(
// "Mismatched content type. Expected application/json"
// );
// }
}
try {
if (!token) {
throw AuthenticationError("Authentication required");
}
const { type, token, user } = await validateAuthentication(ctx, options);
let user: User | null;
let type: AuthenticationType;
if (OAuthAuthentication.match(String(token))) {
if (!authorizationHeader) {
throw AuthenticationError(
"OAuth access token must be passed in the Authorization header"
);
}
type = AuthenticationType.OAUTH;
let authentication;
try {
authentication = await OAuthAuthentication.findByAccessToken(token, {
rejectOnEmpty: true,
});
} catch (_err) {
throw AuthenticationError("Invalid access token");
}
if (!authentication) {
throw AuthenticationError("Invalid access token");
}
if (authentication.accessTokenExpiresAt < new Date()) {
throw AuthenticationError("Access token is expired");
}
if (!authentication.canAccess(ctx.request.url)) {
throw AuthenticationError(
"Access token does not have access to this resource"
);
}
user = await User.findByPk(authentication.userId, {
include: [
{
model: Team,
as: "team",
required: true,
},
],
});
if (!user) {
throw AuthenticationError("Invalid access token");
}
await authentication.updateActiveAt();
} else if (ApiKey.match(String(token))) {
type = AuthenticationType.API;
let apiKey;
try {
apiKey = await ApiKey.findByToken(token);
} catch (_err) {
throw AuthenticationError("Invalid API key");
}
if (!apiKey) {
throw AuthenticationError("Invalid API key");
}
if (apiKey.expiresAt && apiKey.expiresAt < new Date()) {
throw AuthenticationError("API key is expired");
}
if (!apiKey.canAccess(ctx.request.url)) {
throw AuthenticationError(
"API key does not have access to this resource"
);
}
user = await User.findByPk(apiKey.userId, {
include: [
{
model: Team,
as: "team",
required: true,
},
],
});
if (!user) {
throw AuthenticationError("Invalid API key");
}
await apiKey.updateActiveAt();
} else {
type = AuthenticationType.APP;
user = await getUserForJWT(String(token));
}
if (user.isSuspended) {
const suspendingAdmin = await User.findOne({
where: {
id: user.suspendedById!,
},
paranoid: false,
});
throw UserSuspendedError({
adminEmail: suspendingAdmin?.email || undefined,
});
}
if (options.role && UserRoleHelper.isRoleLower(user.role, options.role)) {
throw AuthorizationError(`${capitalize(options.role)} role required`);
}
if (
options.type &&
(Array.isArray(options.type)
? !options.type.includes(type)
: type !== options.type)
) {
throw AuthorizationError(`Invalid authentication type`);
}
// not awaiting the promises here so that the request is not blocked
// We are not awaiting the promises here so that the request is not blocked
user.updateActiveAt(ctx).catch((err) => {
Logger.error("Failed to update user activeAt", err);
});
@@ -205,7 +49,7 @@ export default function auth(options: AuthenticationOptions = {}) {
ctx.state.auth = {
user,
token: String(token),
token,
type,
};
@@ -240,3 +84,196 @@ export default function auth(options: AuthenticationOptions = {}) {
return next();
};
}
/**
* Parses the authentication token from the request context.
*
* @param ctx The application context containing the request information.
* @returns An object containing the token and its transport method.
*/
export function parseAuthentication(ctx: AppContext): AuthInput {
const authorizationHeader = ctx.request.get("authorization");
if (authorizationHeader) {
const parts = authorizationHeader.split(" ");
if (parts.length === 2) {
const scheme = parts[0];
const credentials = parts[1];
if (/^Bearer$/i.test(scheme)) {
return {
token: credentials,
transport: "header",
};
}
} else {
throw AuthenticationError(
`Bad Authorization header format. Format is "Authorization: Bearer <token>"`
);
}
} else if (
ctx.request.body &&
typeof ctx.request.body === "object" &&
"token" in ctx.request.body
) {
return {
token: String(ctx.request.body.token),
transport: "body",
};
} else if (ctx.request.query?.token) {
return {
token: String(ctx.request.query.token),
transport: "query",
};
} else {
const accessToken = ctx.cookies.get("accessToken");
if (accessToken) {
return {
token: accessToken,
transport: "cookie",
};
}
}
return {
token: undefined,
transport: undefined,
};
}
async function validateAuthentication(
ctx: AppContext,
options: AuthenticationOptions
): Promise<{ user: User; token: string; type: AuthenticationType }> {
const { token, transport } = parseAuthentication(ctx);
if (!token) {
throw AuthenticationError("Authentication required");
}
let user: User | null;
let type: AuthenticationType;
if (OAuthAuthentication.match(token)) {
if (transport !== "header") {
throw AuthenticationError(
"OAuth access token must be passed in the Authorization header"
);
}
type = AuthenticationType.OAUTH;
let authentication;
try {
authentication = await OAuthAuthentication.findByAccessToken(token, {
rejectOnEmpty: true,
});
} catch (_err) {
throw AuthenticationError("Invalid access token");
}
if (!authentication) {
throw AuthenticationError("Invalid access token");
}
if (authentication.accessTokenExpiresAt < new Date()) {
throw AuthenticationError("Access token is expired");
}
if (!authentication.canAccess(ctx.request.url)) {
throw AuthenticationError(
"Access token does not have access to this resource"
);
}
user = await User.findByPk(authentication.userId, {
include: [
{
model: Team,
as: "team",
required: true,
},
],
});
if (!user) {
throw AuthenticationError("Invalid access token");
}
await authentication.updateActiveAt();
} else if (ApiKey.match(token)) {
if (transport === "cookie") {
throw AuthenticationError("API key must not be passed in the cookie");
}
type = AuthenticationType.API;
let apiKey;
try {
apiKey = await ApiKey.findByToken(token);
} catch (_err) {
throw AuthenticationError("Invalid API key");
}
if (!apiKey) {
throw AuthenticationError("Invalid API key");
}
if (apiKey.expiresAt && apiKey.expiresAt < new Date()) {
throw AuthenticationError("API key is expired");
}
if (!apiKey.canAccess(ctx.request.url)) {
throw AuthenticationError(
"API key does not have access to this resource"
);
}
user = await User.findByPk(apiKey.userId, {
include: [
{
model: Team,
as: "team",
required: true,
},
],
});
if (!user) {
throw AuthenticationError("Invalid API key");
}
await apiKey.updateActiveAt();
} else {
type = AuthenticationType.APP;
user = await getUserForJWT(token);
}
if (user.isSuspended) {
const suspendingAdmin = await User.findOne({
where: {
id: user.suspendedById!,
},
paranoid: false,
});
throw UserSuspendedError({
adminEmail: suspendingAdmin?.email || undefined,
});
}
if (options.role && UserRoleHelper.isRoleLower(user.role, options.role)) {
throw AuthorizationError(`${capitalize(options.role)} role required`);
}
if (
options.type &&
(Array.isArray(options.type)
? !options.type.includes(type)
: type !== options.type)
) {
throw AuthorizationError(`Invalid authentication type`);
}
return {
user,
type,
token,
};
}

View File

@@ -11,6 +11,7 @@ import {
import { getCookieDomain } from "@shared/utils/domains";
import { CSRF } from "@shared/constants";
import { CSRFError } from "@server/errors";
import { parseAuthentication } from "./authentication";
/**
* Middleware that generates and attaches CSRF tokens for safe methods
@@ -48,7 +49,8 @@ export function verifyCSRFToken() {
}
// If not using cookie-based auth, skip CSRF protection
if (!ctx.cookies.get("accessToken")) {
const { transport } = parseAuthentication(ctx);
if (transport !== "cookie") {
return false;
}

View File

@@ -1575,7 +1575,7 @@ router.post(
router.post(
"documents.import",
auth({ allowMultipart: true }),
auth(),
rateLimiter(RateLimiterStrategy.TwentyFivePerMinute),
validate(T.DocumentsImportSchema),
multipart({ maximumFileSize: env.FILE_STORAGE_IMPORT_MAX_SIZE }),