mirror of
https://github.com/formbricks/formbricks.git
synced 2026-03-24 09:22:16 -05:00
feat: add provider-aware public rate limit routing
This commit is contained in:
@@ -185,6 +185,10 @@ ENTERPRISE_LICENSE_KEY=
|
||||
# Ignore Rate Limiting across the Formbricks app
|
||||
# RATE_LIMITING_DISABLED=1
|
||||
|
||||
# Public unauthenticated IP-based rate limits can be handled by an edge provider.
|
||||
# Supported values: none, cloudflare, cloudarmor, envoy
|
||||
# EDGE_RATE_LIMIT_PROVIDER=none
|
||||
|
||||
# OpenTelemetry OTLP endpoint (base URL, exporters append /v1/traces and /v1/metrics)
|
||||
# OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318
|
||||
# OTEL_EXPORTER_OTLP_PROTOCOL=http/protobuf
|
||||
|
||||
@@ -4,6 +4,11 @@ import { ZDisplayCreateInputV2 } from "@/app/api/v2/client/[environmentId]/displ
|
||||
import { responses } from "@/app/lib/api/response";
|
||||
import { transformErrorToDetails } from "@/app/lib/api/validator";
|
||||
import { getOrganizationIdFromEnvironmentId } from "@/lib/utils/helper";
|
||||
import {
|
||||
applyPublicIpRateLimit,
|
||||
publicEdgeRateLimitPolicies,
|
||||
} from "@/modules/core/rate-limit/public-edge-rate-limit";
|
||||
import { rateLimitConfigs } from "@/modules/core/rate-limit/rate-limit-configs";
|
||||
import { getIsContactsEnabled } from "@/modules/ee/license-check/lib/utils";
|
||||
import { createDisplay } from "./lib/display";
|
||||
|
||||
@@ -24,6 +29,15 @@ export const OPTIONS = async (): Promise<Response> => {
|
||||
};
|
||||
|
||||
export const POST = async (request: Request, context: Context): Promise<Response> => {
|
||||
try {
|
||||
await applyPublicIpRateLimit(publicEdgeRateLimitPolicies.v2ClientDisplays, rateLimitConfigs.api.client);
|
||||
} catch (error) {
|
||||
return responses.tooManyRequestsResponse(
|
||||
error instanceof Error ? error.message : "Rate limit exceeded",
|
||||
true
|
||||
);
|
||||
}
|
||||
|
||||
const params = await context.params;
|
||||
const jsonInput = await request.json();
|
||||
const inputValidation = ZDisplayCreateInputV2.safeParse({
|
||||
|
||||
@@ -14,6 +14,11 @@ import { getClientIpFromHeaders } from "@/lib/utils/client-ip";
|
||||
import { getOrganizationIdFromEnvironmentId } from "@/lib/utils/helper";
|
||||
import { formatValidationErrorsForV1Api, validateResponseData } from "@/modules/api/lib/validation";
|
||||
import { validateOtherOptionLengthForMultipleChoice } from "@/modules/api/v2/lib/element";
|
||||
import {
|
||||
applyPublicIpRateLimit,
|
||||
publicEdgeRateLimitPolicies,
|
||||
} from "@/modules/core/rate-limit/public-edge-rate-limit";
|
||||
import { rateLimitConfigs } from "@/modules/core/rate-limit/rate-limit-configs";
|
||||
import { getIsContactsEnabled } from "@/modules/ee/license-check/lib/utils";
|
||||
import { createQuotaFullObject } from "@/modules/ee/quotas/lib/helpers";
|
||||
import { createResponseWithQuotaEvaluation } from "./lib/response";
|
||||
@@ -36,6 +41,15 @@ export const OPTIONS = async (): Promise<Response> => {
|
||||
};
|
||||
|
||||
export const POST = async (request: Request, context: Context): Promise<Response> => {
|
||||
try {
|
||||
await applyPublicIpRateLimit(publicEdgeRateLimitPolicies.v2ClientResponses, rateLimitConfigs.api.client);
|
||||
} catch (error) {
|
||||
return responses.tooManyRequestsResponse(
|
||||
error instanceof Error ? error.message : "Rate limit exceeded",
|
||||
true
|
||||
);
|
||||
}
|
||||
|
||||
const params = await context.params;
|
||||
const requestHeaders = await headers();
|
||||
let responseInput;
|
||||
|
||||
@@ -12,6 +12,10 @@ vi.mock("@/modules/ee/audit-logs/lib/handler", () => ({
|
||||
queueAuditEvent: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("@/modules/ee/audit-logs/types/audit-log", () => ({
|
||||
UNKNOWN_DATA: "unknown",
|
||||
}));
|
||||
|
||||
vi.mock("@sentry/nextjs", () => ({
|
||||
captureException: vi.fn(),
|
||||
withScope: vi.fn((callback) => {
|
||||
@@ -72,10 +76,13 @@ vi.mock("@/app/middleware/endpoint-validator", async () => {
|
||||
});
|
||||
|
||||
vi.mock("@/modules/core/rate-limit/helpers", () => ({
|
||||
applyIPRateLimit: vi.fn(),
|
||||
applyRateLimit: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("@/modules/core/rate-limit/public-edge-rate-limit", () => ({
|
||||
applyPublicIpRateLimitForRoute: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("@/modules/core/rate-limit/rate-limit-configs", () => ({
|
||||
rateLimitConfigs: {
|
||||
api: {
|
||||
@@ -115,6 +122,7 @@ describe("withV1ApiWrapper", () => {
|
||||
|
||||
vi.doMock("@/lib/constants", () => ({
|
||||
AUDIT_LOG_ENABLED: true,
|
||||
EDGE_RATE_LIMIT_PROVIDER: "none",
|
||||
IS_PRODUCTION: true,
|
||||
SENTRY_DSN: "dsn",
|
||||
ENCRYPTION_KEY: "test-key",
|
||||
@@ -131,11 +139,13 @@ describe("withV1ApiWrapper", () => {
|
||||
});
|
||||
|
||||
test("logs and audits on error response with API key authentication", async () => {
|
||||
const { queueAuditEvent: mockedQueueAuditEvent } =
|
||||
(await import("@/modules/ee/audit-logs/lib/handler")) as unknown as { queueAuditEvent: Mock };
|
||||
const { queueAuditEvent: mockedQueueAuditEvent } = (await import(
|
||||
"@/modules/ee/audit-logs/lib/handler"
|
||||
)) as unknown as { queueAuditEvent: Mock };
|
||||
const { authenticateRequest } = await import("@/app/api/v1/auth");
|
||||
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } =
|
||||
await import("@/app/middleware/endpoint-validator");
|
||||
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } = await import(
|
||||
"@/app/middleware/endpoint-validator"
|
||||
);
|
||||
|
||||
vi.mocked(authenticateRequest).mockResolvedValue(mockApiAuthentication);
|
||||
vi.mocked(isClientSideApiRoute).mockReturnValue({ isClientSideApi: false, isRateLimited: true });
|
||||
@@ -183,11 +193,13 @@ describe("withV1ApiWrapper", () => {
|
||||
});
|
||||
|
||||
test("does not log Sentry if not 500", async () => {
|
||||
const { queueAuditEvent: mockedQueueAuditEvent } =
|
||||
(await import("@/modules/ee/audit-logs/lib/handler")) as unknown as { queueAuditEvent: Mock };
|
||||
const { queueAuditEvent: mockedQueueAuditEvent } = (await import(
|
||||
"@/modules/ee/audit-logs/lib/handler"
|
||||
)) as unknown as { queueAuditEvent: Mock };
|
||||
const { authenticateRequest } = await import("@/app/api/v1/auth");
|
||||
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } =
|
||||
await import("@/app/middleware/endpoint-validator");
|
||||
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } = await import(
|
||||
"@/app/middleware/endpoint-validator"
|
||||
);
|
||||
|
||||
vi.mocked(authenticateRequest).mockResolvedValue(mockApiAuthentication);
|
||||
vi.mocked(isClientSideApiRoute).mockReturnValue({ isClientSideApi: false, isRateLimited: true });
|
||||
@@ -229,11 +241,13 @@ describe("withV1ApiWrapper", () => {
|
||||
});
|
||||
|
||||
test("logs and audits on thrown error", async () => {
|
||||
const { queueAuditEvent: mockedQueueAuditEvent } =
|
||||
(await import("@/modules/ee/audit-logs/lib/handler")) as unknown as { queueAuditEvent: Mock };
|
||||
const { queueAuditEvent: mockedQueueAuditEvent } = (await import(
|
||||
"@/modules/ee/audit-logs/lib/handler"
|
||||
)) as unknown as { queueAuditEvent: Mock };
|
||||
const { authenticateRequest } = await import("@/app/api/v1/auth");
|
||||
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } =
|
||||
await import("@/app/middleware/endpoint-validator");
|
||||
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } = await import(
|
||||
"@/app/middleware/endpoint-validator"
|
||||
);
|
||||
|
||||
vi.mocked(authenticateRequest).mockResolvedValue(mockApiAuthentication);
|
||||
vi.mocked(isClientSideApiRoute).mockReturnValue({ isClientSideApi: false, isRateLimited: true });
|
||||
@@ -285,11 +299,13 @@ describe("withV1ApiWrapper", () => {
|
||||
});
|
||||
|
||||
test("does not log on success response but still audits", async () => {
|
||||
const { queueAuditEvent: mockedQueueAuditEvent } =
|
||||
(await import("@/modules/ee/audit-logs/lib/handler")) as unknown as { queueAuditEvent: Mock };
|
||||
const { queueAuditEvent: mockedQueueAuditEvent } = (await import(
|
||||
"@/modules/ee/audit-logs/lib/handler"
|
||||
)) as unknown as { queueAuditEvent: Mock };
|
||||
const { authenticateRequest } = await import("@/app/api/v1/auth");
|
||||
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } =
|
||||
await import("@/app/middleware/endpoint-validator");
|
||||
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } = await import(
|
||||
"@/app/middleware/endpoint-validator"
|
||||
);
|
||||
|
||||
vi.mocked(authenticateRequest).mockResolvedValue(mockApiAuthentication);
|
||||
vi.mocked(isClientSideApiRoute).mockReturnValue({ isClientSideApi: false, isRateLimited: true });
|
||||
@@ -333,17 +349,20 @@ describe("withV1ApiWrapper", () => {
|
||||
test("does not call audit if AUDIT_LOG_ENABLED is false", async () => {
|
||||
vi.doMock("@/lib/constants", () => ({
|
||||
AUDIT_LOG_ENABLED: false,
|
||||
EDGE_RATE_LIMIT_PROVIDER: "none",
|
||||
IS_PRODUCTION: true,
|
||||
SENTRY_DSN: "dsn",
|
||||
ENCRYPTION_KEY: "test-key",
|
||||
REDIS_URL: "redis://localhost:6379",
|
||||
}));
|
||||
|
||||
const { queueAuditEvent: mockedQueueAuditEvent } =
|
||||
(await import("@/modules/ee/audit-logs/lib/handler")) as unknown as { queueAuditEvent: Mock };
|
||||
const { queueAuditEvent: mockedQueueAuditEvent } = (await import(
|
||||
"@/modules/ee/audit-logs/lib/handler"
|
||||
)) as unknown as { queueAuditEvent: Mock };
|
||||
const { authenticateRequest } = await import("@/app/api/v1/auth");
|
||||
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } =
|
||||
await import("@/app/middleware/endpoint-validator");
|
||||
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } = await import(
|
||||
"@/app/middleware/endpoint-validator"
|
||||
);
|
||||
const { withV1ApiWrapper } = await import("./with-api-logging");
|
||||
|
||||
vi.mocked(authenticateRequest).mockResolvedValue(mockApiAuthentication);
|
||||
@@ -366,10 +385,13 @@ describe("withV1ApiWrapper", () => {
|
||||
});
|
||||
|
||||
test("handles client-side API routes without authentication", async () => {
|
||||
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } =
|
||||
await import("@/app/middleware/endpoint-validator");
|
||||
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } = await import(
|
||||
"@/app/middleware/endpoint-validator"
|
||||
);
|
||||
const { authenticateRequest } = await import("@/app/api/v1/auth");
|
||||
const { applyIPRateLimit } = await import("@/modules/core/rate-limit/helpers");
|
||||
const { applyPublicIpRateLimitForRoute } = await import(
|
||||
"@/modules/core/rate-limit/public-edge-rate-limit"
|
||||
);
|
||||
|
||||
vi.mocked(isClientSideApiRoute).mockReturnValue({ isClientSideApi: true, isRateLimited: true });
|
||||
vi.mocked(isManagementApiRoute).mockReturnValue({
|
||||
@@ -378,7 +400,7 @@ describe("withV1ApiWrapper", () => {
|
||||
});
|
||||
vi.mocked(isIntegrationRoute).mockReturnValue(false);
|
||||
vi.mocked(authenticateRequest).mockResolvedValue(null);
|
||||
vi.mocked(applyIPRateLimit).mockResolvedValue(undefined);
|
||||
vi.mocked(applyPublicIpRateLimitForRoute).mockResolvedValue("app");
|
||||
|
||||
const handler = vi.fn().mockResolvedValue({
|
||||
response: responses.successResponse({ data: "test" }),
|
||||
@@ -396,11 +418,17 @@ describe("withV1ApiWrapper", () => {
|
||||
auditLog: undefined,
|
||||
authentication: null,
|
||||
});
|
||||
expect(applyPublicIpRateLimitForRoute).toHaveBeenCalledWith(
|
||||
"/api/v1/client/displays",
|
||||
"GET",
|
||||
expect.objectContaining({ max: 100 })
|
||||
);
|
||||
});
|
||||
|
||||
test("returns authentication error for non-client routes without auth", async () => {
|
||||
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } =
|
||||
await import("@/app/middleware/endpoint-validator");
|
||||
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } = await import(
|
||||
"@/app/middleware/endpoint-validator"
|
||||
);
|
||||
const { authenticateRequest } = await import("@/app/api/v1/auth");
|
||||
|
||||
vi.mocked(isClientSideApiRoute).mockReturnValue({ isClientSideApi: false, isRateLimited: true });
|
||||
@@ -422,8 +450,9 @@ describe("withV1ApiWrapper", () => {
|
||||
});
|
||||
|
||||
test("uses unauthenticatedResponse when provided instead of default 401", async () => {
|
||||
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } =
|
||||
await import("@/app/middleware/endpoint-validator");
|
||||
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } = await import(
|
||||
"@/app/middleware/endpoint-validator"
|
||||
);
|
||||
const { getServerSession } = await import("next-auth");
|
||||
|
||||
vi.mocked(isClientSideApiRoute).mockReturnValue({ isClientSideApi: false, isRateLimited: true });
|
||||
@@ -455,8 +484,9 @@ describe("withV1ApiWrapper", () => {
|
||||
|
||||
test("handles rate limiting errors", async () => {
|
||||
const { applyRateLimit } = await import("@/modules/core/rate-limit/helpers");
|
||||
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } =
|
||||
await import("@/app/middleware/endpoint-validator");
|
||||
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } = await import(
|
||||
"@/app/middleware/endpoint-validator"
|
||||
);
|
||||
const { authenticateRequest } = await import("@/app/api/v1/auth");
|
||||
|
||||
vi.mocked(authenticateRequest).mockResolvedValue(mockApiAuthentication);
|
||||
@@ -481,11 +511,13 @@ describe("withV1ApiWrapper", () => {
|
||||
});
|
||||
|
||||
test("skips audit log creation when no action/targetType provided", async () => {
|
||||
const { queueAuditEvent: mockedQueueAuditEvent } =
|
||||
(await import("@/modules/ee/audit-logs/lib/handler")) as unknown as { queueAuditEvent: Mock };
|
||||
const { queueAuditEvent: mockedQueueAuditEvent } = (await import(
|
||||
"@/modules/ee/audit-logs/lib/handler"
|
||||
)) as unknown as { queueAuditEvent: Mock };
|
||||
const { authenticateRequest } = await import("@/app/api/v1/auth");
|
||||
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } =
|
||||
await import("@/app/middleware/endpoint-validator");
|
||||
const { isClientSideApiRoute, isManagementApiRoute, isIntegrationRoute } = await import(
|
||||
"@/app/middleware/endpoint-validator"
|
||||
);
|
||||
|
||||
vi.mocked(authenticateRequest).mockResolvedValue(mockApiAuthentication);
|
||||
vi.mocked(isClientSideApiRoute).mockReturnValue({ isClientSideApi: false, isRateLimited: true });
|
||||
|
||||
@@ -13,7 +13,8 @@ import {
|
||||
} from "@/app/middleware/endpoint-validator";
|
||||
import { AUDIT_LOG_ENABLED, IS_PRODUCTION, SENTRY_DSN } from "@/lib/constants";
|
||||
import { authOptions } from "@/modules/auth/lib/authOptions";
|
||||
import { applyIPRateLimit, applyRateLimit } from "@/modules/core/rate-limit/helpers";
|
||||
import { applyRateLimit } from "@/modules/core/rate-limit/helpers";
|
||||
import { applyPublicIpRateLimitForRoute } from "@/modules/core/rate-limit/public-edge-rate-limit";
|
||||
import { rateLimitConfigs } from "@/modules/core/rate-limit/rate-limit-configs";
|
||||
import { TRateLimitConfig } from "@/modules/core/rate-limit/types/rate-limit";
|
||||
import { queueAuditEvent } from "@/modules/ee/audit-logs/lib/handler";
|
||||
@@ -54,14 +55,22 @@ enum ApiV1RouteTypeEnum {
|
||||
/**
|
||||
* Apply client-side API rate limiting (IP-based)
|
||||
*/
|
||||
const applyClientRateLimit = async (customRateLimitConfig?: TRateLimitConfig): Promise<void> => {
|
||||
await applyIPRateLimit(customRateLimitConfig ?? rateLimitConfigs.api.client);
|
||||
const applyClientRateLimit = async (
|
||||
req: NextRequest,
|
||||
customRateLimitConfig?: TRateLimitConfig
|
||||
): Promise<void> => {
|
||||
await applyPublicIpRateLimitForRoute(
|
||||
req.nextUrl.pathname,
|
||||
req.method,
|
||||
customRateLimitConfig ?? rateLimitConfigs.api.client
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Handle rate limiting based on authentication and API type
|
||||
*/
|
||||
const handleRateLimiting = async (
|
||||
req: NextRequest,
|
||||
authentication: TApiV1Authentication,
|
||||
routeType: ApiV1RouteTypeEnum,
|
||||
customRateLimitConfig?: TRateLimitConfig
|
||||
@@ -81,7 +90,7 @@ const handleRateLimiting = async (
|
||||
}
|
||||
|
||||
if (routeType === ApiV1RouteTypeEnum.Client) {
|
||||
await applyClientRateLimit(customRateLimitConfig);
|
||||
await applyClientRateLimit(req, customRateLimitConfig);
|
||||
}
|
||||
} catch (error) {
|
||||
return responses.tooManyRequestsResponse(error instanceof Error ? error.message : "Rate limit exceeded");
|
||||
@@ -305,7 +314,12 @@ export const withV1ApiWrapper = <TResult extends { response: Response }, TProps
|
||||
|
||||
// === Rate Limiting ===
|
||||
if (isRateLimited) {
|
||||
const rateLimitResponse = await handleRateLimiting(authentication, routeType, customRateLimitConfig);
|
||||
const rateLimitResponse = await handleRateLimiting(
|
||||
req,
|
||||
authentication,
|
||||
routeType,
|
||||
customRateLimitConfig
|
||||
);
|
||||
if (rateLimitResponse) return rateLimitResponse;
|
||||
}
|
||||
|
||||
|
||||
@@ -48,6 +48,10 @@ describe("endpoint-validator", () => {
|
||||
isClientSideApi: true,
|
||||
isRateLimited: false,
|
||||
});
|
||||
expect(isClientSideApiRoute("/api/v1/client/og-image")).toEqual({
|
||||
isClientSideApi: true,
|
||||
isRateLimited: true,
|
||||
});
|
||||
});
|
||||
|
||||
test("should return false for non-client-side API routes", () => {
|
||||
|
||||
@@ -13,7 +13,7 @@ export enum AuthenticationMethod {
|
||||
|
||||
export const isClientSideApiRoute = (url: string): { isClientSideApi: boolean; isRateLimited: boolean } => {
|
||||
// Open Graph image generation route is a client side API route but it should not be rate limited
|
||||
if (url.includes("/api/v1/client/og")) return { isClientSideApi: true, isRateLimited: false };
|
||||
if (/^\/api\/v1\/client\/og(?:\/.*)?$/.test(url)) return { isClientSideApi: true, isRateLimited: false };
|
||||
|
||||
const regex = /^\/api\/v\d+\/client\//;
|
||||
return { isClientSideApi: regex.test(url), isRateLimited: true };
|
||||
|
||||
@@ -3,6 +3,7 @@ import { TUserLocale } from "@formbricks/types/user";
|
||||
import { env } from "./env";
|
||||
|
||||
export const IS_FORMBRICKS_CLOUD = env.IS_FORMBRICKS_CLOUD === "1";
|
||||
export const EDGE_RATE_LIMIT_PROVIDER = env.EDGE_RATE_LIMIT_PROVIDER ?? "none";
|
||||
|
||||
export const IS_PRODUCTION = env.NODE_ENV === "production";
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ export const env = createEnv({
|
||||
E2E_TESTING: z.enum(["1", "0"]).optional(),
|
||||
EMAIL_AUTH_DISABLED: z.enum(["1", "0"]).optional(),
|
||||
EMAIL_VERIFICATION_DISABLED: z.enum(["1", "0"]).optional(),
|
||||
EDGE_RATE_LIMIT_PROVIDER: z.enum(["none", "cloudflare", "cloudarmor", "envoy"]).optional(),
|
||||
ENCRYPTION_KEY: z.string(),
|
||||
ENTERPRISE_LICENSE_KEY: z.string().optional(),
|
||||
ENVIRONMENT: z.enum(["production", "staging"]).prefault("production"),
|
||||
@@ -147,6 +148,7 @@ export const env = createEnv({
|
||||
E2E_TESTING: process.env.E2E_TESTING,
|
||||
EMAIL_AUTH_DISABLED: process.env.EMAIL_AUTH_DISABLED,
|
||||
EMAIL_VERIFICATION_DISABLED: process.env.EMAIL_VERIFICATION_DISABLED,
|
||||
EDGE_RATE_LIMIT_PROVIDER: process.env.EDGE_RATE_LIMIT_PROVIDER,
|
||||
ENCRYPTION_KEY: process.env.ENCRYPTION_KEY,
|
||||
ENTERPRISE_LICENSE_KEY: process.env.ENTERPRISE_LICENSE_KEY,
|
||||
ENVIRONMENT: process.env.ENVIRONMENT,
|
||||
|
||||
@@ -3,11 +3,14 @@ import { Provider } from "next-auth/providers/index";
|
||||
import { afterEach, describe, expect, test, vi } from "vitest";
|
||||
import { prisma } from "@formbricks/database";
|
||||
import { EMAIL_VERIFICATION_DISABLED } from "@/lib/constants";
|
||||
// Import mocked rate limiting functions
|
||||
import { applyIPRateLimit } from "@/modules/core/rate-limit/helpers";
|
||||
import {
|
||||
applyPublicIpRateLimit,
|
||||
publicEdgeRateLimitPolicies,
|
||||
} from "@/modules/core/rate-limit/public-edge-rate-limit";
|
||||
import { rateLimitConfigs } from "@/modules/core/rate-limit/rate-limit-configs";
|
||||
import { authOptions } from "./authOptions";
|
||||
import { mockUser } from "./mock-data";
|
||||
import { getUserByEmail } from "./user";
|
||||
import { hashPassword } from "./utils";
|
||||
|
||||
// Mock encryption utilities
|
||||
@@ -19,11 +22,48 @@ vi.mock("@/lib/encryption", () => ({
|
||||
// Mock JWT
|
||||
vi.mock("@/lib/jwt");
|
||||
|
||||
// Mock rate limiting dependencies
|
||||
vi.mock("@/modules/core/rate-limit/helpers", () => ({
|
||||
applyIPRateLimit: vi.fn(),
|
||||
vi.mock("@/modules/core/rate-limit/public-edge-rate-limit", () => ({
|
||||
applyPublicIpRateLimit: vi.fn(),
|
||||
publicEdgeRateLimitPolicies: {
|
||||
authLogin: "auth.login",
|
||||
authVerifyEmail: "auth.verify_email",
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock("./user", () => ({
|
||||
getUserByEmail: vi.fn(),
|
||||
updateUser: vi.fn(),
|
||||
updateUserLastLoginAt: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("./brevo", () => ({
|
||||
createBrevoCustomer: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("@/modules/ee/sso/lib/providers", () => ({
|
||||
getSSOProviders: vi.fn(() => []),
|
||||
}));
|
||||
|
||||
vi.mock("@/modules/ee/sso/lib/sso-handlers", () => ({
|
||||
handleSsoCallback: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("@/modules/ee/audit-logs/lib/handler", () => ({
|
||||
queueAuditEventBackground: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("@/modules/ee/audit-logs/types/audit-log", () => ({
|
||||
UNKNOWN_DATA: "unknown",
|
||||
}));
|
||||
|
||||
vi.mock("./utils", async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import("./utils")>();
|
||||
return {
|
||||
...actual,
|
||||
shouldLogAuthFailure: vi.fn().mockResolvedValue(false),
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock("@/modules/core/rate-limit/rate-limit-configs", () => ({
|
||||
rateLimitConfigs: {
|
||||
auth: {
|
||||
@@ -33,26 +73,22 @@ vi.mock("@/modules/core/rate-limit/rate-limit-configs", () => ({
|
||||
},
|
||||
}));
|
||||
|
||||
// Mock constants that this test needs while preserving untouched exports.
|
||||
vi.mock("@/lib/constants", async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import("@/lib/constants")>();
|
||||
return {
|
||||
...actual,
|
||||
EMAIL_VERIFICATION_DISABLED: false,
|
||||
SESSION_MAX_AGE: 86400,
|
||||
NEXTAUTH_SECRET: "test-secret",
|
||||
WEBAPP_URL: "http://localhost:3000",
|
||||
ENCRYPTION_KEY: "12345678901234567890123456789012", // 32 bytes for AES-256
|
||||
REDIS_URL: undefined,
|
||||
AUDIT_LOG_ENABLED: false,
|
||||
AUDIT_LOG_GET_USER_IP: false,
|
||||
ENTERPRISE_LICENSE_KEY: undefined,
|
||||
SENTRY_DSN: undefined,
|
||||
BREVO_API_KEY: undefined,
|
||||
RATE_LIMITING_DISABLED: false,
|
||||
CONTROL_HASH: "$2b$12$fzHf9le13Ss9UJ04xzmsjODXpFJxz6vsnupoepF5FiqDECkX2BH5q",
|
||||
};
|
||||
});
|
||||
vi.mock("@/lib/constants", () => ({
|
||||
EMAIL_VERIFICATION_DISABLED: false,
|
||||
EDGE_RATE_LIMIT_PROVIDER: "none",
|
||||
SESSION_MAX_AGE: 86400,
|
||||
NEXTAUTH_SECRET: "test-secret",
|
||||
WEBAPP_URL: "http://localhost:3000",
|
||||
ENCRYPTION_KEY: "12345678901234567890123456789012", // 32 bytes for AES-256
|
||||
REDIS_URL: undefined,
|
||||
AUDIT_LOG_ENABLED: false,
|
||||
AUDIT_LOG_GET_USER_IP: false,
|
||||
ENTERPRISE_LICENSE_KEY: undefined,
|
||||
SENTRY_DSN: undefined,
|
||||
BREVO_API_KEY: undefined,
|
||||
RATE_LIMITING_DISABLED: false,
|
||||
CONTROL_HASH: "$2b$12$fzHf9le13Ss9UJ04xzmsjODXpFJxz6vsnupoepF5FiqDECkX2BH5q",
|
||||
}));
|
||||
|
||||
// Mock next/headers
|
||||
vi.mock("next/headers", () => ({
|
||||
@@ -114,7 +150,7 @@ describe("authOptions", () => {
|
||||
});
|
||||
|
||||
test("should throw error if user not found", async () => {
|
||||
vi.mocked(applyIPRateLimit).mockResolvedValue(); // Rate limiting passes
|
||||
vi.mocked(applyPublicIpRateLimit).mockResolvedValue("app");
|
||||
vi.spyOn(prisma.user, "findUnique").mockResolvedValue(null);
|
||||
|
||||
const credentials = { email: mockUser.email, password: mockPassword };
|
||||
@@ -125,7 +161,7 @@ describe("authOptions", () => {
|
||||
});
|
||||
|
||||
test("should throw error if user has no password stored", async () => {
|
||||
vi.mocked(applyIPRateLimit).mockResolvedValue(); // Rate limiting passes
|
||||
vi.mocked(applyPublicIpRateLimit).mockResolvedValue("app");
|
||||
vi.spyOn(prisma.user, "findUnique").mockResolvedValue({
|
||||
id: mockUser.id,
|
||||
email: mockUser.email,
|
||||
@@ -140,7 +176,7 @@ describe("authOptions", () => {
|
||||
});
|
||||
|
||||
test("should throw error if password verification fails", async () => {
|
||||
vi.mocked(applyIPRateLimit).mockResolvedValue(); // Rate limiting passes
|
||||
vi.mocked(applyPublicIpRateLimit).mockResolvedValue("app");
|
||||
vi.spyOn(prisma.user, "findUnique").mockResolvedValue({
|
||||
id: mockUserId,
|
||||
email: mockUser.email,
|
||||
@@ -155,7 +191,7 @@ describe("authOptions", () => {
|
||||
});
|
||||
|
||||
test("should successfully login when credentials are valid", async () => {
|
||||
vi.mocked(applyIPRateLimit).mockResolvedValue(); // Rate limiting passes
|
||||
vi.mocked(applyPublicIpRateLimit).mockResolvedValue("app");
|
||||
const fakeUser = {
|
||||
id: mockUserId,
|
||||
email: mockUser.email,
|
||||
@@ -178,7 +214,7 @@ describe("authOptions", () => {
|
||||
|
||||
describe("Rate Limiting", () => {
|
||||
test("should apply rate limiting before credential validation", async () => {
|
||||
vi.mocked(applyIPRateLimit).mockResolvedValue();
|
||||
vi.mocked(applyPublicIpRateLimit).mockResolvedValue("app");
|
||||
vi.spyOn(prisma.user, "findUnique").mockResolvedValue({
|
||||
id: mockUserId,
|
||||
email: mockUser.email,
|
||||
@@ -191,12 +227,15 @@ describe("authOptions", () => {
|
||||
|
||||
await credentialsProvider.options.authorize(credentials, {});
|
||||
|
||||
expect(applyIPRateLimit).toHaveBeenCalledWith(rateLimitConfigs.auth.login);
|
||||
expect(applyIPRateLimit).toHaveBeenCalledBefore(prisma.user.findUnique as any);
|
||||
expect(applyPublicIpRateLimit).toHaveBeenCalledWith(
|
||||
publicEdgeRateLimitPolicies.authLogin,
|
||||
rateLimitConfigs.auth.login
|
||||
);
|
||||
expect(applyPublicIpRateLimit).toHaveBeenCalledBefore(prisma.user.findUnique as any);
|
||||
});
|
||||
|
||||
test("should block login when rate limit exceeded", async () => {
|
||||
vi.mocked(applyIPRateLimit).mockRejectedValue(
|
||||
vi.mocked(applyPublicIpRateLimit).mockRejectedValue(
|
||||
new Error("Maximum number of requests reached. Please try again later.")
|
||||
);
|
||||
const findUniqueSpy = vi.spyOn(prisma.user, "findUnique");
|
||||
@@ -211,7 +250,7 @@ describe("authOptions", () => {
|
||||
});
|
||||
|
||||
test("should use correct rate limit configuration", async () => {
|
||||
vi.mocked(applyIPRateLimit).mockResolvedValue();
|
||||
vi.mocked(applyPublicIpRateLimit).mockResolvedValue("app");
|
||||
vi.spyOn(prisma.user, "findUnique").mockResolvedValue({
|
||||
id: mockUserId,
|
||||
email: mockUser.email,
|
||||
@@ -224,7 +263,7 @@ describe("authOptions", () => {
|
||||
|
||||
await credentialsProvider.options.authorize(credentials, {});
|
||||
|
||||
expect(applyIPRateLimit).toHaveBeenCalledWith({
|
||||
expect(applyPublicIpRateLimit).toHaveBeenCalledWith(publicEdgeRateLimitPolicies.authLogin, {
|
||||
interval: 900,
|
||||
allowedPerInterval: 30,
|
||||
namespace: "auth:login",
|
||||
@@ -234,7 +273,7 @@ describe("authOptions", () => {
|
||||
|
||||
describe("Two-Factor Backup Code login", () => {
|
||||
test("should throw error if backup codes are missing", async () => {
|
||||
vi.mocked(applyIPRateLimit).mockResolvedValue(); // Rate limiting passes
|
||||
vi.mocked(applyPublicIpRateLimit).mockResolvedValue("app");
|
||||
const mockUser = {
|
||||
id: mockUserId,
|
||||
email: "2fa@example.com",
|
||||
@@ -263,7 +302,7 @@ describe("authOptions", () => {
|
||||
});
|
||||
|
||||
test("should throw error if token is invalid or user not found", async () => {
|
||||
vi.mocked(applyIPRateLimit).mockResolvedValue(); // Rate limiting passes
|
||||
vi.mocked(applyPublicIpRateLimit).mockResolvedValue("app");
|
||||
const credentials = { token: "badtoken" };
|
||||
|
||||
await expect(tokenProvider.options.authorize(credentials, {})).rejects.toThrow(
|
||||
@@ -273,17 +312,20 @@ describe("authOptions", () => {
|
||||
|
||||
describe("Rate Limiting", () => {
|
||||
test("should apply rate limiting before token verification", async () => {
|
||||
vi.mocked(applyIPRateLimit).mockResolvedValue();
|
||||
vi.mocked(applyPublicIpRateLimit).mockResolvedValue("app");
|
||||
|
||||
const credentials = { token: "sometoken" };
|
||||
|
||||
await expect(tokenProvider.options.authorize(credentials, {})).rejects.toThrow();
|
||||
|
||||
expect(applyIPRateLimit).toHaveBeenCalledWith(rateLimitConfigs.auth.verifyEmail);
|
||||
expect(applyPublicIpRateLimit).toHaveBeenCalledWith(
|
||||
publicEdgeRateLimitPolicies.authVerifyEmail,
|
||||
rateLimitConfigs.auth.verifyEmail
|
||||
);
|
||||
});
|
||||
|
||||
test("should block verification when rate limit exceeded", async () => {
|
||||
vi.mocked(applyIPRateLimit).mockRejectedValue(
|
||||
vi.mocked(applyPublicIpRateLimit).mockRejectedValue(
|
||||
new Error("Maximum number of requests reached. Please try again later.")
|
||||
);
|
||||
const findUniqueSpy = vi.spyOn(prisma.user, "findUnique");
|
||||
@@ -302,7 +344,7 @@ describe("authOptions", () => {
|
||||
describe("Callbacks", () => {
|
||||
describe("jwt callback", () => {
|
||||
test("should add profile information to token if user is found", async () => {
|
||||
vi.spyOn(prisma.user, "findFirst").mockResolvedValue({
|
||||
vi.mocked(getUserByEmail).mockResolvedValue({
|
||||
id: mockUser.id,
|
||||
locale: mockUser.locale,
|
||||
email: mockUser.email,
|
||||
@@ -321,7 +363,7 @@ describe("authOptions", () => {
|
||||
});
|
||||
|
||||
test("should return token unchanged if no existing user is found", async () => {
|
||||
vi.spyOn(prisma.user, "findFirst").mockResolvedValue(null);
|
||||
vi.mocked(getUserByEmail).mockResolvedValue(null);
|
||||
|
||||
const token = { email: "nonexistent@example.com" };
|
||||
if (!authOptions.callbacks?.jwt) {
|
||||
@@ -366,7 +408,7 @@ describe("authOptions", () => {
|
||||
const credentialsProvider = getProviderById("credentials");
|
||||
|
||||
test("should throw error if TOTP code is missing when 2FA is enabled", async () => {
|
||||
vi.mocked(applyIPRateLimit).mockResolvedValue(); // Rate limiting passes
|
||||
vi.mocked(applyPublicIpRateLimit).mockResolvedValue("app");
|
||||
const mockUser = {
|
||||
id: mockUserId,
|
||||
email: "2fa@example.com",
|
||||
@@ -384,7 +426,7 @@ describe("authOptions", () => {
|
||||
});
|
||||
|
||||
test("should throw error if two factor secret is missing", async () => {
|
||||
vi.mocked(applyIPRateLimit).mockResolvedValue(); // Rate limiting passes
|
||||
vi.mocked(applyPublicIpRateLimit).mockResolvedValue("app");
|
||||
const mockUser = {
|
||||
id: mockUserId,
|
||||
email: "2fa@example.com",
|
||||
|
||||
@@ -23,7 +23,10 @@ import {
|
||||
shouldLogAuthFailure,
|
||||
verifyPassword,
|
||||
} from "@/modules/auth/lib/utils";
|
||||
import { applyIPRateLimit } from "@/modules/core/rate-limit/helpers";
|
||||
import {
|
||||
applyPublicIpRateLimit,
|
||||
publicEdgeRateLimitPolicies,
|
||||
} from "@/modules/core/rate-limit/public-edge-rate-limit";
|
||||
import { rateLimitConfigs } from "@/modules/core/rate-limit/rate-limit-configs";
|
||||
import { UNKNOWN_DATA } from "@/modules/ee/audit-logs/types/audit-log";
|
||||
import { getSSOProviders } from "@/modules/ee/sso/lib/providers";
|
||||
@@ -55,7 +58,7 @@ export const authOptions: NextAuthOptions = {
|
||||
backupCode: { label: "Backup Code", type: "input", placeholder: "Two-factor backup code" },
|
||||
},
|
||||
async authorize(credentials, _req) {
|
||||
await applyIPRateLimit(rateLimitConfigs.auth.login);
|
||||
await applyPublicIpRateLimit(publicEdgeRateLimitPolicies.authLogin, rateLimitConfigs.auth.login);
|
||||
|
||||
// Use email for rate limiting when available, fall back to "unknown_user" for credential validation
|
||||
const identifier = credentials?.email || "unknown_user"; // NOSONAR // We want to check for empty strings
|
||||
@@ -245,7 +248,10 @@ export const authOptions: NextAuthOptions = {
|
||||
},
|
||||
},
|
||||
async authorize(credentials, _req) {
|
||||
await applyIPRateLimit(rateLimitConfigs.auth.verifyEmail);
|
||||
await applyPublicIpRateLimit(
|
||||
publicEdgeRateLimitPolicies.authVerifyEmail,
|
||||
rateLimitConfigs.auth.verifyEmail
|
||||
);
|
||||
|
||||
// For token verification, we can't rate limit effectively by token (single-use)
|
||||
// So we use a generic identifier for token abuse attempts
|
||||
|
||||
142
apps/web/modules/core/rate-limit/public-edge-rate-limit.test.ts
Normal file
142
apps/web/modules/core/rate-limit/public-edge-rate-limit.test.ts
Normal file
@@ -0,0 +1,142 @@
|
||||
import { beforeEach, describe, expect, test, vi } from "vitest";
|
||||
import { applyIPRateLimit } from "./helpers";
|
||||
import {
|
||||
applyPublicIpRateLimit,
|
||||
applyPublicIpRateLimitForRoute,
|
||||
getEdgeRateLimitProvider,
|
||||
getPublicEdgeRateLimitPolicyId,
|
||||
isPublicEdgeRateLimitManaged,
|
||||
publicEdgeRateLimitPolicies,
|
||||
} from "./public-edge-rate-limit";
|
||||
|
||||
vi.mock("./helpers", () => ({
|
||||
applyIPRateLimit: vi.fn(),
|
||||
}));
|
||||
|
||||
const mockConfig = {
|
||||
interval: 60,
|
||||
allowedPerInterval: 100,
|
||||
namespace: "api:client",
|
||||
};
|
||||
|
||||
describe("public-edge-rate-limit", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("getEdgeRateLimitProvider", () => {
|
||||
test("falls back to none for unknown providers", () => {
|
||||
expect(getEdgeRateLimitProvider(undefined)).toBe("none");
|
||||
expect(getEdgeRateLimitProvider("unknown")).toBe("none");
|
||||
});
|
||||
|
||||
test("accepts configured providers", () => {
|
||||
expect(getEdgeRateLimitProvider("cloudflare")).toBe("cloudflare");
|
||||
expect(getEdgeRateLimitProvider("cloudarmor")).toBe("cloudarmor");
|
||||
expect(getEdgeRateLimitProvider("envoy")).toBe("envoy");
|
||||
});
|
||||
});
|
||||
|
||||
describe("getPublicEdgeRateLimitPolicyId", () => {
|
||||
test("classifies auth callback routes", () => {
|
||||
expect(getPublicEdgeRateLimitPolicyId("/api/auth/callback/credentials", "POST")).toBe(
|
||||
publicEdgeRateLimitPolicies.authLogin
|
||||
);
|
||||
expect(getPublicEdgeRateLimitPolicyId("/api/auth/callback/token", "POST")).toBe(
|
||||
publicEdgeRateLimitPolicies.authVerifyEmail
|
||||
);
|
||||
});
|
||||
|
||||
test("classifies v1 client routes", () => {
|
||||
expect(getPublicEdgeRateLimitPolicyId("/api/v1/client/env_123/environment", "GET")).toBe(
|
||||
publicEdgeRateLimitPolicies.v1ClientDefault
|
||||
);
|
||||
expect(getPublicEdgeRateLimitPolicyId("/api/v1/client/env_123/storage", "POST")).toBe(
|
||||
publicEdgeRateLimitPolicies.v1ClientStorageUpload
|
||||
);
|
||||
expect(getPublicEdgeRateLimitPolicyId("/api/v1/client/og", "GET")).toBeNull();
|
||||
expect(getPublicEdgeRateLimitPolicyId("/api/v1/client/og/image", "GET")).toBeNull();
|
||||
expect(getPublicEdgeRateLimitPolicyId("/api/v1/client/og-image", "GET")).toBe(
|
||||
publicEdgeRateLimitPolicies.v1ClientDefault
|
||||
);
|
||||
});
|
||||
|
||||
test("classifies v2 public write routes", () => {
|
||||
expect(getPublicEdgeRateLimitPolicyId("/api/v2/client/env_123/responses", "POST")).toBe(
|
||||
publicEdgeRateLimitPolicies.v2ClientResponses
|
||||
);
|
||||
expect(getPublicEdgeRateLimitPolicyId("/api/v2/client/env_123/responses/resp_123", "PUT")).toBe(
|
||||
publicEdgeRateLimitPolicies.v2ClientResponses
|
||||
);
|
||||
expect(getPublicEdgeRateLimitPolicyId("/api/v2/client/env_123/displays", "POST")).toBe(
|
||||
publicEdgeRateLimitPolicies.v2ClientDisplays
|
||||
);
|
||||
expect(getPublicEdgeRateLimitPolicyId("/api/v2/client/env_123/storage", "POST")).toBe(
|
||||
publicEdgeRateLimitPolicies.v2ClientStorageUpload
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("isPublicEdgeRateLimitManaged", () => {
|
||||
test("manages public policies on cloudflare and cloudarmor only", () => {
|
||||
expect(isPublicEdgeRateLimitManaged(publicEdgeRateLimitPolicies.authLogin, "cloudflare")).toBe(true);
|
||||
expect(isPublicEdgeRateLimitManaged(publicEdgeRateLimitPolicies.authLogin, "cloudarmor")).toBe(true);
|
||||
expect(isPublicEdgeRateLimitManaged(publicEdgeRateLimitPolicies.authLogin, "none")).toBe(false);
|
||||
expect(isPublicEdgeRateLimitManaged(publicEdgeRateLimitPolicies.authLogin, "envoy")).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("applyPublicIpRateLimit", () => {
|
||||
test("uses app rate limiting when no edge provider manages the policy", async () => {
|
||||
vi.mocked(applyIPRateLimit).mockResolvedValue({ allowed: true });
|
||||
|
||||
const source = await applyPublicIpRateLimit(
|
||||
publicEdgeRateLimitPolicies.v2ClientResponses,
|
||||
mockConfig,
|
||||
"none"
|
||||
);
|
||||
|
||||
expect(source).toBe("app");
|
||||
expect(applyIPRateLimit).toHaveBeenCalledWith(mockConfig);
|
||||
});
|
||||
|
||||
test("skips app rate limiting when the edge provider manages the policy", async () => {
|
||||
const source = await applyPublicIpRateLimit(
|
||||
publicEdgeRateLimitPolicies.v2ClientResponses,
|
||||
mockConfig,
|
||||
"cloudflare"
|
||||
);
|
||||
|
||||
expect(source).toBe("edge");
|
||||
expect(applyIPRateLimit).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("applyPublicIpRateLimitForRoute", () => {
|
||||
test("uses the route classifier for managed public routes", async () => {
|
||||
const source = await applyPublicIpRateLimitForRoute(
|
||||
"/api/v2/client/env_123/displays",
|
||||
"POST",
|
||||
mockConfig,
|
||||
"cloudarmor"
|
||||
);
|
||||
|
||||
expect(source).toBe("edge");
|
||||
expect(applyIPRateLimit).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("falls back to app rate limiting for unmanaged routes", async () => {
|
||||
vi.mocked(applyIPRateLimit).mockResolvedValue({ allowed: true });
|
||||
|
||||
const source = await applyPublicIpRateLimitForRoute(
|
||||
"/api/v1/client/env_123/environment",
|
||||
"GET",
|
||||
mockConfig,
|
||||
"envoy"
|
||||
);
|
||||
|
||||
expect(source).toBe("app");
|
||||
expect(applyIPRateLimit).toHaveBeenCalledWith(mockConfig);
|
||||
});
|
||||
});
|
||||
});
|
||||
135
apps/web/modules/core/rate-limit/public-edge-rate-limit.ts
Normal file
135
apps/web/modules/core/rate-limit/public-edge-rate-limit.ts
Normal file
@@ -0,0 +1,135 @@
|
||||
import { EDGE_RATE_LIMIT_PROVIDER } from "@/lib/constants";
|
||||
import { applyIPRateLimit } from "./helpers";
|
||||
import { TRateLimitConfig } from "./types/rate-limit";
|
||||
|
||||
export const publicEdgeRateLimitPolicies = {
|
||||
authLogin: "auth.login",
|
||||
authVerifyEmail: "auth.verify_email",
|
||||
v1ClientDefault: "client.v1.default",
|
||||
v1ClientStorageUpload: "client.storage.upload.v1",
|
||||
v2ClientResponses: "client.responses.v2",
|
||||
v2ClientDisplays: "client.displays.v2",
|
||||
v2ClientStorageUpload: "client.storage.upload.v2",
|
||||
} as const;
|
||||
|
||||
export type TPublicEdgeRateLimitPolicyId =
|
||||
(typeof publicEdgeRateLimitPolicies)[keyof typeof publicEdgeRateLimitPolicies];
|
||||
|
||||
export type TEdgeRateLimitProvider = "none" | "cloudflare" | "cloudarmor" | "envoy";
|
||||
|
||||
const managedPublicEdgePolicies = Object.values(
|
||||
publicEdgeRateLimitPolicies
|
||||
) as TPublicEdgeRateLimitPolicyId[];
|
||||
|
||||
const managedPublicEdgePoliciesByProvider: Record<
|
||||
TEdgeRateLimitProvider,
|
||||
readonly TPublicEdgeRateLimitPolicyId[]
|
||||
> = {
|
||||
none: [],
|
||||
cloudflare: managedPublicEdgePolicies,
|
||||
cloudarmor: managedPublicEdgePolicies,
|
||||
envoy: [],
|
||||
};
|
||||
|
||||
const normalizeEdgeRateLimitProvider = (provider: string | undefined): TEdgeRateLimitProvider => {
|
||||
switch (provider) {
|
||||
case "cloudflare":
|
||||
case "cloudarmor":
|
||||
case "envoy":
|
||||
return provider;
|
||||
default:
|
||||
return "none";
|
||||
}
|
||||
};
|
||||
|
||||
const normalizePathname = (pathname: string): string => {
|
||||
if (pathname.length > 1 && pathname.endsWith("/")) {
|
||||
return pathname.slice(0, -1);
|
||||
}
|
||||
|
||||
return pathname;
|
||||
};
|
||||
|
||||
export const getEdgeRateLimitProvider = (
|
||||
provider: string | undefined = EDGE_RATE_LIMIT_PROVIDER
|
||||
): TEdgeRateLimitProvider => normalizeEdgeRateLimitProvider(provider);
|
||||
|
||||
export const getPublicEdgeRateLimitPolicyId = (
|
||||
pathname: string,
|
||||
method: string
|
||||
): TPublicEdgeRateLimitPolicyId | null => {
|
||||
const normalizedPathname = normalizePathname(pathname);
|
||||
const normalizedMethod = method.toUpperCase();
|
||||
|
||||
if (normalizedMethod === "POST" && normalizedPathname === "/api/auth/callback/credentials") {
|
||||
return publicEdgeRateLimitPolicies.authLogin;
|
||||
}
|
||||
|
||||
if (normalizedMethod === "POST" && normalizedPathname === "/api/auth/callback/token") {
|
||||
return publicEdgeRateLimitPolicies.authVerifyEmail;
|
||||
}
|
||||
|
||||
if (/^\/api\/v1\/client\/og(?:\/.*)?$/.test(normalizedPathname)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (/^\/api\/v1\/client\/[^/]+\/storage$/.test(normalizedPathname) && normalizedMethod === "POST") {
|
||||
return publicEdgeRateLimitPolicies.v1ClientStorageUpload;
|
||||
}
|
||||
|
||||
if (/^\/api\/v2\/client\/[^/]+\/storage$/.test(normalizedPathname) && normalizedMethod === "POST") {
|
||||
return publicEdgeRateLimitPolicies.v2ClientStorageUpload;
|
||||
}
|
||||
|
||||
if (
|
||||
/^\/api\/v2\/client\/[^/]+\/responses(?:\/[^/]+)?$/.test(normalizedPathname) &&
|
||||
(normalizedMethod === "POST" || normalizedMethod === "PUT")
|
||||
) {
|
||||
return publicEdgeRateLimitPolicies.v2ClientResponses;
|
||||
}
|
||||
|
||||
if (/^\/api\/v2\/client\/[^/]+\/displays$/.test(normalizedPathname) && normalizedMethod === "POST") {
|
||||
return publicEdgeRateLimitPolicies.v2ClientDisplays;
|
||||
}
|
||||
|
||||
if (normalizedPathname.startsWith("/api/v1/client/")) {
|
||||
return publicEdgeRateLimitPolicies.v1ClientDefault;
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
export const isPublicEdgeRateLimitManaged = (
|
||||
policyId: TPublicEdgeRateLimitPolicyId,
|
||||
provider: string | undefined = EDGE_RATE_LIMIT_PROVIDER
|
||||
): boolean => managedPublicEdgePoliciesByProvider[getEdgeRateLimitProvider(provider)].includes(policyId);
|
||||
|
||||
export const applyPublicIpRateLimit = async (
|
||||
policyId: TPublicEdgeRateLimitPolicyId,
|
||||
config: TRateLimitConfig,
|
||||
provider: string | undefined = EDGE_RATE_LIMIT_PROVIDER
|
||||
): Promise<"app" | "edge"> => {
|
||||
if (isPublicEdgeRateLimitManaged(policyId, provider)) {
|
||||
return "edge";
|
||||
}
|
||||
|
||||
await applyIPRateLimit(config);
|
||||
|
||||
return "app";
|
||||
};
|
||||
|
||||
export const applyPublicIpRateLimitForRoute = async (
|
||||
pathname: string,
|
||||
method: string,
|
||||
config: TRateLimitConfig,
|
||||
provider: string | undefined = EDGE_RATE_LIMIT_PROVIDER
|
||||
): Promise<"app" | "edge"> => {
|
||||
const policyId = getPublicEdgeRateLimitPolicyId(pathname, method);
|
||||
|
||||
if (!policyId) {
|
||||
await applyIPRateLimit(config);
|
||||
return "app";
|
||||
}
|
||||
|
||||
return await applyPublicIpRateLimit(policyId, config, provider);
|
||||
};
|
||||
@@ -186,6 +186,7 @@ export const testInputValidation = async (service: Function, ...args: any[]): Pr
|
||||
|
||||
vi.mock("@/lib/constants", () => ({
|
||||
IS_FORMBRICKS_CLOUD: false,
|
||||
EDGE_RATE_LIMIT_PROVIDER: "none",
|
||||
ENCRYPTION_KEY: "mock-encryption-key",
|
||||
ENTERPRISE_LICENSE_KEY: "mock-enterprise-license-key",
|
||||
GITHUB_ID: "mock-github-id",
|
||||
|
||||
@@ -146,6 +146,7 @@
|
||||
"E2E_TESTING",
|
||||
"EMAIL_AUTH_DISABLED",
|
||||
"EMAIL_VERIFICATION_DISABLED",
|
||||
"EDGE_RATE_LIMIT_PROVIDER",
|
||||
"ENCRYPTION_KEY",
|
||||
"ENTERPRISE_LICENSE_KEY",
|
||||
"ENVIRONMENT",
|
||||
|
||||
Reference in New Issue
Block a user