feat: add provider-aware public rate limit routing

This commit is contained in:
Bhagya Amarasinghe
2026-03-24 14:35:56 +05:30
parent e7ca66ed77
commit f737c8b76f
15 changed files with 500 additions and 88 deletions

View File

@@ -185,6 +185,10 @@ ENTERPRISE_LICENSE_KEY=
# Ignore Rate Limiting across the Formbricks app
# RATE_LIMITING_DISABLED=1
# Public unauthenticated IP-based rate limits can be handled by an edge provider.
# Supported values: none, cloudflare, cloudarmor, envoy
# EDGE_RATE_LIMIT_PROVIDER=none
# OpenTelemetry OTLP endpoint (base URL, exporters append /v1/traces and /v1/metrics)
# OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318
# OTEL_EXPORTER_OTLP_PROTOCOL=http/protobuf

View File

@@ -4,6 +4,11 @@ import { ZDisplayCreateInputV2 } from "@/app/api/v2/client/[environmentId]/displ
import { responses } from "@/app/lib/api/response";
import { transformErrorToDetails } from "@/app/lib/api/validator";
import { getOrganizationIdFromEnvironmentId } from "@/lib/utils/helper";
import {
applyPublicIpRateLimit,
publicEdgeRateLimitPolicies,
} from "@/modules/core/rate-limit/public-edge-rate-limit";
import { rateLimitConfigs } from "@/modules/core/rate-limit/rate-limit-configs";
import { getIsContactsEnabled } from "@/modules/ee/license-check/lib/utils";
import { createDisplay } from "./lib/display";
@@ -24,6 +29,15 @@ export const OPTIONS = async (): Promise<Response> => {
};
export const POST = async (request: Request, context: Context): Promise<Response> => {
try {
await applyPublicIpRateLimit(publicEdgeRateLimitPolicies.v2ClientDisplays, rateLimitConfigs.api.client);
} catch (error) {
return responses.tooManyRequestsResponse(
error instanceof Error ? error.message : "Rate limit exceeded",
true
);
}
const params = await context.params;
const jsonInput = await request.json();
const inputValidation = ZDisplayCreateInputV2.safeParse({

View File

@@ -14,6 +14,11 @@ import { getClientIpFromHeaders } from "@/lib/utils/client-ip";
import { getOrganizationIdFromEnvironmentId } from "@/lib/utils/helper";
import { formatValidationErrorsForV1Api, validateResponseData } from "@/modules/api/lib/validation";
import { validateOtherOptionLengthForMultipleChoice } from "@/modules/api/v2/lib/element";
import {
applyPublicIpRateLimit,
publicEdgeRateLimitPolicies,
} from "@/modules/core/rate-limit/public-edge-rate-limit";
import { rateLimitConfigs } from "@/modules/core/rate-limit/rate-limit-configs";
import { getIsContactsEnabled } from "@/modules/ee/license-check/lib/utils";
import { createQuotaFullObject } from "@/modules/ee/quotas/lib/helpers";
import { createResponseWithQuotaEvaluation } from "./lib/response";
@@ -36,6 +41,15 @@ export const OPTIONS = async (): Promise<Response> => {
};
export const POST = async (request: Request, context: Context): Promise<Response> => {
try {
await applyPublicIpRateLimit(publicEdgeRateLimitPolicies.v2ClientResponses, rateLimitConfigs.api.client);
} catch (error) {
return responses.tooManyRequestsResponse(
error instanceof Error ? error.message : "Rate limit exceeded",
true
);
}
const params = await context.params;
const requestHeaders = await headers();
let responseInput;

View File

@@ -12,6 +12,10 @@ vi.mock("@/modules/ee/audit-logs/lib/handler", () => ({
queueAuditEvent: vi.fn(),
}));
vi.mock("@/modules/ee/audit-logs/types/audit-log", () => ({
UNKNOWN_DATA: "unknown",
}));
vi.mock("@sentry/nextjs", () => ({
captureException: vi.fn(),
withScope: vi.fn((callback) => {
@@ -72,10 +76,13 @@ vi.mock("@/app/middleware/endpoint-validator", async () => {
});
vi.mock("@/modules/core/rate-limit/helpers", () => ({
applyIPRateLimit: vi.fn(),
applyRateLimit: vi.fn(),
}));
vi.mock("@/modules/core/rate-limit/public-edge-rate-limit", () => ({
applyPublicIpRateLimitForRoute: vi.fn(),
}));
vi.mock("@/modules/core/rate-limit/rate-limit-configs", () => ({
rateLimitConfigs: {
api: {
@@ -115,6 +122,7 @@ describe("withV1ApiWrapper", () => {
vi.doMock("@/lib/constants", () => ({
AUDIT_LOG_ENABLED: true,
EDGE_RATE_LIMIT_PROVIDER: "none",
IS_PRODUCTION: true,
SENTRY_DSN: "dsn",
ENCRYPTION_KEY: "test-key",
@@ -131,11 +139,13 @@ describe("withV1ApiWrapper", () => {
});
test("logs and audits on error response with API key authentication", async () => {
const { queueAuditEvent: mockedQueueAuditEvent } =
(await import("@/modules/ee/audit-logs/lib/handler")) as unknown as { queueAuditEvent: Mock };
const { queueAuditEvent: mockedQueueAuditEvent } = (await import(
"@/modules/ee/audit-logs/lib/handler"
)) as unknown as { queueAuditEvent: Mock };
const { authenticateRequest } = await import("@/app/api/v1/auth");
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } =
await import("@/app/middleware/endpoint-validator");
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } = await import(
"@/app/middleware/endpoint-validator"
);
vi.mocked(authenticateRequest).mockResolvedValue(mockApiAuthentication);
vi.mocked(isClientSideApiRoute).mockReturnValue({ isClientSideApi: false, isRateLimited: true });
@@ -183,11 +193,13 @@ describe("withV1ApiWrapper", () => {
});
test("does not log Sentry if not 500", async () => {
const { queueAuditEvent: mockedQueueAuditEvent } =
(await import("@/modules/ee/audit-logs/lib/handler")) as unknown as { queueAuditEvent: Mock };
const { queueAuditEvent: mockedQueueAuditEvent } = (await import(
"@/modules/ee/audit-logs/lib/handler"
)) as unknown as { queueAuditEvent: Mock };
const { authenticateRequest } = await import("@/app/api/v1/auth");
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } =
await import("@/app/middleware/endpoint-validator");
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } = await import(
"@/app/middleware/endpoint-validator"
);
vi.mocked(authenticateRequest).mockResolvedValue(mockApiAuthentication);
vi.mocked(isClientSideApiRoute).mockReturnValue({ isClientSideApi: false, isRateLimited: true });
@@ -229,11 +241,13 @@ describe("withV1ApiWrapper", () => {
});
test("logs and audits on thrown error", async () => {
const { queueAuditEvent: mockedQueueAuditEvent } =
(await import("@/modules/ee/audit-logs/lib/handler")) as unknown as { queueAuditEvent: Mock };
const { queueAuditEvent: mockedQueueAuditEvent } = (await import(
"@/modules/ee/audit-logs/lib/handler"
)) as unknown as { queueAuditEvent: Mock };
const { authenticateRequest } = await import("@/app/api/v1/auth");
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } =
await import("@/app/middleware/endpoint-validator");
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } = await import(
"@/app/middleware/endpoint-validator"
);
vi.mocked(authenticateRequest).mockResolvedValue(mockApiAuthentication);
vi.mocked(isClientSideApiRoute).mockReturnValue({ isClientSideApi: false, isRateLimited: true });
@@ -285,11 +299,13 @@ describe("withV1ApiWrapper", () => {
});
test("does not log on success response but still audits", async () => {
const { queueAuditEvent: mockedQueueAuditEvent } =
(await import("@/modules/ee/audit-logs/lib/handler")) as unknown as { queueAuditEvent: Mock };
const { queueAuditEvent: mockedQueueAuditEvent } = (await import(
"@/modules/ee/audit-logs/lib/handler"
)) as unknown as { queueAuditEvent: Mock };
const { authenticateRequest } = await import("@/app/api/v1/auth");
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } =
await import("@/app/middleware/endpoint-validator");
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } = await import(
"@/app/middleware/endpoint-validator"
);
vi.mocked(authenticateRequest).mockResolvedValue(mockApiAuthentication);
vi.mocked(isClientSideApiRoute).mockReturnValue({ isClientSideApi: false, isRateLimited: true });
@@ -333,17 +349,20 @@ describe("withV1ApiWrapper", () => {
test("does not call audit if AUDIT_LOG_ENABLED is false", async () => {
vi.doMock("@/lib/constants", () => ({
AUDIT_LOG_ENABLED: false,
EDGE_RATE_LIMIT_PROVIDER: "none",
IS_PRODUCTION: true,
SENTRY_DSN: "dsn",
ENCRYPTION_KEY: "test-key",
REDIS_URL: "redis://localhost:6379",
}));
const { queueAuditEvent: mockedQueueAuditEvent } =
(await import("@/modules/ee/audit-logs/lib/handler")) as unknown as { queueAuditEvent: Mock };
const { queueAuditEvent: mockedQueueAuditEvent } = (await import(
"@/modules/ee/audit-logs/lib/handler"
)) as unknown as { queueAuditEvent: Mock };
const { authenticateRequest } = await import("@/app/api/v1/auth");
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } =
await import("@/app/middleware/endpoint-validator");
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } = await import(
"@/app/middleware/endpoint-validator"
);
const { withV1ApiWrapper } = await import("./with-api-logging");
vi.mocked(authenticateRequest).mockResolvedValue(mockApiAuthentication);
@@ -366,10 +385,13 @@ describe("withV1ApiWrapper", () => {
});
test("handles client-side API routes without authentication", async () => {
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } =
await import("@/app/middleware/endpoint-validator");
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } = await import(
"@/app/middleware/endpoint-validator"
);
const { authenticateRequest } = await import("@/app/api/v1/auth");
const { applyIPRateLimit } = await import("@/modules/core/rate-limit/helpers");
const { applyPublicIpRateLimitForRoute } = await import(
"@/modules/core/rate-limit/public-edge-rate-limit"
);
vi.mocked(isClientSideApiRoute).mockReturnValue({ isClientSideApi: true, isRateLimited: true });
vi.mocked(isManagementApiRoute).mockReturnValue({
@@ -378,7 +400,7 @@ describe("withV1ApiWrapper", () => {
});
vi.mocked(isIntegrationRoute).mockReturnValue(false);
vi.mocked(authenticateRequest).mockResolvedValue(null);
vi.mocked(applyIPRateLimit).mockResolvedValue(undefined);
vi.mocked(applyPublicIpRateLimitForRoute).mockResolvedValue("app");
const handler = vi.fn().mockResolvedValue({
response: responses.successResponse({ data: "test" }),
@@ -396,11 +418,17 @@ describe("withV1ApiWrapper", () => {
auditLog: undefined,
authentication: null,
});
expect(applyPublicIpRateLimitForRoute).toHaveBeenCalledWith(
"/api/v1/client/displays",
"GET",
expect.objectContaining({ max: 100 })
);
});
test("returns authentication error for non-client routes without auth", async () => {
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } =
await import("@/app/middleware/endpoint-validator");
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } = await import(
"@/app/middleware/endpoint-validator"
);
const { authenticateRequest } = await import("@/app/api/v1/auth");
vi.mocked(isClientSideApiRoute).mockReturnValue({ isClientSideApi: false, isRateLimited: true });
@@ -422,8 +450,9 @@ describe("withV1ApiWrapper", () => {
});
test("uses unauthenticatedResponse when provided instead of default 401", async () => {
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } =
await import("@/app/middleware/endpoint-validator");
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } = await import(
"@/app/middleware/endpoint-validator"
);
const { getServerSession } = await import("next-auth");
vi.mocked(isClientSideApiRoute).mockReturnValue({ isClientSideApi: false, isRateLimited: true });
@@ -455,8 +484,9 @@ describe("withV1ApiWrapper", () => {
test("handles rate limiting errors", async () => {
const { applyRateLimit } = await import("@/modules/core/rate-limit/helpers");
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } =
await import("@/app/middleware/endpoint-validator");
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } = await import(
"@/app/middleware/endpoint-validator"
);
const { authenticateRequest } = await import("@/app/api/v1/auth");
vi.mocked(authenticateRequest).mockResolvedValue(mockApiAuthentication);
@@ -481,11 +511,13 @@ describe("withV1ApiWrapper", () => {
});
test("skips audit log creation when no action/targetType provided", async () => {
const { queueAuditEvent: mockedQueueAuditEvent } =
(await import("@/modules/ee/audit-logs/lib/handler")) as unknown as { queueAuditEvent: Mock };
const { queueAuditEvent: mockedQueueAuditEvent } = (await import(
"@/modules/ee/audit-logs/lib/handler"
)) as unknown as { queueAuditEvent: Mock };
const { authenticateRequest } = await import("@/app/api/v1/auth");
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } =
await import("@/app/middleware/endpoint-validator");
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } = await import(
"@/app/middleware/endpoint-validator"
);
vi.mocked(authenticateRequest).mockResolvedValue(mockApiAuthentication);
vi.mocked(isClientSideApiRoute).mockReturnValue({ isClientSideApi: false, isRateLimited: true });

View File

@@ -13,7 +13,8 @@ import {
} from "@/app/middleware/endpoint-validator";
import { AUDIT_LOG_ENABLED, IS_PRODUCTION, SENTRY_DSN } from "@/lib/constants";
import { authOptions } from "@/modules/auth/lib/authOptions";
import { applyIPRateLimit, applyRateLimit } from "@/modules/core/rate-limit/helpers";
import { applyRateLimit } from "@/modules/core/rate-limit/helpers";
import { applyPublicIpRateLimitForRoute } from "@/modules/core/rate-limit/public-edge-rate-limit";
import { rateLimitConfigs } from "@/modules/core/rate-limit/rate-limit-configs";
import { TRateLimitConfig } from "@/modules/core/rate-limit/types/rate-limit";
import { queueAuditEvent } from "@/modules/ee/audit-logs/lib/handler";
@@ -54,14 +55,22 @@ enum ApiV1RouteTypeEnum {
/**
* Apply client-side API rate limiting (IP-based)
*/
const applyClientRateLimit = async (customRateLimitConfig?: TRateLimitConfig): Promise<void> => {
await applyIPRateLimit(customRateLimitConfig ?? rateLimitConfigs.api.client);
const applyClientRateLimit = async (
req: NextRequest,
customRateLimitConfig?: TRateLimitConfig
): Promise<void> => {
await applyPublicIpRateLimitForRoute(
req.nextUrl.pathname,
req.method,
customRateLimitConfig ?? rateLimitConfigs.api.client
);
};
/**
* Handle rate limiting based on authentication and API type
*/
const handleRateLimiting = async (
req: NextRequest,
authentication: TApiV1Authentication,
routeType: ApiV1RouteTypeEnum,
customRateLimitConfig?: TRateLimitConfig
@@ -81,7 +90,7 @@ const handleRateLimiting = async (
}
if (routeType === ApiV1RouteTypeEnum.Client) {
await applyClientRateLimit(customRateLimitConfig);
await applyClientRateLimit(req, customRateLimitConfig);
}
} catch (error) {
return responses.tooManyRequestsResponse(error instanceof Error ? error.message : "Rate limit exceeded");
@@ -305,7 +314,12 @@ export const withV1ApiWrapper = <TResult extends { response: Response }, TProps
// === Rate Limiting ===
if (isRateLimited) {
const rateLimitResponse = await handleRateLimiting(authentication, routeType, customRateLimitConfig);
const rateLimitResponse = await handleRateLimiting(
req,
authentication,
routeType,
customRateLimitConfig
);
if (rateLimitResponse) return rateLimitResponse;
}

View File

@@ -48,6 +48,10 @@ describe("endpoint-validator", () => {
isClientSideApi: true,
isRateLimited: false,
});
expect(isClientSideApiRoute("/api/v1/client/og-image")).toEqual({
isClientSideApi: true,
isRateLimited: true,
});
});
test("should return false for non-client-side API routes", () => {

View File

@@ -13,7 +13,7 @@ export enum AuthenticationMethod {
export const isClientSideApiRoute = (url: string): { isClientSideApi: boolean; isRateLimited: boolean } => {
// Open Graph image generation route is a client side API route but it should not be rate limited
if (url.includes("/api/v1/client/og")) return { isClientSideApi: true, isRateLimited: false };
if (/^\/api\/v1\/client\/og(?:\/.*)?$/.test(url)) return { isClientSideApi: true, isRateLimited: false };
const regex = /^\/api\/v\d+\/client\//;
return { isClientSideApi: regex.test(url), isRateLimited: true };

View File

@@ -3,6 +3,7 @@ import { TUserLocale } from "@formbricks/types/user";
import { env } from "./env";
export const IS_FORMBRICKS_CLOUD = env.IS_FORMBRICKS_CLOUD === "1";
export const EDGE_RATE_LIMIT_PROVIDER = env.EDGE_RATE_LIMIT_PROVIDER ?? "none";
export const IS_PRODUCTION = env.NODE_ENV === "production";

View File

@@ -21,6 +21,7 @@ export const env = createEnv({
E2E_TESTING: z.enum(["1", "0"]).optional(),
EMAIL_AUTH_DISABLED: z.enum(["1", "0"]).optional(),
EMAIL_VERIFICATION_DISABLED: z.enum(["1", "0"]).optional(),
EDGE_RATE_LIMIT_PROVIDER: z.enum(["none", "cloudflare", "cloudarmor", "envoy"]).optional(),
ENCRYPTION_KEY: z.string(),
ENTERPRISE_LICENSE_KEY: z.string().optional(),
ENVIRONMENT: z.enum(["production", "staging"]).prefault("production"),
@@ -147,6 +148,7 @@ export const env = createEnv({
E2E_TESTING: process.env.E2E_TESTING,
EMAIL_AUTH_DISABLED: process.env.EMAIL_AUTH_DISABLED,
EMAIL_VERIFICATION_DISABLED: process.env.EMAIL_VERIFICATION_DISABLED,
EDGE_RATE_LIMIT_PROVIDER: process.env.EDGE_RATE_LIMIT_PROVIDER,
ENCRYPTION_KEY: process.env.ENCRYPTION_KEY,
ENTERPRISE_LICENSE_KEY: process.env.ENTERPRISE_LICENSE_KEY,
ENVIRONMENT: process.env.ENVIRONMENT,

View File

@@ -3,11 +3,14 @@ import { Provider } from "next-auth/providers/index";
import { afterEach, describe, expect, test, vi } from "vitest";
import { prisma } from "@formbricks/database";
import { EMAIL_VERIFICATION_DISABLED } from "@/lib/constants";
// Import mocked rate limiting functions
import { applyIPRateLimit } from "@/modules/core/rate-limit/helpers";
import {
applyPublicIpRateLimit,
publicEdgeRateLimitPolicies,
} from "@/modules/core/rate-limit/public-edge-rate-limit";
import { rateLimitConfigs } from "@/modules/core/rate-limit/rate-limit-configs";
import { authOptions } from "./authOptions";
import { mockUser } from "./mock-data";
import { getUserByEmail } from "./user";
import { hashPassword } from "./utils";
// Mock encryption utilities
@@ -19,11 +22,48 @@ vi.mock("@/lib/encryption", () => ({
// Mock JWT
vi.mock("@/lib/jwt");
// Mock rate limiting dependencies
vi.mock("@/modules/core/rate-limit/helpers", () => ({
applyIPRateLimit: vi.fn(),
vi.mock("@/modules/core/rate-limit/public-edge-rate-limit", () => ({
applyPublicIpRateLimit: vi.fn(),
publicEdgeRateLimitPolicies: {
authLogin: "auth.login",
authVerifyEmail: "auth.verify_email",
},
}));
vi.mock("./user", () => ({
getUserByEmail: vi.fn(),
updateUser: vi.fn(),
updateUserLastLoginAt: vi.fn(),
}));
vi.mock("./brevo", () => ({
createBrevoCustomer: vi.fn(),
}));
vi.mock("@/modules/ee/sso/lib/providers", () => ({
getSSOProviders: vi.fn(() => []),
}));
vi.mock("@/modules/ee/sso/lib/sso-handlers", () => ({
handleSsoCallback: vi.fn(),
}));
vi.mock("@/modules/ee/audit-logs/lib/handler", () => ({
queueAuditEventBackground: vi.fn(),
}));
vi.mock("@/modules/ee/audit-logs/types/audit-log", () => ({
UNKNOWN_DATA: "unknown",
}));
vi.mock("./utils", async (importOriginal) => {
const actual = await importOriginal<typeof import("./utils")>();
return {
...actual,
shouldLogAuthFailure: vi.fn().mockResolvedValue(false),
};
});
vi.mock("@/modules/core/rate-limit/rate-limit-configs", () => ({
rateLimitConfigs: {
auth: {
@@ -33,26 +73,22 @@ vi.mock("@/modules/core/rate-limit/rate-limit-configs", () => ({
},
}));
// Mock constants that this test needs while preserving untouched exports.
vi.mock("@/lib/constants", async (importOriginal) => {
const actual = await importOriginal<typeof import("@/lib/constants")>();
return {
...actual,
EMAIL_VERIFICATION_DISABLED: false,
SESSION_MAX_AGE: 86400,
NEXTAUTH_SECRET: "test-secret",
WEBAPP_URL: "http://localhost:3000",
ENCRYPTION_KEY: "12345678901234567890123456789012", // 32 bytes for AES-256
REDIS_URL: undefined,
AUDIT_LOG_ENABLED: false,
AUDIT_LOG_GET_USER_IP: false,
ENTERPRISE_LICENSE_KEY: undefined,
SENTRY_DSN: undefined,
BREVO_API_KEY: undefined,
RATE_LIMITING_DISABLED: false,
CONTROL_HASH: "$2b$12$fzHf9le13Ss9UJ04xzmsjODXpFJxz6vsnupoepF5FiqDECkX2BH5q",
};
});
vi.mock("@/lib/constants", () => ({
EMAIL_VERIFICATION_DISABLED: false,
EDGE_RATE_LIMIT_PROVIDER: "none",
SESSION_MAX_AGE: 86400,
NEXTAUTH_SECRET: "test-secret",
WEBAPP_URL: "http://localhost:3000",
ENCRYPTION_KEY: "12345678901234567890123456789012", // 32 bytes for AES-256
REDIS_URL: undefined,
AUDIT_LOG_ENABLED: false,
AUDIT_LOG_GET_USER_IP: false,
ENTERPRISE_LICENSE_KEY: undefined,
SENTRY_DSN: undefined,
BREVO_API_KEY: undefined,
RATE_LIMITING_DISABLED: false,
CONTROL_HASH: "$2b$12$fzHf9le13Ss9UJ04xzmsjODXpFJxz6vsnupoepF5FiqDECkX2BH5q",
}));
// Mock next/headers
vi.mock("next/headers", () => ({
@@ -114,7 +150,7 @@ describe("authOptions", () => {
});
test("should throw error if user not found", async () => {
vi.mocked(applyIPRateLimit).mockResolvedValue(); // Rate limiting passes
vi.mocked(applyPublicIpRateLimit).mockResolvedValue("app");
vi.spyOn(prisma.user, "findUnique").mockResolvedValue(null);
const credentials = { email: mockUser.email, password: mockPassword };
@@ -125,7 +161,7 @@ describe("authOptions", () => {
});
test("should throw error if user has no password stored", async () => {
vi.mocked(applyIPRateLimit).mockResolvedValue(); // Rate limiting passes
vi.mocked(applyPublicIpRateLimit).mockResolvedValue("app");
vi.spyOn(prisma.user, "findUnique").mockResolvedValue({
id: mockUser.id,
email: mockUser.email,
@@ -140,7 +176,7 @@ describe("authOptions", () => {
});
test("should throw error if password verification fails", async () => {
vi.mocked(applyIPRateLimit).mockResolvedValue(); // Rate limiting passes
vi.mocked(applyPublicIpRateLimit).mockResolvedValue("app");
vi.spyOn(prisma.user, "findUnique").mockResolvedValue({
id: mockUserId,
email: mockUser.email,
@@ -155,7 +191,7 @@ describe("authOptions", () => {
});
test("should successfully login when credentials are valid", async () => {
vi.mocked(applyIPRateLimit).mockResolvedValue(); // Rate limiting passes
vi.mocked(applyPublicIpRateLimit).mockResolvedValue("app");
const fakeUser = {
id: mockUserId,
email: mockUser.email,
@@ -178,7 +214,7 @@ describe("authOptions", () => {
describe("Rate Limiting", () => {
test("should apply rate limiting before credential validation", async () => {
vi.mocked(applyIPRateLimit).mockResolvedValue();
vi.mocked(applyPublicIpRateLimit).mockResolvedValue("app");
vi.spyOn(prisma.user, "findUnique").mockResolvedValue({
id: mockUserId,
email: mockUser.email,
@@ -191,12 +227,15 @@ describe("authOptions", () => {
await credentialsProvider.options.authorize(credentials, {});
expect(applyIPRateLimit).toHaveBeenCalledWith(rateLimitConfigs.auth.login);
expect(applyIPRateLimit).toHaveBeenCalledBefore(prisma.user.findUnique as any);
expect(applyPublicIpRateLimit).toHaveBeenCalledWith(
publicEdgeRateLimitPolicies.authLogin,
rateLimitConfigs.auth.login
);
expect(applyPublicIpRateLimit).toHaveBeenCalledBefore(prisma.user.findUnique as any);
});
test("should block login when rate limit exceeded", async () => {
vi.mocked(applyIPRateLimit).mockRejectedValue(
vi.mocked(applyPublicIpRateLimit).mockRejectedValue(
new Error("Maximum number of requests reached. Please try again later.")
);
const findUniqueSpy = vi.spyOn(prisma.user, "findUnique");
@@ -211,7 +250,7 @@ describe("authOptions", () => {
});
test("should use correct rate limit configuration", async () => {
vi.mocked(applyIPRateLimit).mockResolvedValue();
vi.mocked(applyPublicIpRateLimit).mockResolvedValue("app");
vi.spyOn(prisma.user, "findUnique").mockResolvedValue({
id: mockUserId,
email: mockUser.email,
@@ -224,7 +263,7 @@ describe("authOptions", () => {
await credentialsProvider.options.authorize(credentials, {});
expect(applyIPRateLimit).toHaveBeenCalledWith({
expect(applyPublicIpRateLimit).toHaveBeenCalledWith(publicEdgeRateLimitPolicies.authLogin, {
interval: 900,
allowedPerInterval: 30,
namespace: "auth:login",
@@ -234,7 +273,7 @@ describe("authOptions", () => {
describe("Two-Factor Backup Code login", () => {
test("should throw error if backup codes are missing", async () => {
vi.mocked(applyIPRateLimit).mockResolvedValue(); // Rate limiting passes
vi.mocked(applyPublicIpRateLimit).mockResolvedValue("app");
const mockUser = {
id: mockUserId,
email: "2fa@example.com",
@@ -263,7 +302,7 @@ describe("authOptions", () => {
});
test("should throw error if token is invalid or user not found", async () => {
vi.mocked(applyIPRateLimit).mockResolvedValue(); // Rate limiting passes
vi.mocked(applyPublicIpRateLimit).mockResolvedValue("app");
const credentials = { token: "badtoken" };
await expect(tokenProvider.options.authorize(credentials, {})).rejects.toThrow(
@@ -273,17 +312,20 @@ describe("authOptions", () => {
describe("Rate Limiting", () => {
test("should apply rate limiting before token verification", async () => {
vi.mocked(applyIPRateLimit).mockResolvedValue();
vi.mocked(applyPublicIpRateLimit).mockResolvedValue("app");
const credentials = { token: "sometoken" };
await expect(tokenProvider.options.authorize(credentials, {})).rejects.toThrow();
expect(applyIPRateLimit).toHaveBeenCalledWith(rateLimitConfigs.auth.verifyEmail);
expect(applyPublicIpRateLimit).toHaveBeenCalledWith(
publicEdgeRateLimitPolicies.authVerifyEmail,
rateLimitConfigs.auth.verifyEmail
);
});
test("should block verification when rate limit exceeded", async () => {
vi.mocked(applyIPRateLimit).mockRejectedValue(
vi.mocked(applyPublicIpRateLimit).mockRejectedValue(
new Error("Maximum number of requests reached. Please try again later.")
);
const findUniqueSpy = vi.spyOn(prisma.user, "findUnique");
@@ -302,7 +344,7 @@ describe("authOptions", () => {
describe("Callbacks", () => {
describe("jwt callback", () => {
test("should add profile information to token if user is found", async () => {
vi.spyOn(prisma.user, "findFirst").mockResolvedValue({
vi.mocked(getUserByEmail).mockResolvedValue({
id: mockUser.id,
locale: mockUser.locale,
email: mockUser.email,
@@ -321,7 +363,7 @@ describe("authOptions", () => {
});
test("should return token unchanged if no existing user is found", async () => {
vi.spyOn(prisma.user, "findFirst").mockResolvedValue(null);
vi.mocked(getUserByEmail).mockResolvedValue(null);
const token = { email: "nonexistent@example.com" };
if (!authOptions.callbacks?.jwt) {
@@ -366,7 +408,7 @@ describe("authOptions", () => {
const credentialsProvider = getProviderById("credentials");
test("should throw error if TOTP code is missing when 2FA is enabled", async () => {
vi.mocked(applyIPRateLimit).mockResolvedValue(); // Rate limiting passes
vi.mocked(applyPublicIpRateLimit).mockResolvedValue("app");
const mockUser = {
id: mockUserId,
email: "2fa@example.com",
@@ -384,7 +426,7 @@ describe("authOptions", () => {
});
test("should throw error if two factor secret is missing", async () => {
vi.mocked(applyIPRateLimit).mockResolvedValue(); // Rate limiting passes
vi.mocked(applyPublicIpRateLimit).mockResolvedValue("app");
const mockUser = {
id: mockUserId,
email: "2fa@example.com",

View File

@@ -23,7 +23,10 @@ import {
shouldLogAuthFailure,
verifyPassword,
} from "@/modules/auth/lib/utils";
import { applyIPRateLimit } from "@/modules/core/rate-limit/helpers";
import {
applyPublicIpRateLimit,
publicEdgeRateLimitPolicies,
} from "@/modules/core/rate-limit/public-edge-rate-limit";
import { rateLimitConfigs } from "@/modules/core/rate-limit/rate-limit-configs";
import { UNKNOWN_DATA } from "@/modules/ee/audit-logs/types/audit-log";
import { getSSOProviders } from "@/modules/ee/sso/lib/providers";
@@ -55,7 +58,7 @@ export const authOptions: NextAuthOptions = {
backupCode: { label: "Backup Code", type: "input", placeholder: "Two-factor backup code" },
},
async authorize(credentials, _req) {
await applyIPRateLimit(rateLimitConfigs.auth.login);
await applyPublicIpRateLimit(publicEdgeRateLimitPolicies.authLogin, rateLimitConfigs.auth.login);
// Use email for rate limiting when available, fall back to "unknown_user" for credential validation
const identifier = credentials?.email || "unknown_user"; // NOSONAR // We want to check for empty strings
@@ -245,7 +248,10 @@ export const authOptions: NextAuthOptions = {
},
},
async authorize(credentials, _req) {
await applyIPRateLimit(rateLimitConfigs.auth.verifyEmail);
await applyPublicIpRateLimit(
publicEdgeRateLimitPolicies.authVerifyEmail,
rateLimitConfigs.auth.verifyEmail
);
// For token verification, we can't rate limit effectively by token (single-use)
// So we use a generic identifier for token abuse attempts

View File

@@ -0,0 +1,142 @@
import { beforeEach, describe, expect, test, vi } from "vitest";
import { applyIPRateLimit } from "./helpers";
import {
applyPublicIpRateLimit,
applyPublicIpRateLimitForRoute,
getEdgeRateLimitProvider,
getPublicEdgeRateLimitPolicyId,
isPublicEdgeRateLimitManaged,
publicEdgeRateLimitPolicies,
} from "./public-edge-rate-limit";
vi.mock("./helpers", () => ({
applyIPRateLimit: vi.fn(),
}));
const mockConfig = {
interval: 60,
allowedPerInterval: 100,
namespace: "api:client",
};
describe("public-edge-rate-limit", () => {
beforeEach(() => {
vi.clearAllMocks();
});
describe("getEdgeRateLimitProvider", () => {
test("falls back to none for unknown providers", () => {
expect(getEdgeRateLimitProvider(undefined)).toBe("none");
expect(getEdgeRateLimitProvider("unknown")).toBe("none");
});
test("accepts configured providers", () => {
expect(getEdgeRateLimitProvider("cloudflare")).toBe("cloudflare");
expect(getEdgeRateLimitProvider("cloudarmor")).toBe("cloudarmor");
expect(getEdgeRateLimitProvider("envoy")).toBe("envoy");
});
});
describe("getPublicEdgeRateLimitPolicyId", () => {
test("classifies auth callback routes", () => {
expect(getPublicEdgeRateLimitPolicyId("/api/auth/callback/credentials", "POST")).toBe(
publicEdgeRateLimitPolicies.authLogin
);
expect(getPublicEdgeRateLimitPolicyId("/api/auth/callback/token", "POST")).toBe(
publicEdgeRateLimitPolicies.authVerifyEmail
);
});
test("classifies v1 client routes", () => {
expect(getPublicEdgeRateLimitPolicyId("/api/v1/client/env_123/environment", "GET")).toBe(
publicEdgeRateLimitPolicies.v1ClientDefault
);
expect(getPublicEdgeRateLimitPolicyId("/api/v1/client/env_123/storage", "POST")).toBe(
publicEdgeRateLimitPolicies.v1ClientStorageUpload
);
expect(getPublicEdgeRateLimitPolicyId("/api/v1/client/og", "GET")).toBeNull();
expect(getPublicEdgeRateLimitPolicyId("/api/v1/client/og/image", "GET")).toBeNull();
expect(getPublicEdgeRateLimitPolicyId("/api/v1/client/og-image", "GET")).toBe(
publicEdgeRateLimitPolicies.v1ClientDefault
);
});
test("classifies v2 public write routes", () => {
expect(getPublicEdgeRateLimitPolicyId("/api/v2/client/env_123/responses", "POST")).toBe(
publicEdgeRateLimitPolicies.v2ClientResponses
);
expect(getPublicEdgeRateLimitPolicyId("/api/v2/client/env_123/responses/resp_123", "PUT")).toBe(
publicEdgeRateLimitPolicies.v2ClientResponses
);
expect(getPublicEdgeRateLimitPolicyId("/api/v2/client/env_123/displays", "POST")).toBe(
publicEdgeRateLimitPolicies.v2ClientDisplays
);
expect(getPublicEdgeRateLimitPolicyId("/api/v2/client/env_123/storage", "POST")).toBe(
publicEdgeRateLimitPolicies.v2ClientStorageUpload
);
});
});
describe("isPublicEdgeRateLimitManaged", () => {
test("manages public policies on cloudflare and cloudarmor only", () => {
expect(isPublicEdgeRateLimitManaged(publicEdgeRateLimitPolicies.authLogin, "cloudflare")).toBe(true);
expect(isPublicEdgeRateLimitManaged(publicEdgeRateLimitPolicies.authLogin, "cloudarmor")).toBe(true);
expect(isPublicEdgeRateLimitManaged(publicEdgeRateLimitPolicies.authLogin, "none")).toBe(false);
expect(isPublicEdgeRateLimitManaged(publicEdgeRateLimitPolicies.authLogin, "envoy")).toBe(false);
});
});
describe("applyPublicIpRateLimit", () => {
test("uses app rate limiting when no edge provider manages the policy", async () => {
vi.mocked(applyIPRateLimit).mockResolvedValue({ allowed: true });
const source = await applyPublicIpRateLimit(
publicEdgeRateLimitPolicies.v2ClientResponses,
mockConfig,
"none"
);
expect(source).toBe("app");
expect(applyIPRateLimit).toHaveBeenCalledWith(mockConfig);
});
test("skips app rate limiting when the edge provider manages the policy", async () => {
const source = await applyPublicIpRateLimit(
publicEdgeRateLimitPolicies.v2ClientResponses,
mockConfig,
"cloudflare"
);
expect(source).toBe("edge");
expect(applyIPRateLimit).not.toHaveBeenCalled();
});
});
describe("applyPublicIpRateLimitForRoute", () => {
test("uses the route classifier for managed public routes", async () => {
const source = await applyPublicIpRateLimitForRoute(
"/api/v2/client/env_123/displays",
"POST",
mockConfig,
"cloudarmor"
);
expect(source).toBe("edge");
expect(applyIPRateLimit).not.toHaveBeenCalled();
});
test("falls back to app rate limiting for unmanaged routes", async () => {
vi.mocked(applyIPRateLimit).mockResolvedValue({ allowed: true });
const source = await applyPublicIpRateLimitForRoute(
"/api/v1/client/env_123/environment",
"GET",
mockConfig,
"envoy"
);
expect(source).toBe("app");
expect(applyIPRateLimit).toHaveBeenCalledWith(mockConfig);
});
});
});

View File

@@ -0,0 +1,135 @@
import { EDGE_RATE_LIMIT_PROVIDER } from "@/lib/constants";
import { applyIPRateLimit } from "./helpers";
import { TRateLimitConfig } from "./types/rate-limit";
export const publicEdgeRateLimitPolicies = {
authLogin: "auth.login",
authVerifyEmail: "auth.verify_email",
v1ClientDefault: "client.v1.default",
v1ClientStorageUpload: "client.storage.upload.v1",
v2ClientResponses: "client.responses.v2",
v2ClientDisplays: "client.displays.v2",
v2ClientStorageUpload: "client.storage.upload.v2",
} as const;
export type TPublicEdgeRateLimitPolicyId =
(typeof publicEdgeRateLimitPolicies)[keyof typeof publicEdgeRateLimitPolicies];
export type TEdgeRateLimitProvider = "none" | "cloudflare" | "cloudarmor" | "envoy";
const managedPublicEdgePolicies = Object.values(
publicEdgeRateLimitPolicies
) as TPublicEdgeRateLimitPolicyId[];
const managedPublicEdgePoliciesByProvider: Record<
TEdgeRateLimitProvider,
readonly TPublicEdgeRateLimitPolicyId[]
> = {
none: [],
cloudflare: managedPublicEdgePolicies,
cloudarmor: managedPublicEdgePolicies,
envoy: [],
};
const normalizeEdgeRateLimitProvider = (provider: string | undefined): TEdgeRateLimitProvider => {
switch (provider) {
case "cloudflare":
case "cloudarmor":
case "envoy":
return provider;
default:
return "none";
}
};
const normalizePathname = (pathname: string): string => {
if (pathname.length > 1 && pathname.endsWith("/")) {
return pathname.slice(0, -1);
}
return pathname;
};
export const getEdgeRateLimitProvider = (
provider: string | undefined = EDGE_RATE_LIMIT_PROVIDER
): TEdgeRateLimitProvider => normalizeEdgeRateLimitProvider(provider);
export const getPublicEdgeRateLimitPolicyId = (
pathname: string,
method: string
): TPublicEdgeRateLimitPolicyId | null => {
const normalizedPathname = normalizePathname(pathname);
const normalizedMethod = method.toUpperCase();
if (normalizedMethod === "POST" && normalizedPathname === "/api/auth/callback/credentials") {
return publicEdgeRateLimitPolicies.authLogin;
}
if (normalizedMethod === "POST" && normalizedPathname === "/api/auth/callback/token") {
return publicEdgeRateLimitPolicies.authVerifyEmail;
}
if (/^\/api\/v1\/client\/og(?:\/.*)?$/.test(normalizedPathname)) {
return null;
}
if (/^\/api\/v1\/client\/[^/]+\/storage$/.test(normalizedPathname) && normalizedMethod === "POST") {
return publicEdgeRateLimitPolicies.v1ClientStorageUpload;
}
if (/^\/api\/v2\/client\/[^/]+\/storage$/.test(normalizedPathname) && normalizedMethod === "POST") {
return publicEdgeRateLimitPolicies.v2ClientStorageUpload;
}
if (
/^\/api\/v2\/client\/[^/]+\/responses(?:\/[^/]+)?$/.test(normalizedPathname) &&
(normalizedMethod === "POST" || normalizedMethod === "PUT")
) {
return publicEdgeRateLimitPolicies.v2ClientResponses;
}
if (/^\/api\/v2\/client\/[^/]+\/displays$/.test(normalizedPathname) && normalizedMethod === "POST") {
return publicEdgeRateLimitPolicies.v2ClientDisplays;
}
if (normalizedPathname.startsWith("/api/v1/client/")) {
return publicEdgeRateLimitPolicies.v1ClientDefault;
}
return null;
};
export const isPublicEdgeRateLimitManaged = (
policyId: TPublicEdgeRateLimitPolicyId,
provider: string | undefined = EDGE_RATE_LIMIT_PROVIDER
): boolean => managedPublicEdgePoliciesByProvider[getEdgeRateLimitProvider(provider)].includes(policyId);
export const applyPublicIpRateLimit = async (
policyId: TPublicEdgeRateLimitPolicyId,
config: TRateLimitConfig,
provider: string | undefined = EDGE_RATE_LIMIT_PROVIDER
): Promise<"app" | "edge"> => {
if (isPublicEdgeRateLimitManaged(policyId, provider)) {
return "edge";
}
await applyIPRateLimit(config);
return "app";
};
export const applyPublicIpRateLimitForRoute = async (
pathname: string,
method: string,
config: TRateLimitConfig,
provider: string | undefined = EDGE_RATE_LIMIT_PROVIDER
): Promise<"app" | "edge"> => {
const policyId = getPublicEdgeRateLimitPolicyId(pathname, method);
if (!policyId) {
await applyIPRateLimit(config);
return "app";
}
return await applyPublicIpRateLimit(policyId, config, provider);
};

View File

@@ -186,6 +186,7 @@ export const testInputValidation = async (service: Function, ...args: any[]): Pr
vi.mock("@/lib/constants", () => ({
IS_FORMBRICKS_CLOUD: false,
EDGE_RATE_LIMIT_PROVIDER: "none",
ENCRYPTION_KEY: "mock-encryption-key",
ENTERPRISE_LICENSE_KEY: "mock-enterprise-license-key",
GITHUB_ID: "mock-github-id",

View File

@@ -146,6 +146,7 @@
"E2E_TESTING",
"EMAIL_AUTH_DISABLED",
"EMAIL_VERIFICATION_DISABLED",
"EDGE_RATE_LIMIT_PROVIDER",
"ENCRYPTION_KEY",
"ENTERPRISE_LICENSE_KEY",
"ENVIRONMENT",