Compare commits

...

5 Commits

Author SHA1 Message Date
Tiago Farto 1fb59f4b60 chore: improved test coverage 2026-05-18 12:09:01 +00:00
Tiago Farto ebf8fc017c chore: improve test coverage 2026-05-18 11:57:56 +00:00
Tiago Farto 5c4f5eb0d6 chore: increased test coverage 2026-05-18 11:41:30 +00:00
Tiago Farto fe4b7d9962 chore: linting fixes 2026-05-18 11:20:53 +00:00
Tiago Farto a9939c65c4 fix: add CSRF protection to integration OAuth flows 2026-05-18 10:28:38 +00:00
12 changed files with 717 additions and 83 deletions
+92 -38
View File
@@ -10,52 +10,125 @@ import {
WEBAPP_URL,
} from "@/lib/constants";
import { createOrUpdateIntegration, getIntegrationByType } from "@/lib/integration/service";
import {
IntegrationOAuthStateError,
consumeIntegrationOAuthState,
getSafeOAuthCallbackError,
} from "@/lib/oauth/integration-state";
import { capturePostHogEvent } from "@/lib/posthog";
import { getOrganizationIdFromWorkspaceId } from "@/lib/utils/helper";
import { hasUserWorkspaceAccess } from "@/lib/workspace/auth";
import { authOptions } from "@/modules/auth/lib/authOptions";
const getGoogleSheetsRedirectUrl = (workspaceId: string) =>
new URL(`/workspaces/${workspaceId}/integrations/google-sheets`, WEBAPP_URL);
const getGoogleSheetsOAuthState = async (state: string | null, userId: string) => {
try {
return await consumeIntegrationOAuthState({
provider: "googleSheets",
userId,
state,
});
} catch (err) {
if (err instanceof IntegrationOAuthStateError) {
return null;
}
throw err;
}
};
const getGoogleSheetsOAuthClient = () => {
const client_id = GOOGLE_SHEETS_CLIENT_ID;
const client_secret = GOOGLE_SHEETS_CLIENT_SECRET;
const redirect_uri = GOOGLE_SHEETS_REDIRECT_URL;
if (!client_id) {
return { response: responses.internalServerErrorResponse("Google client id is missing") };
}
if (!client_secret) {
return { response: responses.internalServerErrorResponse("Google client secret is missing") };
}
if (!redirect_uri) {
return { response: responses.internalServerErrorResponse("Google redirect url is missing") };
}
return { client: new google.auth.OAuth2(client_id, client_secret, redirect_uri) };
};
const captureGoogleSheetsConnectedEvent = async (userId: string, workspaceId: string) => {
try {
const organizationId = await getOrganizationIdFromWorkspaceId(workspaceId);
capturePostHogEvent(userId, "integration_connected", {
integration_type: "googleSheets",
organization_id: organizationId,
});
capturePostHogEvent(
userId,
"integration_connected",
{
integration_type: "googleSheets",
organization_id: organizationId,
workspace_id: workspaceId,
},
{ organizationId, workspaceId }
);
} catch (err) {
logger.error({ error: err }, "Failed to capture PostHog integration_connected event for googleSheets");
}
};
export const GET = async (req: Request) => {
const url = new URL(req.url);
const workspaceId = url.searchParams.get("state");
const state = url.searchParams.get("state");
const code = url.searchParams.get("code");
if (!workspaceId) {
return responses.badRequestResponse("Invalid workspaceId");
}
const error = url.searchParams.get("error");
const session = await getServerSession(authOptions);
if (!session) {
return responses.notAuthenticatedResponse();
}
const oauthState = await getGoogleSheetsOAuthState(state, session.user.id);
if (!oauthState) {
return responses.badRequestResponse("Invalid OAuth state");
}
const workspaceId = oauthState.workspaceId;
const canUserAccessWorkspace = await hasUserWorkspaceAccess(session.user.id, workspaceId);
if (!canUserAccessWorkspace) {
return responses.unauthorizedResponse();
}
const basePath = `/workspaces/${workspaceId}`;
const redirectUrl = getGoogleSheetsRedirectUrl(workspaceId);
const safeError = getSafeOAuthCallbackError(error);
if (safeError) {
redirectUrl.searchParams.set("error", safeError);
return Response.redirect(redirectUrl);
}
if (code && typeof code !== "string") {
return responses.badRequestResponse("`code` must be a string");
}
const client_id = GOOGLE_SHEETS_CLIENT_ID;
const client_secret = GOOGLE_SHEETS_CLIENT_SECRET;
const redirect_uri = GOOGLE_SHEETS_REDIRECT_URL;
if (!client_id) return responses.internalServerErrorResponse("Google client id is missing");
if (!client_secret) return responses.internalServerErrorResponse("Google client secret is missing");
if (!redirect_uri) return responses.internalServerErrorResponse("Google redirect url is missing");
const oAuth2Client = new google.auth.OAuth2(client_id, client_secret, redirect_uri);
const oAuth2ClientResult = getGoogleSheetsOAuthClient();
if ("response" in oAuth2ClientResult) {
return oAuth2ClientResult.response;
}
const oAuth2Client = oAuth2ClientResult.client;
if (!code) {
return Response.redirect(`${WEBAPP_URL}${basePath}/integrations/google-sheets`);
return Response.redirect(redirectUrl);
}
const token = await oAuth2Client.getToken(code);
const key = token.res?.data;
if (!key) {
return Response.redirect(`${WEBAPP_URL}${basePath}/integrations/google-sheets`);
return Response.redirect(redirectUrl);
}
oAuth2Client.setCredentials({ access_token: key.access_token });
@@ -81,29 +154,10 @@ export const GET = async (req: Request) => {
};
const result = await createOrUpdateIntegration(workspaceId, googleSheetIntegration);
if (result) {
try {
const organizationId = await getOrganizationIdFromWorkspaceId(workspaceId);
capturePostHogEvent(session.user.id, "integration_connected", {
integration_type: "googleSheets",
organization_id: organizationId,
});
capturePostHogEvent(
session.user.id,
"integration_connected",
{
integration_type: "googleSheets",
organization_id: organizationId,
workspace_id: workspaceId,
},
{ organizationId, workspaceId }
);
} catch (err) {
logger.error({ error: err }, "Failed to capture PostHog integration_connected event for googleSheets");
}
return Response.redirect(`${WEBAPP_URL}/${basePath}/integrations/google-sheets`);
if (!result) {
return responses.internalServerErrorResponse("Failed to create or update Google Sheets integration");
}
return responses.internalServerErrorResponse("Failed to create or update Google Sheets integration");
await captureGoogleSheetsConnectedEvent(session.user.id, workspaceId);
return Response.redirect(redirectUrl);
};
+7 -1
View File
@@ -7,6 +7,7 @@ import {
GOOGLE_SHEETS_CLIENT_SECRET,
GOOGLE_SHEETS_REDIRECT_URL,
} from "@/lib/constants";
import { createIntegrationOAuthState } from "@/lib/oauth/integration-state";
import { hasUserWorkspaceAccess } from "@/lib/workspace/auth";
import { authOptions } from "@/modules/auth/lib/authOptions";
@@ -39,12 +40,17 @@ export const GET = async (req: NextRequest) => {
if (!client_secret) return responses.internalServerErrorResponse("Google client secret is missing");
if (!redirect_uri) return responses.internalServerErrorResponse("Google redirect url is missing");
const oAuth2Client = new google.auth.OAuth2(client_id, client_secret, redirect_uri);
const state = await createIntegrationOAuthState({
provider: "googleSheets",
userId: session.user.id,
workspaceId,
});
const authUrl = oAuth2Client.generateAuthUrl({
access_type: "offline",
scope: scopes,
prompt: "consent",
state: workspaceId,
state,
});
return responses.successResponse({ authUrl });
@@ -5,6 +5,11 @@ import { withV1ApiWrapper } from "@/app/lib/api/with-api-logging";
import { fetchAirtableAuthToken } from "@/lib/airtable/service";
import { AIRTABLE_CLIENT_ID, WEBAPP_URL } from "@/lib/constants";
import { createOrUpdateIntegration, getIntegrationByType } from "@/lib/integration/service";
import {
IntegrationOAuthStateError,
consumeIntegrationOAuthState,
getSafeOAuthCallbackError,
} from "@/lib/oauth/integration-state";
import { capturePostHogEvent } from "@/lib/posthog";
import { getOrganizationIdFromWorkspaceId } from "@/lib/utils/helper";
import { hasUserWorkspaceAccess } from "@/lib/workspace/auth";
@@ -29,18 +34,31 @@ export const GET = withV1ApiWrapper({
const url = req.url;
const queryParams = new URLSearchParams(url.split("?")[1]); // Split the URL and get the query parameters
const workspaceId = queryParams.get("state"); // Get the value of the 'state' parameter
const state = queryParams.get("state");
const code = queryParams.get("code");
const error = queryParams.get("error");
if (!workspaceId) {
return {
response: responses.badRequestResponse("Invalid workspaceId"),
};
let oauthState;
try {
oauthState = await consumeIntegrationOAuthState({
provider: "airtable",
userId: authentication.user.id,
state,
});
} catch (err) {
if (err instanceof IntegrationOAuthStateError) {
return {
response: responses.badRequestResponse("Invalid OAuth state"),
};
}
throw err;
}
if (!code) {
const workspaceId = oauthState.workspaceId;
if (!workspaceId || !oauthState.pkceCodeVerifier) {
return {
response: responses.badRequestResponse("`code` is missing"),
response: responses.badRequestResponse("Invalid OAuth state"),
};
}
@@ -52,10 +70,25 @@ export const GET = withV1ApiWrapper({
}
const basePath = `/workspaces/${workspaceId}`;
const redirectUrl = new URL(`${basePath}/integrations/airtable`, WEBAPP_URL);
const safeError = getSafeOAuthCallbackError(error);
if (!code && safeError) {
redirectUrl.searchParams.set("error", safeError);
return {
response: Response.redirect(redirectUrl),
};
}
if (!code) {
return {
response: responses.badRequestResponse("`code` is missing"),
};
}
const client_id = AIRTABLE_CLIENT_ID;
const redirect_uri = WEBAPP_URL + "/api/v1/integrations/airtable/callback";
const code_verifier = Buffer.from(workspaceId + authentication.user.id + workspaceId).toString("base64");
const code_verifier = oauthState.pkceCodeVerifier;
if (!client_id)
return {
@@ -110,10 +143,10 @@ export const GET = withV1ApiWrapper({
}
return {
response: Response.redirect(`${WEBAPP_URL}${basePath}/integrations/airtable`),
response: Response.redirect(redirectUrl),
};
} catch (error) {
logger.error({ error, url: req.url }, "Error in GET /api/v1/integrations/airtable/callback");
logger.error({ error }, "Error in GET /api/v1/integrations/airtable/callback");
return {
response: responses.internalServerErrorResponse(
error instanceof Error ? error.message : String(error)
@@ -1,7 +1,7 @@
import crypto from "crypto";
import { responses } from "@/app/lib/api/response";
import { withV1ApiWrapper } from "@/app/lib/api/with-api-logging";
import { AIRTABLE_CLIENT_ID, WEBAPP_URL } from "@/lib/constants";
import { createIntegrationOAuthState, generatePkcePair } from "@/lib/oauth/integration-state";
import { hasUserWorkspaceAccess } from "@/lib/workspace/auth";
const scope = `data.records:read data.records:write schema.bases:read schema.bases:write user.email:read`;
@@ -33,22 +33,19 @@ export const GET = withV1ApiWrapper({
return {
response: responses.internalServerErrorResponse("Airtable client id is missing"),
};
const codeVerifier = Buffer.from(workspaceId + authentication.user.id + workspaceId).toString("base64");
const codeChallengeMethod = "S256";
const codeChallenge = crypto
.createHash("sha256")
.update(codeVerifier) // hash the code verifier with the sha256 algorithm
.digest("base64") // base64 encode, needs to be transformed to base64url
.replace(/=/g, "") // remove =
.replace(/\+/g, "-") // replace + with -
.replace(/\//g, "_"); // replace / with _ now base64url encoded
const { codeChallenge, codeChallengeMethod, codeVerifier } = generatePkcePair();
const state = await createIntegrationOAuthState({
provider: "airtable",
userId: authentication.user.id,
workspaceId,
pkceCodeVerifier: codeVerifier,
});
const authUrl = new URL("https://airtable.com/oauth2/v1/authorize");
authUrl.searchParams.append("client_id", client_id);
authUrl.searchParams.append("redirect_uri", redirect_uri);
authUrl.searchParams.append("state", workspaceId);
authUrl.searchParams.append("state", state);
authUrl.searchParams.append("scope", scope);
authUrl.searchParams.append("response_type", "code");
authUrl.searchParams.append("code_challenge_method", codeChallengeMethod);
@@ -11,6 +11,11 @@ import {
} from "@/lib/constants";
import { symmetricEncrypt } from "@/lib/crypto";
import { createOrUpdateIntegration, getIntegrationByType } from "@/lib/integration/service";
import {
IntegrationOAuthStateError,
consumeIntegrationOAuthState,
getSafeOAuthCallbackError,
} from "@/lib/oauth/integration-state";
import { capturePostHogEvent } from "@/lib/posthog";
import { getOrganizationIdFromWorkspaceId } from "@/lib/utils/helper";
import { hasUserWorkspaceAccess } from "@/lib/workspace/auth";
@@ -23,10 +28,28 @@ export const GET = withV1ApiWrapper({
const url = req.url;
const queryParams = new URLSearchParams(url.split("?")[1]); // Split the URL and get the query parameters
const workspaceId = queryParams.get("state"); // Get the value of the 'state' parameter
const state = queryParams.get("state");
const code = queryParams.get("code");
const error = queryParams.get("error");
let oauthState;
try {
oauthState = await consumeIntegrationOAuthState({
provider: "notion",
userId: authentication.user.id,
state,
});
} catch (err) {
if (err instanceof IntegrationOAuthStateError) {
return {
response: responses.badRequestResponse("Invalid OAuth state"),
};
}
throw err;
}
const workspaceId = oauthState.workspaceId;
if (!workspaceId) {
return {
response: responses.badRequestResponse("Invalid workspaceId"),
@@ -41,6 +64,8 @@ export const GET = withV1ApiWrapper({
}
const basePath = `/workspaces/${workspaceId}`;
const redirectUrl = new URL(`${basePath}/integrations/notion`, WEBAPP_URL);
const safeError = getSafeOAuthCallbackError(error);
if (code && typeof code !== "string") {
return {
@@ -48,6 +73,13 @@ export const GET = withV1ApiWrapper({
};
}
if (!code && safeError) {
redirectUrl.searchParams.set("error", safeError);
return {
response: Response.redirect(redirectUrl),
};
}
const client_id = NOTION_OAUTH_CLIENT_ID;
const client_secret = NOTION_OAUTH_CLIENT_SECRET;
const redirect_uri = NOTION_REDIRECT_URI;
@@ -118,13 +150,9 @@ export const GET = withV1ApiWrapper({
}
return {
response: Response.redirect(`${WEBAPP_URL}${basePath}/integrations/notion`),
response: Response.redirect(redirectUrl),
};
}
} else if (error) {
return {
response: Response.redirect(`${WEBAPP_URL}${basePath}/integrations/notion?error=${error}`),
};
}
return {
@@ -6,6 +6,7 @@ import {
NOTION_OAUTH_CLIENT_SECRET,
NOTION_REDIRECT_URI,
} from "@/lib/constants";
import { createIntegrationOAuthState } from "@/lib/oauth/integration-state";
import { hasUserWorkspaceAccess } from "@/lib/workspace/auth";
export const GET = withV1ApiWrapper({
@@ -49,9 +50,16 @@ export const GET = withV1ApiWrapper({
return {
response: responses.internalServerErrorResponse("Notion auth url is missing"),
};
const state = await createIntegrationOAuthState({
provider: "notion",
userId: authentication.user.id,
workspaceId,
});
const authUrlWithState = new URL(auth_url);
authUrlWithState.searchParams.set("state", state);
return {
response: responses.successResponse({ authUrl: `${auth_url}&state=${workspaceId}` }),
response: responses.successResponse({ authUrl: authUrlWithState.toString() }),
};
},
});
@@ -8,6 +8,11 @@ import { responses } from "@/app/lib/api/response";
import { withV1ApiWrapper } from "@/app/lib/api/with-api-logging";
import { SLACK_CLIENT_ID, SLACK_CLIENT_SECRET, SLACK_REDIRECT_URI, WEBAPP_URL } from "@/lib/constants";
import { createOrUpdateIntegration, getIntegrationByType } from "@/lib/integration/service";
import {
IntegrationOAuthStateError,
consumeIntegrationOAuthState,
getSafeOAuthCallbackError,
} from "@/lib/oauth/integration-state";
import { capturePostHogEvent } from "@/lib/posthog";
import { getOrganizationIdFromWorkspaceId } from "@/lib/utils/helper";
import { hasUserWorkspaceAccess } from "@/lib/workspace/auth";
@@ -20,10 +25,28 @@ export const GET = withV1ApiWrapper({
const url = req.url;
const queryParams = new URLSearchParams(url.split("?")[1]); // Split the URL and get the query parameters
const workspaceId = queryParams.get("state"); // Get the value of the 'state' parameter
const state = queryParams.get("state");
const code = queryParams.get("code");
const error = queryParams.get("error");
let oauthState;
try {
oauthState = await consumeIntegrationOAuthState({
provider: "slack",
userId: authentication.user.id,
state,
});
} catch (err) {
if (err instanceof IntegrationOAuthStateError) {
return {
response: responses.badRequestResponse("Invalid OAuth state"),
};
}
throw err;
}
const workspaceId = oauthState.workspaceId;
if (!workspaceId) {
return {
response: responses.badRequestResponse("Invalid workspaceId"),
@@ -38,6 +61,8 @@ export const GET = withV1ApiWrapper({
}
const basePath = `/workspaces/${workspaceId}`;
const redirectUrl = new URL(`${basePath}/integrations/slack`, WEBAPP_URL);
const safeError = getSafeOAuthCallbackError(error);
if (code && typeof code !== "string") {
return {
@@ -45,6 +70,13 @@ export const GET = withV1ApiWrapper({
};
}
if (!code && safeError) {
redirectUrl.searchParams.set("error", safeError);
return {
response: Response.redirect(redirectUrl),
};
}
if (!SLACK_CLIENT_ID)
return {
response: responses.internalServerErrorResponse("Slack client id is missing"),
@@ -125,13 +157,9 @@ export const GET = withV1ApiWrapper({
}
return {
response: Response.redirect(`${WEBAPP_URL}${basePath}/integrations/slack`),
response: Response.redirect(redirectUrl),
};
}
} else if (error) {
return {
response: Response.redirect(`${WEBAPP_URL}${basePath}/integrations/slack?error=${error}`),
};
}
return {
@@ -1,6 +1,7 @@
import { responses } from "@/app/lib/api/response";
import { withV1ApiWrapper } from "@/app/lib/api/with-api-logging";
import { SLACK_AUTH_URL, SLACK_CLIENT_ID, SLACK_CLIENT_SECRET } from "@/lib/constants";
import { createIntegrationOAuthState } from "@/lib/oauth/integration-state";
import { hasUserWorkspaceAccess } from "@/lib/workspace/auth";
export const GET = withV1ApiWrapper({
@@ -37,9 +38,16 @@ export const GET = withV1ApiWrapper({
return {
response: responses.internalServerErrorResponse("Slack auth url is missing"),
};
const state = await createIntegrationOAuthState({
provider: "slack",
userId: authentication.user.id,
workspaceId,
});
const authUrl = new URL(SLACK_AUTH_URL);
authUrl.searchParams.set("state", state);
return {
response: responses.successResponse({ authUrl: `${SLACK_AUTH_URL}&state=${workspaceId}` }),
response: responses.successResponse({ authUrl: authUrl.toString() }),
};
},
});
@@ -0,0 +1,254 @@
import { beforeEach, describe, expect, test, vi } from "vitest";
import { ErrorCode } from "@formbricks/cache";
import { logger } from "@formbricks/logger";
import { cache } from "@/lib/cache";
import {
IntegrationOAuthStateError,
consumeIntegrationOAuthState,
createIntegrationOAuthState,
generatePkcePair,
getSafeOAuthCallbackError,
} from "./integration-state";
vi.mock("@formbricks/logger", () => ({
logger: {
error: vi.fn(),
warn: vi.fn(),
},
}));
vi.mock("@/lib/cache", () => ({
cache: {
getRedisClient: vi.fn(),
set: vi.fn(),
},
}));
const mockCache = vi.mocked(cache);
const oauthStatePayload = {
createdAt: Date.now(),
provider: "slack",
userId: "user-1",
workspaceId: "workspace-1",
} as const;
const mockRedisConsume = (value: unknown) => {
const evalMock = vi.fn().mockResolvedValue(value === null ? null : JSON.stringify(value));
mockCache.getRedisClient.mockResolvedValueOnce({ eval: evalMock } as any);
return evalMock;
};
describe("integration OAuth state", () => {
beforeEach(() => {
vi.resetAllMocks();
mockCache.set.mockResolvedValue({ ok: true, data: undefined });
});
test("creates an opaque cached state that does not expose the workspace id", async () => {
const state = await createIntegrationOAuthState({
provider: "slack",
userId: oauthStatePayload.userId,
workspaceId: oauthStatePayload.workspaceId,
});
expect(state).toMatch(/^[A-Za-z0-9_-]{43,128}$/);
expect(state).not.toContain(oauthStatePayload.workspaceId);
expect(mockCache.set).toHaveBeenCalledWith(
"fb:oauth:state:fake-hash",
expect.objectContaining({
provider: oauthStatePayload.provider,
userId: oauthStatePayload.userId,
workspaceId: oauthStatePayload.workspaceId,
}),
10 * 60 * 1000
);
});
test("stores the PKCE verifier with Airtable OAuth state", async () => {
const pkceCodeVerifier = "E".repeat(43);
await createIntegrationOAuthState({
pkceCodeVerifier,
provider: "airtable",
userId: oauthStatePayload.userId,
workspaceId: oauthStatePayload.workspaceId,
});
expect(mockCache.set).toHaveBeenCalledWith(
"fb:oauth:state:fake-hash",
expect.objectContaining({ pkceCodeVerifier }),
10 * 60 * 1000
);
});
test("consumes a valid state atomically and returns the stored workspace", async () => {
const state = await createIntegrationOAuthState({
provider: "slack",
userId: oauthStatePayload.userId,
workspaceId: oauthStatePayload.workspaceId,
});
const redisEval = mockRedisConsume(oauthStatePayload);
const consumedState = await consumeIntegrationOAuthState({
provider: "slack",
userId: oauthStatePayload.userId,
state,
});
expect(consumedState).toEqual(oauthStatePayload);
expect(redisEval).toHaveBeenCalledWith(expect.stringContaining('redis.call("GET", KEYS[1])'), {
arguments: [],
keys: ["fb:oauth:state:fake-hash"],
});
});
test("rejects reused or unknown states", async () => {
mockRedisConsume(null);
await expect(
consumeIntegrationOAuthState({
provider: "slack",
userId: oauthStatePayload.userId,
state: "A".repeat(43),
})
).rejects.toThrow(IntegrationOAuthStateError);
});
test("rejects malformed callback state before reading Redis", async () => {
await expect(
consumeIntegrationOAuthState({
provider: "slack",
userId: oauthStatePayload.userId,
state: "too-short",
})
).rejects.toThrow(IntegrationOAuthStateError);
expect(mockCache.getRedisClient).not.toHaveBeenCalled();
expect(logger.warn).toHaveBeenCalled();
});
test("rejects wrong provider and wrong user states", async () => {
mockRedisConsume(oauthStatePayload);
await expect(
consumeIntegrationOAuthState({
provider: "notion",
userId: oauthStatePayload.userId,
state: "B".repeat(43),
})
).rejects.toThrow(IntegrationOAuthStateError);
mockRedisConsume(oauthStatePayload);
await expect(
consumeIntegrationOAuthState({
provider: "slack",
userId: "user-2",
state: "C".repeat(43),
})
).rejects.toThrow(IntegrationOAuthStateError);
});
test("fails closed when cache storage or Redis is unavailable", async () => {
mockCache.set.mockResolvedValueOnce({ ok: false, error: { code: ErrorCode.RedisConnectionError } });
await expect(
createIntegrationOAuthState({
provider: "slack",
userId: oauthStatePayload.userId,
workspaceId: oauthStatePayload.workspaceId,
})
).rejects.toThrow("Unable to start OAuth flow");
mockCache.getRedisClient.mockResolvedValueOnce(null);
await expect(
consumeIntegrationOAuthState({
provider: "slack",
userId: oauthStatePayload.userId,
state: "D".repeat(43),
})
).rejects.toThrow(IntegrationOAuthStateError);
expect(logger.error).toHaveBeenCalled();
});
test("fails closed when Redis client resolution throws", async () => {
mockCache.getRedisClient.mockRejectedValueOnce(new Error("Redis unavailable"));
await expect(
consumeIntegrationOAuthState({
provider: "slack",
userId: oauthStatePayload.userId,
state: "I".repeat(43),
})
).rejects.toThrow(IntegrationOAuthStateError);
expect(logger.error).toHaveBeenCalled();
});
test("rejects malformed cached state values", async () => {
mockRedisConsume({
createdAt: Date.now(),
provider: "slack",
userId: oauthStatePayload.userId,
});
await expect(
consumeIntegrationOAuthState({
provider: "slack",
userId: oauthStatePayload.userId,
state: "F".repeat(43),
})
).rejects.toThrow(IntegrationOAuthStateError);
expect(logger.error).toHaveBeenCalled();
});
test("rejects unexpected cached value types", async () => {
mockCache.getRedisClient.mockResolvedValueOnce({
eval: vi.fn().mockResolvedValue(42),
} as any);
await expect(
consumeIntegrationOAuthState({
provider: "slack",
userId: oauthStatePayload.userId,
state: "G".repeat(43),
})
).rejects.toThrow(IntegrationOAuthStateError);
expect(logger.error).toHaveBeenCalled();
});
test("fails closed when atomic cache consumption fails", async () => {
mockCache.getRedisClient.mockResolvedValueOnce({
eval: vi.fn().mockRejectedValue(new Error("Redis failed")),
} as any);
await expect(
consumeIntegrationOAuthState({
provider: "slack",
userId: oauthStatePayload.userId,
state: "H".repeat(43),
})
).rejects.toThrow(IntegrationOAuthStateError);
expect(logger.error).toHaveBeenCalled();
});
test("generates an RFC 7636 S256 PKCE pair", () => {
const { codeChallenge, codeChallengeMethod, codeVerifier } = generatePkcePair();
expect(codeVerifier).toMatch(/^[A-Za-z0-9_-]{43,128}$/);
expect(codeChallenge).toBe("fake-hash");
expect(codeChallengeMethod).toBe("S256");
});
test("sanitizes provider callback errors", () => {
expect(getSafeOAuthCallbackError("access_denied")).toBe("access_denied");
expect(getSafeOAuthCallbackError("https://evil.example")).toBe("oauth_error");
expect(getSafeOAuthCallbackError(null)).toBeNull();
});
});
+215
View File
@@ -0,0 +1,215 @@
import "server-only";
import crypto from "node:crypto";
import { createCacheKey } from "@formbricks/cache";
import { logger } from "@formbricks/logger";
import { cache } from "@/lib/cache";
const INTEGRATION_OAUTH_STATE_TTL_MS = 10 * 60 * 1000;
const OAUTH_STATE_ENTROPY_BYTES = 32;
const BASE64URL_TOKEN_REGEX = /^[A-Za-z0-9_-]{43,128}$/;
const SAFE_OAUTH_CALLBACK_ERRORS = new Set([
"access_denied",
"invalid_request",
"invalid_scope",
"server_error",
"temporarily_unavailable",
]);
export type TIntegrationOAuthProvider = "googleSheets" | "slack" | "notion" | "airtable";
type TStoredIntegrationOAuthState = {
provider: TIntegrationOAuthProvider;
userId: string;
workspaceId: string;
pkceCodeVerifier?: string;
createdAt: number;
};
type TCreateIntegrationOAuthStateInput = {
provider: TIntegrationOAuthProvider;
userId: string;
workspaceId: string;
pkceCodeVerifier?: string;
};
type TConsumeIntegrationOAuthStateInput = {
provider: TIntegrationOAuthProvider;
userId: string;
state: string | null;
};
export class IntegrationOAuthStateError extends Error {
constructor(message = "Invalid OAuth state") {
super(message);
this.name = "IntegrationOAuthStateError";
}
}
const toBase64Url = (buffer: Buffer) =>
buffer.toString("base64").replaceAll("=", "").replaceAll("+", "-").replaceAll("/", "_");
const generateRandomToken = () => toBase64Url(crypto.randomBytes(OAUTH_STATE_ENTROPY_BYTES));
const hashState = (state: string) => crypto.createHash("sha256").update(state).digest("hex");
const getIntegrationOAuthStateCacheKey = (stateHash: string) =>
createCacheKey.custom("oauth", "state", stateHash);
const getValidToken = (token: string | undefined, label: string) => {
if (!token || !BASE64URL_TOKEN_REGEX.test(token)) {
throw new IntegrationOAuthStateError(`Invalid OAuth ${label}`);
}
return token;
};
const parseStoredIntegrationOAuthState = (serializedValue: string): TStoredIntegrationOAuthState => {
try {
const parsedValue = JSON.parse(serializedValue) as Partial<TStoredIntegrationOAuthState>;
if (
!parsedValue ||
typeof parsedValue.provider !== "string" ||
typeof parsedValue.userId !== "string" ||
typeof parsedValue.workspaceId !== "string" ||
typeof parsedValue.createdAt !== "number" ||
(parsedValue.pkceCodeVerifier !== undefined && typeof parsedValue.pkceCodeVerifier !== "string")
) {
throw new Error("Invalid stored OAuth state shape");
}
return parsedValue as TStoredIntegrationOAuthState;
} catch (error) {
logger.error({ error }, "Failed to parse stored integration OAuth state");
throw new IntegrationOAuthStateError();
}
};
const consumeCachedIntegrationOAuthState = async (
cacheKey: string,
logContext: Record<string, unknown>
): Promise<TStoredIntegrationOAuthState | null> => {
let redis;
try {
redis = await cache.getRedisClient();
} catch (error) {
logger.error({ ...logContext, error }, "Failed to resolve Redis client for integration OAuth state");
throw new IntegrationOAuthStateError("Unable to validate OAuth state");
}
if (!redis) {
logger.error({ ...logContext }, "Redis is required to validate integration OAuth state");
throw new IntegrationOAuthStateError("Unable to validate OAuth state");
}
try {
const serializedValue = await redis.eval(
`
local value = redis.call("GET", KEYS[1])
if value then
redis.call("DEL", KEYS[1])
end
return value
`,
{
arguments: [],
keys: [cacheKey],
}
);
if (serializedValue === null) {
return null;
}
if (typeof serializedValue !== "string") {
logger.error({ ...logContext }, "Unexpected cached integration OAuth state value");
throw new IntegrationOAuthStateError();
}
return parseStoredIntegrationOAuthState(serializedValue);
} catch (error) {
if (error instanceof IntegrationOAuthStateError) {
throw error;
}
logger.error({ ...logContext, error }, "Failed to consume integration OAuth state");
throw new IntegrationOAuthStateError("Unable to validate OAuth state");
}
};
export const createIntegrationOAuthState = async ({
provider,
userId,
workspaceId,
pkceCodeVerifier,
}: TCreateIntegrationOAuthStateInput): Promise<string> => {
if (pkceCodeVerifier !== undefined) {
getValidToken(pkceCodeVerifier, "PKCE verifier");
}
const state = generateRandomToken();
const stateHash = hashState(state);
const cacheKey = getIntegrationOAuthStateCacheKey(stateHash);
const storedState: TStoredIntegrationOAuthState = {
provider,
userId,
workspaceId,
pkceCodeVerifier,
createdAt: Date.now(),
};
const result = await cache.set(cacheKey, storedState, INTEGRATION_OAUTH_STATE_TTL_MS);
if (!result.ok) {
logger.error({ error: result.error, provider, userId, workspaceId }, "Failed to store OAuth state");
throw new Error("Unable to start OAuth flow");
}
return state;
};
export const consumeIntegrationOAuthState = async ({
provider,
userId,
state,
}: TConsumeIntegrationOAuthStateInput): Promise<TStoredIntegrationOAuthState> => {
let providedState;
try {
providedState = getValidToken(state ?? undefined, "state");
} catch (error) {
logger.warn({ provider, userId }, "Integration OAuth callback rejected due to malformed state");
throw error;
}
const stateHash = hashState(providedState);
const cacheKey = getIntegrationOAuthStateCacheKey(stateHash);
const storedState = await consumeCachedIntegrationOAuthState(cacheKey, { provider, stateHash, userId });
if (storedState?.provider !== provider || storedState?.userId !== userId) {
logger.warn({ provider, stateHash, userId }, "Integration OAuth callback rejected due to invalid state");
throw new IntegrationOAuthStateError();
}
return storedState;
};
export const getSafeOAuthCallbackError = (error: string | null): string | null => {
if (!error) {
return null;
}
return SAFE_OAUTH_CALLBACK_ERRORS.has(error) ? error : "oauth_error";
};
export const generatePkcePair = () => {
const verifier = generateRandomToken();
const challenge = toBase64Url(crypto.createHash("sha256").update(verifier).digest());
return {
codeChallenge: challenge,
codeChallengeMethod: "S256" as const,
codeVerifier: verifier,
};
};
+10 -7
View File
@@ -91,14 +91,17 @@ describe("@formbricks/cache types/keys", () => {
});
describe("CustomCacheNamespace type", () => {
test("should include expected namespaces", () => {
test("should support known custom namespaces in parsed cache keys", () => {
// Type test - this will fail at compile time if types don't match
const accountDeletionNamespace: CustomCacheNamespace = "account_deletion";
const analyticsNamespace: CustomCacheNamespace = "analytics";
const billingNamespace: CustomCacheNamespace = "billing";
expect(accountDeletionNamespace).toBe("account_deletion");
expect(analyticsNamespace).toBe("analytics");
expect(billingNamespace).toBe("billing");
const namespaces: CustomCacheNamespace[] = ["account_deletion", "analytics", "billing", "oauth"];
const cacheKeys = namespaces.map((namespace) => ZCacheKey.parse(`${namespace}:test:123`));
expect(cacheKeys).toEqual([
"account_deletion:test:123",
"analytics:test:123",
"billing:test:123",
"oauth:test:123",
]);
});
test("should be usable in cache key construction", () => {
+1 -1
View File
@@ -16,4 +16,4 @@ export type CacheKey = z.infer<typeof ZCacheKey>;
* Possible namespaces for custom cache keys
* Add new namespaces here as they are introduced
*/
export type CustomCacheNamespace = "account_deletion" | "analytics" | "billing";
export type CustomCacheNamespace = "account_deletion" | "analytics" | "billing" | "oauth";