feat: Adds SAML SSO auth using boxyHQ jackson for self-hosters (#4799)

Co-authored-by: Dhruwang <dhruwangjariwala18@gmail.com>
Co-authored-by: Matti Nannt <mail@matthiasnannt.com>
This commit is contained in:
Piyush Gupta
2025-02-28 17:48:59 +05:30
committed by GitHub
parent 1eb8049d04
commit 803a73afb6
89 changed files with 2740 additions and 105 deletions
+3
View File
@@ -48,3 +48,6 @@ uploads/
# Sentry Config File
.sentryclirc
# SAML Preloaded Connections
saml-connection/
+5
View File
@@ -107,6 +107,11 @@ ENV HOSTNAME "0.0.0.0"
RUN mkdir -p /home/nextjs/apps/web/uploads/
VOLUME /home/nextjs/apps/web/uploads/
# Prepare volume for SAML preloaded connection
RUN mkdir -p /home/nextjs/apps/web/saml-connection
VOLUME /home/nextjs/apps/web/saml-connection
CMD supercronic -quiet /app/docker/cronjobs & \
(cd packages/database && npm run db:migrate:deploy) && \
(cd packages/database && npm run db:create-saml-database:deploy) && \
exec node apps/web/server.js
@@ -0,0 +1,3 @@
import { GET } from "@/modules/ee/auth/saml/api/authorize/route";
export { GET };
@@ -0,0 +1,3 @@
import { POST } from "@/modules/ee/auth/saml/api/callback/route";
export { POST };
@@ -0,0 +1,3 @@
import { POST } from "@/modules/ee/auth/saml/api/token/route";
export { POST };
@@ -0,0 +1,3 @@
import { GET } from "@/modules/ee/auth/saml/api/userinfo/route";
export { GET };
@@ -39,7 +39,10 @@ interface LoginFormProps {
oidcOAuthEnabled: boolean;
oidcDisplayName?: string;
isMultiOrgEnabled: boolean;
isSSOEnabled: boolean;
isSsoEnabled: boolean;
samlSsoEnabled: boolean;
samlTenant: string;
samlProduct: string;
}
export const LoginForm = ({
@@ -52,7 +55,10 @@ export const LoginForm = ({
oidcOAuthEnabled,
oidcDisplayName,
isMultiOrgEnabled,
isSSOEnabled,
isSsoEnabled,
samlSsoEnabled,
samlTenant,
samlProduct,
}: LoginFormProps) => {
const router = useRouter();
const searchParams = useSearchParams();
@@ -239,13 +245,16 @@ export const LoginForm = ({
</Button>
)}
</form>
{isSSOEnabled && (
{isSsoEnabled && (
<SSOOptions
googleOAuthEnabled={googleOAuthEnabled}
githubOAuthEnabled={githubOAuthEnabled}
azureOAuthEnabled={azureOAuthEnabled}
oidcOAuthEnabled={oidcOAuthEnabled}
oidcDisplayName={oidcDisplayName}
samlSsoEnabled={samlSsoEnabled}
samlTenant={samlTenant}
samlProduct={samlProduct}
callbackUrl={callbackUrl}
/>
)}
+19 -3
View File
@@ -1,6 +1,10 @@
import { FormWrapper } from "@/modules/auth/components/form-wrapper";
import { Testimonial } from "@/modules/auth/components/testimonial";
import { getIsMultiOrgEnabled, getIsSSOEnabled } from "@/modules/ee/license-check/lib/utils";
import {
getIsMultiOrgEnabled,
getIsSamlSsoEnabled,
getisSsoEnabled,
} from "@/modules/ee/license-check/lib/utils";
import { Metadata } from "next";
import {
AZURE_OAUTH_ENABLED,
@@ -10,6 +14,9 @@ import {
OIDC_DISPLAY_NAME,
OIDC_OAUTH_ENABLED,
PASSWORD_RESET_DISABLED,
SAML_OAUTH_ENABLED,
SAML_PRODUCT,
SAML_TENANT,
SIGNUP_ENABLED,
} from "@formbricks/lib/constants";
import { LoginForm } from "./components/login-form";
@@ -20,7 +27,13 @@ export const metadata: Metadata = {
};
export const LoginPage = async () => {
const [isMultiOrgEnabled, isSSOEnabled] = await Promise.all([getIsMultiOrgEnabled(), getIsSSOEnabled()]);
const [isMultiOrgEnabled, isSsoEnabled, isSamlSsoEnabled] = await Promise.all([
getIsMultiOrgEnabled(),
getisSsoEnabled(),
getIsSamlSsoEnabled(),
]);
const samlSsoEnabled = isSamlSsoEnabled && SAML_OAUTH_ENABLED;
return (
<div className="grid min-h-screen w-full bg-gradient-to-tr from-slate-100 to-slate-50 lg:grid-cols-5">
<div className="col-span-2 hidden lg:flex">
@@ -38,7 +51,10 @@ export const LoginPage = async () => {
oidcOAuthEnabled={OIDC_OAUTH_ENABLED}
oidcDisplayName={OIDC_DISPLAY_NAME}
isMultiOrgEnabled={isMultiOrgEnabled}
isSSOEnabled={isSSOEnabled}
isSsoEnabled={isSsoEnabled}
samlSsoEnabled={samlSsoEnabled}
samlTenant={SAML_TENANT}
samlProduct={SAML_PRODUCT}
/>
</FormWrapper>
</div>
@@ -53,8 +53,11 @@ interface SignupFormProps {
emailVerificationDisabled: boolean;
defaultOrganizationId?: string;
defaultOrganizationRole?: TOrganizationRole;
isSSOEnabled: boolean;
isSsoEnabled: boolean;
samlSsoEnabled: boolean;
isTurnstileConfigured: boolean;
samlTenant: string;
samlProduct: string;
}
export const SignupForm = ({
@@ -72,8 +75,11 @@ export const SignupForm = ({
emailVerificationDisabled,
defaultOrganizationId,
defaultOrganizationRole,
isSSOEnabled,
isSsoEnabled,
samlSsoEnabled,
isTurnstileConfigured,
samlTenant,
samlProduct,
}: SignupFormProps) => {
const [showLogin, setShowLogin] = useState(false);
const searchParams = useSearchParams();
@@ -266,13 +272,16 @@ export const SignupForm = ({
</form>
</FormProvider>
)}
{isSSOEnabled && (
{isSsoEnabled && (
<SSOOptions
googleOAuthEnabled={googleOAuthEnabled}
githubOAuthEnabled={githubOAuthEnabled}
azureOAuthEnabled={azureOAuthEnabled}
oidcOAuthEnabled={oidcOAuthEnabled}
oidcDisplayName={oidcDisplayName}
samlSsoEnabled={samlSsoEnabled}
samlTenant={samlTenant}
samlProduct={samlProduct}
callbackUrl={callbackUrl}
/>
)}
+20 -3
View File
@@ -1,6 +1,10 @@
import { FormWrapper } from "@/modules/auth/components/form-wrapper";
import { Testimonial } from "@/modules/auth/components/testimonial";
import { getIsMultiOrgEnabled, getIsSSOEnabled } from "@/modules/ee/license-check/lib/utils";
import {
getIsMultiOrgEnabled,
getIsSamlSsoEnabled,
getisSsoEnabled,
} from "@/modules/ee/license-check/lib/utils";
import { notFound } from "next/navigation";
import {
AZURE_OAUTH_ENABLED,
@@ -14,6 +18,9 @@ import {
OIDC_DISPLAY_NAME,
OIDC_OAUTH_ENABLED,
PRIVACY_URL,
SAML_OAUTH_ENABLED,
SAML_PRODUCT,
SAML_TENANT,
SIGNUP_ENABLED,
TERMS_URL,
WEBAPP_URL,
@@ -24,7 +31,14 @@ import { SignupForm } from "./components/signup-form";
export const SignupPage = async ({ searchParams: searchParamsProps }) => {
const searchParams = await searchParamsProps;
const inviteToken = searchParams["inviteToken"] ?? null;
const [isMultOrgEnabled, isSSOEnabled] = await Promise.all([getIsMultiOrgEnabled(), getIsSSOEnabled()]);
const [isMultOrgEnabled, isSsoEnabled, isSamlSsoEnabled] = await Promise.all([
getIsMultiOrgEnabled(),
getisSsoEnabled(),
getIsSamlSsoEnabled(),
]);
const samlSsoEnabled = isSamlSsoEnabled && SAML_OAUTH_ENABLED;
const locale = await findMatchingLocale();
if (!inviteToken && (!SIGNUP_ENABLED || !isMultOrgEnabled)) {
notFound();
@@ -53,8 +67,11 @@ export const SignupPage = async ({ searchParams: searchParamsProps }) => {
emailFromSearchParams={emailFromSearchParams}
defaultOrganizationId={DEFAULT_ORGANIZATION_ID}
defaultOrganizationRole={DEFAULT_ORGANIZATION_ROLE}
isSSOEnabled={isSSOEnabled}
isSsoEnabled={isSsoEnabled}
samlSsoEnabled={samlSsoEnabled}
isTurnstileConfigured={IS_TURNSTILE_CONFIGURED}
samlTenant={SAML_TENANT}
samlProduct={SAML_PRODUCT}
/>
</FormWrapper>
</div>
@@ -0,0 +1,33 @@
import { responses } from "@/app/lib/api/response";
import jackson from "@/modules/ee/auth/saml/lib/jackson";
import { getIsSamlSsoEnabled } from "@/modules/ee/license-check/lib/utils";
import type { OAuthReq } from "@boxyhq/saml-jackson";
import { NextRequest, NextResponse } from "next/server";
export const GET = async (req: NextRequest) => {
const jacksonInstance = await jackson();
if (!jacksonInstance) {
return responses.forbiddenResponse("SAML SSO is not enabled in your Formbricks license");
}
const { oauthController } = jacksonInstance;
const searchParams = Object.fromEntries(req.nextUrl.searchParams);
const isSamlSsoEnabled = await getIsSamlSsoEnabled();
if (!isSamlSsoEnabled) {
return responses.forbiddenResponse("SAML SSO is not enabled in your Formbricks license");
}
try {
const { redirect_url } = await oauthController.authorize(searchParams as OAuthReq);
if (!redirect_url) {
return responses.internalServerErrorResponse("Failed to get redirect URL");
}
return NextResponse.redirect(redirect_url);
} catch (err: unknown) {
const errorMessage = err instanceof Error ? err.message : "An unknown error occurred";
return responses.internalServerErrorResponse(errorMessage);
}
};
@@ -0,0 +1,32 @@
import { responses } from "@/app/lib/api/response";
import jackson from "@/modules/ee/auth/saml/lib/jackson";
import { redirect } from "next/navigation";
interface SAMLCallbackBody {
RelayState: string;
SAMLResponse: string;
}
export const POST = async (req: Request) => {
const jacksonInstance = await jackson();
if (!jacksonInstance) {
return responses.forbiddenResponse("SAML SSO is not enabled in your Formbricks license");
}
const { oauthController } = jacksonInstance;
const formData = await req.formData();
const body = Object.fromEntries(formData.entries());
const { RelayState, SAMLResponse } = body as unknown as SAMLCallbackBody;
const { redirect_url } = await oauthController.samlResponse({
RelayState,
SAMLResponse,
});
if (!redirect_url) {
return responses.internalServerErrorResponse("Failed to get redirect URL");
}
return redirect(redirect_url);
};
@@ -0,0 +1,18 @@
import { responses } from "@/app/lib/api/response";
import jackson from "@/modules/ee/auth/saml/lib/jackson";
import { OAuthTokenReq } from "@boxyhq/saml-jackson";
export const POST = async (req: Request) => {
const jacksonInstance = await jackson();
if (!jacksonInstance) {
return responses.forbiddenResponse("SAML SSO is not enabled in your Formbricks license");
}
const { oauthController } = jacksonInstance;
const body = await req.formData();
const formData = Object.fromEntries(body.entries());
const response = await oauthController.token(formData as unknown as OAuthTokenReq);
return Response.json(response);
};
@@ -0,0 +1,87 @@
import { responses } from "@/app/lib/api/response";
import { describe, expect, test, vi } from "vitest";
import { extractAuthToken } from "./utils";
vi.mock("@/app/lib/api/response", () => ({
responses: {
unauthorizedResponse: vi.fn().mockReturnValue(new Error("Unauthorized")),
},
}));
describe("extractAuthToken", () => {
test("extracts token from Authorization header with Bearer prefix", () => {
const mockRequest = new Request("https://example.com", {
headers: {
authorization: "Bearer token123",
},
});
const token = extractAuthToken(mockRequest);
expect(token).toBe("token123");
});
test("extracts token from Authorization header with other prefix", () => {
const mockRequest = new Request("https://example.com", {
headers: {
authorization: "Custom token123",
},
});
const token = extractAuthToken(mockRequest);
expect(token).toBe("token123");
});
test("extracts token from query parameter", () => {
const mockRequest = new Request("https://example.com?access_token=token123");
const token = extractAuthToken(mockRequest);
expect(token).toBe("token123");
});
test("prioritizes Authorization header over query parameter", () => {
const mockRequest = new Request("https://example.com?access_token=queryToken", {
headers: {
authorization: "Bearer headerToken",
},
});
const token = extractAuthToken(mockRequest);
expect(token).toBe("headerToken");
});
test("throws unauthorized error when no token is found", () => {
const mockRequest = new Request("https://example.com");
expect(() => extractAuthToken(mockRequest)).toThrow("Unauthorized");
expect(responses.unauthorizedResponse).toHaveBeenCalled();
});
test("throws unauthorized error when Authorization header is empty", () => {
const mockRequest = new Request("https://example.com", {
headers: {
authorization: "",
},
});
expect(() => extractAuthToken(mockRequest)).toThrow("Unauthorized");
expect(responses.unauthorizedResponse).toHaveBeenCalled();
});
test("throws unauthorized error when query parameter is empty", () => {
const mockRequest = new Request("https://example.com?access_token=");
expect(() => extractAuthToken(mockRequest)).toThrow("Unauthorized");
expect(responses.unauthorizedResponse).toHaveBeenCalled();
});
test("handles Authorization header with only prefix", () => {
const mockRequest = new Request("https://example.com", {
headers: {
authorization: "Bearer ",
},
});
expect(() => extractAuthToken(mockRequest)).toThrow("Unauthorized");
expect(responses.unauthorizedResponse).toHaveBeenCalled();
});
});
@@ -0,0 +1,14 @@
import { responses } from "@/app/lib/api/response";
export const extractAuthToken = (req: Request) => {
const authHeader = req.headers.get("authorization");
const parts = (authHeader || "").split(" ");
if (parts.length > 1) return parts[1];
// check for query param
const params = new URL(req.url).searchParams;
const accessToken = params.get("access_token");
if (accessToken) return accessToken;
throw responses.unauthorizedResponse();
};
@@ -0,0 +1,16 @@
import { responses } from "@/app/lib/api/response";
import { extractAuthToken } from "@/modules/ee/auth/saml/api/userinfo/lib/utils";
import jackson from "@/modules/ee/auth/saml/lib/jackson";
export const GET = async (req: Request) => {
const jacksonInstance = await jackson();
if (!jacksonInstance) {
return responses.forbiddenResponse("SAML SSO is not enabled in your Formbricks license");
}
const { oauthController } = jacksonInstance;
const token = extractAuthToken(req);
const user = await oauthController.userInfo(token);
return Response.json(user);
};
@@ -0,0 +1,43 @@
"use server";
import { preloadConnection } from "@/modules/ee/auth/saml/lib/preload-connection";
import { getIsSamlSsoEnabled } from "@/modules/ee/license-check/lib/utils";
import type { IConnectionAPIController, IOAuthController, JacksonOption } from "@boxyhq/saml-jackson";
import { SAML_AUDIENCE, SAML_DATABASE_URL, SAML_PATH, WEBAPP_URL } from "@formbricks/lib/constants";
const opts: JacksonOption = {
externalUrl: WEBAPP_URL,
samlAudience: SAML_AUDIENCE,
samlPath: SAML_PATH,
db: {
engine: "sql",
type: "postgres",
url: SAML_DATABASE_URL,
},
};
declare global {
var oauthController: IOAuthController | undefined;
var connectionController: IConnectionAPIController | undefined;
}
const g = global;
export default async function init() {
if (!g.oauthController || !g.connectionController) {
const isSamlSsoEnabled = await getIsSamlSsoEnabled();
if (!isSamlSsoEnabled) return;
const ret = await (await import("@boxyhq/saml-jackson")).controllers(opts);
await preloadConnection(ret.connectionAPIController);
g.oauthController = ret.oauthController;
g.connectionController = ret.connectionAPIController;
}
return {
oauthController: g.oauthController,
connectionController: g.connectionController,
};
}
@@ -0,0 +1,73 @@
import { SAMLSSOConnectionWithEncodedMetadata, SAMLSSORecord } from "@boxyhq/saml-jackson";
import { ConnectionAPIController } from "@boxyhq/saml-jackson/dist/controller/api";
import fs from "fs/promises";
import path from "path";
import { SAML_PRODUCT, SAML_TENANT, SAML_XML_DIR, WEBAPP_URL } from "@formbricks/lib/constants";
const getPreloadedConnectionFile = async () => {
const preloadedConnections = await fs.readdir(path.join(SAML_XML_DIR));
const xmlFiles = preloadedConnections.filter((file) => file.endsWith(".xml"));
if (xmlFiles.length === 0) {
throw new Error("No preloaded connection file found");
}
return xmlFiles[0];
};
const getPreloadedConnectionMetadata = async () => {
const preloadedConnectionFile = await getPreloadedConnectionFile();
const preloadedConnectionMetadata = await fs.readFile(
path.join(SAML_XML_DIR, preloadedConnectionFile),
"utf8"
);
return preloadedConnectionMetadata;
};
const getConnectionPayload = (metadata: string): SAMLSSOConnectionWithEncodedMetadata => {
const encodedRawMetadata = Buffer.from(metadata, "utf8").toString("base64");
return {
name: "SAML SSO",
defaultRedirectUrl: `${WEBAPP_URL}/auth/login`,
redirectUrl: [`${WEBAPP_URL}/*`],
tenant: SAML_TENANT,
product: SAML_PRODUCT,
encodedRawMetadata,
};
};
export const preloadConnection = async (connectionController: ConnectionAPIController) => {
try {
const preloadedConnectionMetadata = await getPreloadedConnectionMetadata();
if (!preloadedConnectionMetadata) {
console.log("No preloaded connection metadata found");
return;
}
const connections = await connectionController.getConnections({
tenant: SAML_TENANT,
product: SAML_PRODUCT,
});
const existingConnection = connections[0];
const connection = getConnectionPayload(preloadedConnectionMetadata);
let newConnection: SAMLSSORecord;
try {
newConnection = await connectionController.createSAMLConnection(connection);
} catch (error) {
throw new Error(`Metadata is not valid\n${error.message}`);
}
if (newConnection && existingConnection && newConnection.clientID !== existingConnection.clientID) {
await connectionController.deleteConnections({
clientID: existingConnection.clientID,
clientSecret: existingConnection.clientSecret,
product: existingConnection.product,
tenant: existingConnection.tenant,
});
}
} catch (error) {
console.error("Error preloading connection:", error.message);
}
};
@@ -0,0 +1,110 @@
import { preloadConnection } from "@/modules/ee/auth/saml/lib/preload-connection";
import { getIsSamlSsoEnabled } from "@/modules/ee/license-check/lib/utils";
import { controllers } from "@boxyhq/saml-jackson";
import { afterEach, beforeEach, describe, expect, test, vi } from "vitest";
import { SAML_AUDIENCE, SAML_DATABASE_URL, SAML_PATH, WEBAPP_URL } from "@formbricks/lib/constants";
import init from "../jackson";
vi.mock("@formbricks/lib/constants", () => ({
SAML_AUDIENCE: "test-audience",
SAML_DATABASE_URL: "test-db-url",
SAML_PATH: "/test-path",
WEBAPP_URL: "https://test-webapp-url.com",
}));
vi.mock("@/modules/ee/license-check/lib/utils", () => ({
getIsSamlSsoEnabled: vi.fn(),
}));
vi.mock("@/modules/ee/auth/saml/lib/preload-connection", () => ({
preloadConnection: vi.fn(),
}));
vi.mock("@boxyhq/saml-jackson", () => ({
controllers: vi.fn(),
}));
describe("SAML Jackson Initialization", () => {
const mockOAuthController = { name: "mockOAuthController" };
const mockConnectionController = { name: "mockConnectionController" };
beforeEach(() => {
vi.clearAllMocks();
global.oauthController = undefined;
global.connectionController = undefined;
vi.mocked(controllers).mockResolvedValue({
oauthController: mockOAuthController,
connectionAPIController: mockConnectionController,
} as any);
});
afterEach(() => {
vi.resetAllMocks();
});
test("initialize controllers when SAML SSO is enabled", async () => {
vi.mocked(getIsSamlSsoEnabled).mockResolvedValue(true);
const result = await init();
expect(getIsSamlSsoEnabled).toHaveBeenCalledTimes(1);
expect(controllers).toHaveBeenCalledWith({
externalUrl: WEBAPP_URL,
samlAudience: SAML_AUDIENCE,
samlPath: SAML_PATH,
db: {
engine: "sql",
type: "postgres",
url: SAML_DATABASE_URL,
},
});
expect(preloadConnection).toHaveBeenCalledWith(mockConnectionController);
expect(global.oauthController).toBe(mockOAuthController);
expect(global.connectionController).toBe(mockConnectionController);
expect(result).toEqual({
oauthController: mockOAuthController,
connectionController: mockConnectionController,
});
});
test("return early when SAML SSO is disabled", async () => {
vi.mocked(getIsSamlSsoEnabled).mockResolvedValue(false);
const result = await init();
expect(getIsSamlSsoEnabled).toHaveBeenCalledTimes(1);
expect(controllers).not.toHaveBeenCalled();
expect(preloadConnection).not.toHaveBeenCalled();
expect(global.oauthController).toBeUndefined();
expect(global.connectionController).toBeUndefined();
expect(result).toBeUndefined();
});
test("reuse existing controllers if already initialized", async () => {
global.oauthController = mockOAuthController as any;
global.connectionController = mockConnectionController as any;
const result = await init();
expect(getIsSamlSsoEnabled).not.toHaveBeenCalled();
expect(controllers).not.toHaveBeenCalled();
expect(preloadConnection).not.toHaveBeenCalled();
expect(result).toEqual({
oauthController: mockOAuthController,
connectionController: mockConnectionController,
});
});
});
@@ -0,0 +1,142 @@
import fs from "fs/promises";
import path from "path";
import { afterEach, beforeEach, describe, expect, test, vi } from "vitest";
import { SAML_PRODUCT, SAML_TENANT, SAML_XML_DIR, WEBAPP_URL } from "@formbricks/lib/constants";
import { preloadConnection } from "../preload-connection";
vi.mock("@formbricks/lib/constants", () => ({
SAML_PRODUCT: "test-product",
SAML_TENANT: "test-tenant",
SAML_XML_DIR: "test-xml-dir",
WEBAPP_URL: "https://test-webapp-url.com",
}));
vi.mock("fs/promises", () => ({
default: {
readdir: vi.fn(),
readFile: vi.fn(),
},
}));
vi.mock("path", () => ({
default: {
join: vi.fn(),
},
}));
vi.mock("@boxyhq/saml-jackson", () => ({
SAMLSSOConnectionWithEncodedMetadata: vi.fn(),
}));
vi.mock("@boxyhq/saml-jackson/dist/controller/api", () => ({
ConnectionAPIController: vi.fn(),
}));
describe("SAML Preload Connection", () => {
const mockConnectionController = {
getConnections: vi.fn(),
createSAMLConnection: vi.fn(),
deleteConnections: vi.fn(),
};
const mockMetadata = "<EntityDescriptor>SAML Metadata</EntityDescriptor>";
const mockEncodedMetadata = Buffer.from(mockMetadata, "utf8").toString("base64");
const mockExistingConnection = {
clientID: "existing-client-id",
clientSecret: "existing-client-secret",
product: SAML_PRODUCT,
tenant: SAML_TENANT,
};
const mockNewConnection = {
clientID: "new-client-id",
clientSecret: "new-client-secret",
};
beforeEach(() => {
vi.clearAllMocks();
vi.mocked(path.join).mockImplementation((...args) => args.join("/"));
vi.mocked(fs.readdir).mockResolvedValue(["metadata.xml", "other-file.txt"] as any);
vi.mocked(fs.readFile).mockResolvedValue(mockMetadata as any);
mockConnectionController.getConnections.mockResolvedValue([mockExistingConnection]);
mockConnectionController.createSAMLConnection.mockResolvedValue(mockNewConnection);
});
afterEach(() => {
vi.resetAllMocks();
});
test("preload connection from XML file", async () => {
await preloadConnection(mockConnectionController as any);
expect(fs.readdir).toHaveBeenCalledWith(path.join(SAML_XML_DIR));
expect(fs.readFile).toHaveBeenCalledWith(path.join(SAML_XML_DIR, "metadata.xml"), "utf8");
expect(mockConnectionController.getConnections).toHaveBeenCalledWith({
tenant: SAML_TENANT,
product: SAML_PRODUCT,
});
expect(mockConnectionController.createSAMLConnection).toHaveBeenCalledWith({
name: "SAML SSO",
defaultRedirectUrl: `${WEBAPP_URL}/auth/login`,
redirectUrl: [`${WEBAPP_URL}/*`],
tenant: SAML_TENANT,
product: SAML_PRODUCT,
encodedRawMetadata: mockEncodedMetadata,
});
expect(mockConnectionController.deleteConnections).toHaveBeenCalledWith({
clientID: mockExistingConnection.clientID,
clientSecret: mockExistingConnection.clientSecret,
product: mockExistingConnection.product,
tenant: mockExistingConnection.tenant,
});
});
test("not delete existing connection if client IDs match", async () => {
mockConnectionController.createSAMLConnection.mockResolvedValue({
clientID: mockExistingConnection.clientID,
});
await preloadConnection(mockConnectionController as any);
expect(mockConnectionController.deleteConnections).not.toHaveBeenCalled();
});
test("handle case when no XML files are found", async () => {
vi.mocked(fs.readdir).mockResolvedValue(["other-file.txt"] as any);
const consoleErrorSpy = vi.spyOn(console, "error");
await preloadConnection(mockConnectionController as any);
expect(consoleErrorSpy).toHaveBeenCalledWith(
"Error preloading connection:",
expect.stringContaining("No preloaded connection file found")
);
expect(mockConnectionController.createSAMLConnection).not.toHaveBeenCalled();
});
test("handle invalid metadata", async () => {
const errorMessage = "Invalid metadata";
mockConnectionController.createSAMLConnection.mockRejectedValue(new Error(errorMessage));
const consoleErrorSpy = vi.spyOn(console, "error");
await preloadConnection(mockConnectionController as any);
expect(consoleErrorSpy).toHaveBeenCalledWith(
"Error preloading connection:",
expect.stringContaining(errorMessage)
);
});
});
+18 -1
View File
@@ -90,6 +90,7 @@ const fetchLicenseForE2ETesting = async (): Promise<{
whitelabel: true,
removeBranding: true,
ai: true,
saml: true,
},
lastChecked: currentTime,
};
@@ -156,6 +157,7 @@ export const getEnterpriseLicense = async (): Promise<{
removeBranding: false,
contacts: false,
ai: false,
saml: false,
},
lastChecked: new Date(),
};
@@ -361,7 +363,7 @@ export const getIsTwoFactorAuthEnabled = async (): Promise<boolean> => {
return licenseFeatures.twoFactorAuth;
};
export const getIsSSOEnabled = async (): Promise<boolean> => {
export const getisSsoEnabled = async (): Promise<boolean> => {
if (E2E_TESTING) {
const previousResult = await fetchLicenseForE2ETesting();
return previousResult && previousResult.features ? previousResult.features.sso : false;
@@ -371,6 +373,21 @@ export const getIsSSOEnabled = async (): Promise<boolean> => {
return licenseFeatures.sso;
};
export const getIsSamlSsoEnabled = async (): Promise<boolean> => {
if (E2E_TESTING) {
const previousResult = await fetchLicenseForE2ETesting();
return previousResult && previousResult.features
? previousResult.features.sso && previousResult.features.saml
: false;
}
if (IS_FORMBRICKS_CLOUD) {
return false;
}
const licenseFeatures = await getLicenseFeatures();
if (!licenseFeatures) return false;
return licenseFeatures.sso && licenseFeatures.saml;
};
export const getIsOrganizationAIReady = async (billingPlan: Organization["billing"]["plan"]) => {
if (!IS_AI_CONFIGURED) return false;
if (E2E_TESTING) {
@@ -12,6 +12,7 @@ const ZEnterpriseLicenseFeatures = z.object({
removeBranding: z.boolean(),
twoFactorAuth: z.boolean(),
sso: z.boolean(),
saml: z.boolean(),
ai: z.boolean(),
});
+21
View File
@@ -0,0 +1,21 @@
"use server";
import { actionClient } from "@/lib/utils/action-client";
import jackson from "@/modules/ee/auth/saml/lib/jackson";
import { SAML_PRODUCT, SAML_TENANT } from "@formbricks/lib/constants";
export const doesSamlConnectionExistAction = actionClient.action(async () => {
const jacksonInstance = await jackson();
if (!jacksonInstance) {
return false;
}
const { connectionController } = jacksonInstance;
const connection = await connectionController.getConnections({
product: SAML_PRODUCT,
tenant: SAML_TENANT,
});
return connection.length === 1;
});
@@ -8,7 +8,7 @@ import { useCallback, useEffect } from "react";
import { FORMBRICKS_LOGGED_IN_WITH_LS } from "@formbricks/lib/localStorage";
interface AzureButtonProps {
inviteUrl?: string | null;
inviteUrl?: string;
directRedirect?: boolean;
lastUsed?: boolean;
}
@@ -7,7 +7,7 @@ import { signIn } from "next-auth/react";
import { FORMBRICKS_LOGGED_IN_WITH_LS } from "@formbricks/lib/localStorage";
interface GithubButtonProps {
inviteUrl?: string | null;
inviteUrl?: string;
lastUsed?: boolean;
}
@@ -7,7 +7,7 @@ import { signIn } from "next-auth/react";
import { FORMBRICKS_LOGGED_IN_WITH_LS } from "@formbricks/lib/localStorage";
interface GoogleButtonProps {
inviteUrl?: string | null;
inviteUrl?: string;
lastUsed?: boolean;
}
@@ -7,7 +7,7 @@ import { useCallback, useEffect } from "react";
import { FORMBRICKS_LOGGED_IN_WITH_LS } from "@formbricks/lib/localStorage";
interface OpenIdButtonProps {
inviteUrl?: string | null;
inviteUrl?: string;
lastUsed?: boolean;
directRedirect?: boolean;
text?: string;
@@ -0,0 +1,61 @@
"use client";
import { doesSamlConnectionExistAction } from "@/modules/ee/sso/actions";
import { Button } from "@/modules/ui/components/button";
import { useTranslate } from "@tolgee/react";
import { LockIcon } from "lucide-react";
import { signIn } from "next-auth/react";
import { useState } from "react";
import toast from "react-hot-toast";
import { FORMBRICKS_LOGGED_IN_WITH_LS } from "@formbricks/lib/localStorage";
interface SamlButtonProps {
inviteUrl?: string;
lastUsed?: boolean;
samlTenant: string;
samlProduct: string;
}
export const SamlButton = ({ inviteUrl, lastUsed, samlTenant, samlProduct }: SamlButtonProps) => {
const { t } = useTranslate();
const [isLoading, setIsLoading] = useState(false);
const handleLogin = async () => {
if (typeof window !== "undefined") {
localStorage.setItem(FORMBRICKS_LOGGED_IN_WITH_LS, "Saml");
}
setIsLoading(true);
const doesSamlConnectionExist = await doesSamlConnectionExistAction();
if (!doesSamlConnectionExist?.data) {
toast.error(t("auth.saml_connection_error"));
setIsLoading(false);
return;
}
signIn(
"saml",
{
redirect: true,
callbackUrl: inviteUrl ? inviteUrl : "/", // redirect after login to /
},
{
tenant: samlTenant,
product: samlProduct,
}
);
};
return (
<Button
type="button"
onClick={handleLogin}
variant="secondary"
className="relative w-full justify-center"
loading={isLoading}>
{t("auth.continue_with_saml")}
<LockIcon />
{lastUsed && <span className="absolute right-3 text-xs opacity-50">{t("auth.last_used")}</span>}
</Button>
);
};
@@ -1,10 +1,13 @@
"use client";
import { useTranslate } from "@tolgee/react";
import { useEffect, useState } from "react";
import { FORMBRICKS_LOGGED_IN_WITH_LS } from "@formbricks/lib/localStorage";
import { AzureButton } from "./azure-button";
import { GithubButton } from "./github-button";
import { GoogleButton } from "./google-button";
import { OpenIdButton } from "./open-id-button";
import { SamlButton } from "./saml-button";
interface SSOOptionsProps {
googleOAuthEnabled: boolean;
@@ -13,6 +16,9 @@ interface SSOOptionsProps {
oidcOAuthEnabled: boolean;
oidcDisplayName?: string;
callbackUrl: string;
samlSsoEnabled: boolean;
samlTenant: string;
samlProduct: string;
}
export const SSOOptions = ({
@@ -22,16 +28,42 @@ export const SSOOptions = ({
oidcOAuthEnabled,
oidcDisplayName,
callbackUrl,
samlSsoEnabled,
samlTenant,
samlProduct,
}: SSOOptionsProps) => {
const { t } = useTranslate();
const [lastLoggedInWith, setLastLoggedInWith] = useState("");
useEffect(() => {
if (typeof window !== "undefined") {
setLastLoggedInWith(localStorage.getItem(FORMBRICKS_LOGGED_IN_WITH_LS) || "");
}
}, []);
return (
<div className="space-y-2">
{googleOAuthEnabled && <GoogleButton inviteUrl={callbackUrl} />}
{githubOAuthEnabled && <GithubButton inviteUrl={callbackUrl} />}
{azureOAuthEnabled && <AzureButton inviteUrl={callbackUrl} />}
{googleOAuthEnabled && (
<GoogleButton inviteUrl={callbackUrl} lastUsed={lastLoggedInWith === "Google"} />
)}
{githubOAuthEnabled && (
<GithubButton inviteUrl={callbackUrl} lastUsed={lastLoggedInWith === "Github"} />
)}
{azureOAuthEnabled && <AzureButton inviteUrl={callbackUrl} lastUsed={lastLoggedInWith === "Azure"} />}
{oidcOAuthEnabled && (
<OpenIdButton inviteUrl={callbackUrl} text={t("auth.continue_with_oidc", { oidcDisplayName })} />
<OpenIdButton
inviteUrl={callbackUrl}
lastUsed={lastLoggedInWith === "OpenID"}
text={t("auth.continue_with_oidc", { oidcDisplayName })}
/>
)}
{samlSsoEnabled && (
<SamlButton
inviteUrl={callbackUrl}
lastUsed={lastLoggedInWith === "Saml"}
samlTenant={samlTenant}
samlProduct={samlProduct}
/>
)}
</div>
);
+31
View File
@@ -15,6 +15,7 @@ import {
OIDC_DISPLAY_NAME,
OIDC_ISSUER,
OIDC_SIGNING_ALGORITHM,
WEBAPP_URL,
} from "@formbricks/lib/constants";
export const getSSOProviders = () => [
@@ -54,6 +55,36 @@ export const getSSOProviders = () => [
};
},
},
{
id: "saml",
name: "BoxyHQ SAML",
type: "oauth" as const,
version: "2.0",
checks: ["pkce" as const, "state" as const],
authorization: {
url: `${WEBAPP_URL}/api/auth/saml/authorize`,
params: {
scope: "",
response_type: "code",
provider: "saml",
},
},
token: `${WEBAPP_URL}/api/auth/saml/token`,
userinfo: `${WEBAPP_URL}/api/auth/saml/userinfo`,
profile(profile) {
return {
id: profile.id,
email: profile.email,
name: [profile.firstName, profile.lastName].filter(Boolean).join(" "),
image: null,
};
},
options: {
clientId: "dummy",
clientSecret: "dummy",
},
allowDangerousEmailAccountLinking: true,
},
];
export type { IdentityProvider };
+15 -1
View File
@@ -1,6 +1,7 @@
import { createBrevoCustomer } from "@/modules/auth/lib/brevo";
import { getUserByEmail, updateUser } from "@/modules/auth/lib/user";
import { createUser } from "@/modules/auth/lib/user";
import { getIsSamlSsoEnabled, getisSsoEnabled } from "@/modules/ee/license-check/lib/utils";
import type { IdentityProvider } from "@prisma/client";
import type { Account } from "next-auth";
import { prisma } from "@formbricks/database";
@@ -12,12 +13,25 @@ import { findMatchingLocale } from "@formbricks/lib/utils/locale";
import type { TUser, TUserNotificationSettings } from "@formbricks/types/user";
export const handleSSOCallback = async ({ user, account }: { user: TUser; account: Account }) => {
const isSsoEnabled = await getisSsoEnabled();
if (!isSsoEnabled) {
return false;
}
if (!user.email || account.type !== "oauth") {
return false;
}
let provider = account.provider.toLowerCase().replace("-", "") as IdentityProvider;
if (provider === "saml") {
const isSamlSsoEnabled = await getIsSamlSsoEnabled();
if (!isSamlSsoEnabled) {
return false;
}
}
if (account.provider) {
const provider = account.provider.toLowerCase().replace("-", "") as IdentityProvider;
// check if accounts for this provider / account Id already exists
const existingUserWithAccount = await prisma.user.findFirst({
include: {
@@ -1,5 +1,5 @@
import { SignupForm } from "@/modules/auth/signup/components/signup-form";
import { getIsSSOEnabled } from "@/modules/ee/license-check/lib/utils";
import { getIsSamlSsoEnabled, getisSsoEnabled } from "@/modules/ee/license-check/lib/utils";
import { getTranslate } from "@/tolgee/server";
import { Metadata } from "next";
import {
@@ -14,6 +14,9 @@ import {
OIDC_DISPLAY_NAME,
OIDC_OAUTH_ENABLED,
PRIVACY_URL,
SAML_OAUTH_ENABLED,
SAML_PRODUCT,
SAML_TENANT,
TERMS_URL,
WEBAPP_URL,
} from "@formbricks/lib/constants";
@@ -26,7 +29,11 @@ export const metadata: Metadata = {
export const SignupPage = async () => {
const locale = await findMatchingLocale();
const isSSOEnabled = await getIsSSOEnabled();
const [isSsoEnabled, isSamlSsoEnabled] = await Promise.all([getisSsoEnabled(), getIsSamlSsoEnabled()]);
const samlSsoEnabled = isSamlSsoEnabled && SAML_OAUTH_ENABLED;
const t = await getTranslate();
return (
<div className="flex flex-col items-center">
@@ -47,8 +54,11 @@ export const SignupPage = async () => {
userLocale={locale}
defaultOrganizationId={DEFAULT_ORGANIZATION_ID}
defaultOrganizationRole={DEFAULT_ORGANIZATION_ROLE}
isSSOEnabled={isSSOEnabled}
isSsoEnabled={isSsoEnabled}
samlSsoEnabled={samlSsoEnabled}
isTurnstileConfigured={IS_TURNSTILE_CONFIGURED}
samlTenant={SAML_TENANT}
samlProduct={SAML_PRODUCT}
/>
</div>
);
+1
View File
@@ -15,6 +15,7 @@
},
"dependencies": {
"@ai-sdk/azure": "1.1.9",
"@boxyhq/saml-jackson": "1.37.1",
"@dnd-kit/core": "6.3.1",
"@dnd-kit/modifiers": "9.0.0",
"@dnd-kit/sortable": "10.0.0",