chore: block signin with SSO when user is not found (#5233)

Co-authored-by: pandeymangg <anshuman.pandey9999@gmail.com>
This commit is contained in:
Piyush Gupta
2025-04-06 09:52:53 +05:30
committed by GitHub
parent ec314c14ea
commit c653841037
27 changed files with 1450 additions and 53 deletions

View File

@@ -9,7 +9,11 @@ export function pickCommonFilter<T extends TGetFilter>(params: T) {
return { limit, skip, sortBy, order, startDate, endDate };
}
type HasFindMany = Prisma.WebhookFindManyArgs | Prisma.ResponseFindManyArgs | Prisma.TeamFindManyArgs | Prisma.ProjectTeamFindManyArgs;
type HasFindMany =
| Prisma.WebhookFindManyArgs
| Prisma.ResponseFindManyArgs
| Prisma.TeamFindManyArgs
| Prisma.ProjectTeamFindManyArgs;
export function buildCommonFilterQuery<T extends HasFindMany>(query: T, params: TGetFilter): T {
const { limit, skip, sortBy, order, startDate, endDate } = params || {};

View File

@@ -4,6 +4,7 @@ import { getSSOProviders } from "@/modules/ee/sso/lib/providers";
import { handleSsoCallback } from "@/modules/ee/sso/lib/sso-handlers";
import type { Account, NextAuthOptions } from "next-auth";
import CredentialsProvider from "next-auth/providers/credentials";
import { cookies } from "next/headers";
import { prisma } from "@formbricks/database";
import {
EMAIL_VERIFICATION_DISABLED,
@@ -208,6 +209,10 @@ export const authOptions: NextAuthOptions = {
return session;
},
async signIn({ user, account }: { user: TUser; account: Account }) {
const cookieStore = await cookies();
const callbackUrl = cookieStore.get("next-auth.callback-url")?.value || "";
if (account?.provider === "credentials" || account?.provider === "token") {
// check if user's email is verified or not
if (!user.emailVerified && !EMAIL_VERIFICATION_DISABLED) {
@@ -217,7 +222,7 @@ export const authOptions: NextAuthOptions = {
return true;
}
if (ENTERPRISE_LICENSE_KEY) {
const result = await handleSsoCallback({ user, account });
const result = await handleSsoCallback({ user, account, callbackUrl });
if (result) {
await updateUserLastLoginAt(user.email);
}

View File

@@ -256,6 +256,7 @@ export const LoginForm = ({
samlTenant={samlTenant}
samlProduct={samlProduct}
callbackUrl={callbackUrl}
source="signin"
/>
)}
</div>

View File

@@ -280,6 +280,7 @@ export const SignupForm = ({
samlTenant={samlTenant}
samlProduct={samlProduct}
callbackUrl={callbackUrl}
source="signup"
/>
)}
<TermsPrivacyLinks termsUrl={termsUrl} privacyUrl={privacyUrl} />

View File

@@ -0,0 +1,246 @@
import { inviteCache } from "@/lib/cache/invite";
import { Prisma } from "@prisma/client";
import { beforeEach, describe, expect, it, vi } from "vitest";
import { prisma } from "@formbricks/database";
import { PrismaErrorType } from "@formbricks/database/types/error";
import { logger } from "@formbricks/logger";
import { DatabaseError } from "@formbricks/types/errors";
import { deleteInvite, getInvite, getIsValidInviteToken } from "./invite";
// Mock data
const mockInviteId = "test-invite-id";
const mockOrganizationId = "test-org-id";
const mockCreatorId = "test-creator-id";
const mockInvite = {
id: mockInviteId,
email: "test@test.com",
name: "Test Name",
organizationId: mockOrganizationId,
creatorId: mockCreatorId,
acceptorId: null,
createdAt: new Date(),
expiresAt: new Date(Date.now() + 24 * 60 * 60 * 1000), // 24 hours from now
deprecatedRole: null,
role: "member" as const,
teamIds: ["team-1"],
creator: {
name: "Test Creator",
email: "creator@test.com",
locale: "en",
},
};
// Mock prisma methods
vi.mock("@formbricks/database", () => ({
prisma: {
invite: {
delete: vi.fn(),
findUnique: vi.fn(),
},
},
}));
// Mock cache
vi.mock("@/lib/cache/invite", () => ({
inviteCache: {
revalidate: vi.fn(),
tag: {
byId: (id: string) => `invite-${id}`,
},
},
}));
// Mock logger
vi.mock("@formbricks/logger", () => ({
logger: {
error: vi.fn(),
},
}));
describe("Invite Management", () => {
beforeEach(() => {
vi.clearAllMocks();
});
describe("deleteInvite", () => {
it("deletes an invite successfully and invalidates cache", async () => {
vi.mocked(prisma.invite.delete).mockResolvedValue(mockInvite);
const result = await deleteInvite(mockInviteId);
expect(result).toBe(true);
expect(prisma.invite.delete).toHaveBeenCalledWith({
where: { id: mockInviteId },
select: { id: true, organizationId: true },
});
expect(inviteCache.revalidate).toHaveBeenCalledWith({
id: mockInviteId,
organizationId: mockOrganizationId,
});
});
it("throws DatabaseError when invite doesn't exist", async () => {
const errToThrow = new Prisma.PrismaClientKnownRequestError("Record not found", {
code: PrismaErrorType.RecordDoesNotExist,
clientVersion: "0.0.1",
});
vi.mocked(prisma.invite.delete).mockRejectedValue(errToThrow);
await expect(deleteInvite(mockInviteId)).rejects.toThrow(DatabaseError);
});
it("throws DatabaseError for other Prisma errors", async () => {
const errToThrow = new Prisma.PrismaClientKnownRequestError("Database error", {
code: "P2002",
clientVersion: "0.0.1",
});
vi.mocked(prisma.invite.delete).mockRejectedValue(errToThrow);
await expect(deleteInvite(mockInviteId)).rejects.toThrow(DatabaseError);
});
it("throws DatabaseError for generic errors", async () => {
vi.mocked(prisma.invite.delete).mockRejectedValue(new Error("Generic error"));
await expect(deleteInvite(mockInviteId)).rejects.toThrow(DatabaseError);
});
});
describe("getInvite", () => {
it("retrieves an invite with creator details successfully", async () => {
vi.mocked(prisma.invite.findUnique).mockResolvedValue(mockInvite);
const result = await getInvite(mockInviteId);
expect(result).toEqual(mockInvite);
expect(prisma.invite.findUnique).toHaveBeenCalledWith({
where: { id: mockInviteId },
select: {
id: true,
organizationId: true,
role: true,
teamIds: true,
creator: {
select: {
name: true,
email: true,
locale: true,
},
},
},
});
});
it("returns null when invite doesn't exist", async () => {
vi.mocked(prisma.invite.findUnique).mockResolvedValue(null);
const result = await getInvite(mockInviteId);
expect(result).toBeNull();
});
it("throws DatabaseError on prisma error", async () => {
const errToThrow = new Prisma.PrismaClientKnownRequestError("Database error", {
code: "P2002",
clientVersion: "0.0.1",
});
vi.mocked(prisma.invite.findUnique).mockRejectedValue(errToThrow);
await expect(getInvite(mockInviteId)).rejects.toThrow(DatabaseError);
});
it("throws DatabaseError for generic errors", async () => {
vi.mocked(prisma.invite.findUnique).mockRejectedValue(new Error("Generic error"));
await expect(getInvite(mockInviteId)).rejects.toThrow(DatabaseError);
});
});
describe("getIsValidInviteToken", () => {
it("returns true for valid invite", async () => {
vi.mocked(prisma.invite.findUnique).mockResolvedValue(mockInvite);
const result = await getIsValidInviteToken(mockInviteId);
expect(result).toBe(true);
expect(prisma.invite.findUnique).toHaveBeenCalledWith({
where: { id: mockInviteId },
});
});
it("returns false when invite doesn't exist", async () => {
vi.mocked(prisma.invite.findUnique).mockResolvedValue(null);
const result = await getIsValidInviteToken(mockInviteId);
expect(result).toBe(false);
});
it("returns false for expired invite", async () => {
const expiredInvite = {
...mockInvite,
expiresAt: new Date(Date.now() - 24 * 60 * 60 * 1000), // 24 hours ago
};
vi.mocked(prisma.invite.findUnique).mockResolvedValue(expiredInvite);
const result = await getIsValidInviteToken(mockInviteId);
expect(result).toBe(false);
expect(logger.error).toHaveBeenCalledWith(
{
inviteId: mockInviteId,
expiresAt: expiredInvite.expiresAt,
},
"SSO: Invite token expired"
);
});
it("returns false and logs error when database error occurs", async () => {
const error = new Error("Database error");
vi.mocked(prisma.invite.findUnique).mockRejectedValue(error);
const result = await getIsValidInviteToken(mockInviteId);
expect(result).toBe(false);
expect(logger.error).toHaveBeenCalledWith(error, "Error getting invite");
});
it("returns false for invite with null expiresAt", async () => {
const invalidInvite = {
...mockInvite,
expiresAt: null,
};
vi.mocked(prisma.invite.findUnique).mockResolvedValue(invalidInvite);
const result = await getIsValidInviteToken(mockInviteId);
expect(result).toBe(false);
expect(logger.error).toHaveBeenCalledWith(
{
inviteId: mockInviteId,
expiresAt: null,
},
"SSO: Invite token expired"
);
});
it("returns false for invite with invalid expiresAt", async () => {
const invalidInvite = {
...mockInvite,
expiresAt: new Date("invalid-date"),
};
vi.mocked(prisma.invite.findUnique).mockResolvedValue(invalidInvite);
const result = await getIsValidInviteToken(mockInviteId);
expect(result).toBe(false);
expect(logger.error).toHaveBeenCalledWith(
{
inviteId: mockInviteId,
expiresAt: invalidInvite.expiresAt,
},
"SSO: Invite token expired"
);
});
});
});

View File

@@ -4,6 +4,7 @@ import { Prisma } from "@prisma/client";
import { cache as reactCache } from "react";
import { prisma } from "@formbricks/database";
import { cache } from "@formbricks/lib/cache";
import { logger } from "@formbricks/logger";
import { DatabaseError, ResourceNotFoundError } from "@formbricks/types/errors";
export const deleteInvite = async (inviteId: string): Promise<boolean> => {
@@ -32,8 +33,7 @@ export const deleteInvite = async (inviteId: string): Promise<boolean> => {
if (error instanceof Prisma.PrismaClientKnownRequestError) {
throw new DatabaseError(error.message);
}
throw error;
throw new DatabaseError(error instanceof Error ? error.message : "Unknown error occurred");
}
};
@@ -66,8 +66,7 @@ export const getInvite = reactCache(
if (error instanceof Prisma.PrismaClientKnownRequestError) {
throw new DatabaseError(error.message);
}
throw error;
throw new DatabaseError(error instanceof Error ? error.message : "Unknown error occurred");
}
},
[`signup-getInvite-${inviteId}`],
@@ -76,3 +75,47 @@ export const getInvite = reactCache(
}
)()
);
export const getIsValidInviteToken = reactCache(
async (inviteId: string): Promise<boolean> =>
cache(
async () => {
try {
const invite = await prisma.invite.findUnique({
where: { id: inviteId },
});
if (!invite) {
return false;
}
if (!invite.expiresAt || isNaN(invite.expiresAt.getTime())) {
logger.error(
{
inviteId,
expiresAt: invite.expiresAt,
},
"SSO: Invite token expired"
);
return false;
}
if (invite.expiresAt < new Date()) {
logger.error(
{
inviteId,
expiresAt: invite.expiresAt,
},
"SSO: Invite token expired"
);
return false;
}
return true;
} catch (err) {
logger.error(err, "Error getting invite");
return false;
}
},
[`getIsValidInviteToken-${inviteId}`],
{
tags: [inviteCache.tag.byId(inviteId)],
}
)()
);

View File

@@ -0,0 +1,154 @@
import { getIsValidInviteToken } from "@/modules/auth/signup/lib/invite";
import {
getIsMultiOrgEnabled,
getIsSamlSsoEnabled,
getisSsoEnabled,
} from "@/modules/ee/license-check/lib/utils";
import { cleanup, render, screen } from "@testing-library/react";
import { notFound } from "next/navigation";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { verifyInviteToken } from "@formbricks/lib/jwt";
import { findMatchingLocale } from "@formbricks/lib/utils/locale";
import { SignupPage } from "./page";
// Mock the necessary dependencies
vi.mock("@/modules/auth/components/testimonial", () => ({
Testimonial: () => <div data-testid="testimonial">Testimonial</div>,
}));
vi.mock("@/modules/auth/components/form-wrapper", () => ({
FormWrapper: ({ children }: { children: React.ReactNode }) => (
<div data-testid="form-wrapper">{children}</div>
),
}));
vi.mock("@/modules/auth/signup/components/signup-form", () => ({
SignupForm: () => <div data-testid="signup-form">SignupForm</div>,
}));
vi.mock("@/modules/ee/license-check/lib/utils", () => ({
getIsMultiOrgEnabled: vi.fn(),
getIsSamlSsoEnabled: vi.fn(),
getisSsoEnabled: vi.fn(),
}));
vi.mock("@/modules/auth/signup/lib/invite", () => ({
getIsValidInviteToken: vi.fn(),
}));
vi.mock("@formbricks/lib/jwt", () => ({
verifyInviteToken: vi.fn(),
}));
vi.mock("@formbricks/lib/utils/locale", () => ({
findMatchingLocale: vi.fn(),
}));
vi.mock("next/navigation", () => ({
notFound: vi.fn(),
}));
// Mock environment variables and constants
vi.mock("@formbricks/lib/constants", () => ({
SIGNUP_ENABLED: true,
EMAIL_AUTH_ENABLED: true,
EMAIL_VERIFICATION_DISABLED: false,
GOOGLE_OAUTH_ENABLED: true,
GITHUB_OAUTH_ENABLED: true,
AZURE_OAUTH_ENABLED: true,
OIDC_OAUTH_ENABLED: true,
OIDC_DISPLAY_NAME: "OpenID",
SAML_OAUTH_ENABLED: true,
SAML_TENANT: "test-tenant",
SAML_PRODUCT: "test-product",
IS_TURNSTILE_CONFIGURED: true,
WEBAPP_URL: "http://localhost:3000",
TERMS_URL: "http://localhost:3000/terms",
PRIVACY_URL: "http://localhost:3000/privacy",
DEFAULT_ORGANIZATION_ID: "test-org-id",
DEFAULT_ORGANIZATION_ROLE: "admin",
}));
describe("SignupPage", () => {
const mockSearchParams = {
inviteToken: "test-token",
email: "test@example.com",
};
beforeEach(() => {
vi.clearAllMocks();
});
afterEach(() => {
cleanup();
});
it("renders the signup page with all components when signup is enabled", async () => {
// Mock the license check functions to return true
vi.mocked(getIsMultiOrgEnabled).mockResolvedValue(true);
vi.mocked(getisSsoEnabled).mockResolvedValue(true);
vi.mocked(getIsSamlSsoEnabled).mockResolvedValue(true);
vi.mocked(findMatchingLocale).mockResolvedValue("en");
vi.mocked(verifyInviteToken).mockReturnValue({ inviteId: "test-invite-id" });
vi.mocked(getIsValidInviteToken).mockResolvedValue(true);
const result = await SignupPage({ searchParams: mockSearchParams });
render(result);
// Verify that all components are rendered
expect(screen.getByTestId("testimonial")).toBeInTheDocument();
expect(screen.getByTestId("form-wrapper")).toBeInTheDocument();
expect(screen.getByTestId("signup-form")).toBeInTheDocument();
});
it("calls notFound when signup is disabled and no valid invite token is provided", async () => {
// Mock the license check functions to return false
vi.mocked(getIsMultiOrgEnabled).mockResolvedValue(false);
vi.mocked(verifyInviteToken).mockImplementation(() => {
throw new Error("Invalid token");
});
await SignupPage({ searchParams: {} });
expect(notFound).toHaveBeenCalled();
});
it("calls notFound when invite token is invalid", async () => {
// Mock the license check functions to return false
vi.mocked(getIsMultiOrgEnabled).mockResolvedValue(false);
vi.mocked(verifyInviteToken).mockImplementation(() => {
throw new Error("Invalid token");
});
await SignupPage({ searchParams: { inviteToken: "invalid-token" } });
expect(notFound).toHaveBeenCalled();
});
it("calls notFound when invite token is valid but invite is not found", async () => {
// Mock the license check functions to return false
vi.mocked(getIsMultiOrgEnabled).mockResolvedValue(false);
vi.mocked(verifyInviteToken).mockReturnValue({ inviteId: "test-invite-id" });
vi.mocked(getIsValidInviteToken).mockResolvedValue(false);
await SignupPage({ searchParams: { inviteToken: "test-token" } });
expect(notFound).toHaveBeenCalled();
});
it("renders the page with email from search params", async () => {
// Mock the license check functions to return true
vi.mocked(getIsMultiOrgEnabled).mockResolvedValue(true);
vi.mocked(getisSsoEnabled).mockResolvedValue(true);
vi.mocked(getIsSamlSsoEnabled).mockResolvedValue(true);
vi.mocked(findMatchingLocale).mockResolvedValue("en");
vi.mocked(verifyInviteToken).mockReturnValue({ inviteId: "test-invite-id" });
vi.mocked(getIsValidInviteToken).mockResolvedValue(true);
const result = await SignupPage({ searchParams: { email: "test@example.com" } });
render(result);
// Verify that the form is rendered with the email from search params
expect(screen.getByTestId("signup-form")).toBeInTheDocument();
});
});

View File

@@ -1,5 +1,6 @@
import { FormWrapper } from "@/modules/auth/components/form-wrapper";
import { Testimonial } from "@/modules/auth/components/testimonial";
import { getIsValidInviteToken } from "@/modules/auth/signup/lib/invite";
import {
getIsMultiOrgEnabled,
getIsSamlSsoEnabled,
@@ -25,6 +26,7 @@ import {
TERMS_URL,
WEBAPP_URL,
} from "@formbricks/lib/constants";
import { verifyInviteToken } from "@formbricks/lib/jwt";
import { findMatchingLocale } from "@formbricks/lib/utils/locale";
import { SignupForm } from "./components/signup-form";
@@ -38,11 +40,20 @@ export const SignupPage = async ({ searchParams: searchParamsProps }) => {
]);
const samlSsoEnabled = isSamlSsoEnabled && SAML_OAUTH_ENABLED;
const locale = await findMatchingLocale();
if (!inviteToken && (!SIGNUP_ENABLED || !isMultOrgEnabled)) {
notFound();
if (!SIGNUP_ENABLED || !isMultOrgEnabled) {
if (!inviteToken) notFound();
try {
const { inviteId } = verifyInviteToken(inviteToken);
const isValidInviteToken = await getIsValidInviteToken(inviteId);
if (!isValidInviteToken) notFound();
} catch {
notFound();
}
}
const emailFromSearchParams = searchParams["email"];
return (

View File

@@ -0,0 +1,84 @@
import { cleanup, fireEvent, render, screen } from "@testing-library/react";
import { signIn } from "next-auth/react";
import { afterEach, describe, expect, it, vi } from "vitest";
import { FORMBRICKS_LOGGED_IN_WITH_LS } from "@formbricks/lib/localStorage";
import { AzureButton } from "./azure-button";
// Mock next-auth/react
vi.mock("next-auth/react", () => ({
signIn: vi.fn(),
}));
// Mock localStorage
const mockLocalStorage = {
setItem: vi.fn(),
};
Object.defineProperty(window, "localStorage", {
value: mockLocalStorage,
writable: true,
});
describe("AzureButton", () => {
const defaultProps = {
source: "signin" as const,
};
afterEach(() => {
cleanup();
vi.clearAllMocks();
});
it("renders correctly with default props", () => {
render(<AzureButton {...defaultProps} />);
const button = screen.getByRole("button", { name: "auth.continue_with_azure" });
expect(button).toBeInTheDocument();
});
it("renders with last used indicator when lastUsed is true", () => {
render(<AzureButton {...defaultProps} lastUsed={true} />);
expect(screen.getByText("auth.last_used")).toBeInTheDocument();
});
it("sets localStorage item and calls signIn on click", async () => {
render(<AzureButton {...defaultProps} />);
const button = screen.getByRole("button", { name: "auth.continue_with_azure" });
fireEvent.click(button);
expect(mockLocalStorage.setItem).toHaveBeenCalledWith(FORMBRICKS_LOGGED_IN_WITH_LS, "Azure");
expect(signIn).toHaveBeenCalledWith("azure-ad", {
redirect: true,
callbackUrl: "/?source=signin",
});
});
it("uses inviteUrl in callbackUrl when provided", async () => {
const inviteUrl = "https://example.com/invite";
render(<AzureButton {...defaultProps} inviteUrl={inviteUrl} />);
const button = screen.getByRole("button", { name: "auth.continue_with_azure" });
fireEvent.click(button);
expect(signIn).toHaveBeenCalledWith("azure-ad", {
redirect: true,
callbackUrl: "https://example.com/invite?source=signin",
});
});
it("handles signup source correctly", async () => {
render(<AzureButton {...defaultProps} source="signup" />);
const button = screen.getByRole("button", { name: "auth.continue_with_azure" });
fireEvent.click(button);
expect(signIn).toHaveBeenCalledWith("azure-ad", {
redirect: true,
callbackUrl: "/?source=signup",
});
});
it("triggers direct redirect when directRedirect is true", () => {
render(<AzureButton {...defaultProps} directRedirect={true} />);
expect(signIn).toHaveBeenCalledWith("azure-ad", {
redirect: true,
callbackUrl: "/?source=signin",
});
});
});

View File

@@ -1,5 +1,6 @@
"use client";
import { getCallbackUrl } from "@/modules/ee/sso/lib/utils";
import { Button } from "@/modules/ui/components/button";
import { MicrosoftIcon } from "@/modules/ui/components/icons";
import { useTranslate } from "@tolgee/react";
@@ -11,20 +12,22 @@ interface AzureButtonProps {
inviteUrl?: string;
directRedirect?: boolean;
lastUsed?: boolean;
source: "signin" | "signup";
}
export const AzureButton = ({ inviteUrl, directRedirect = false, lastUsed }: AzureButtonProps) => {
export const AzureButton = ({ inviteUrl, directRedirect = false, lastUsed, source }: AzureButtonProps) => {
const { t } = useTranslate();
const handleLogin = useCallback(async () => {
if (typeof window !== "undefined") {
localStorage.setItem(FORMBRICKS_LOGGED_IN_WITH_LS, "Azure");
}
const callbackUrlWithSource = getCallbackUrl(inviteUrl, source);
await signIn("azure-ad", {
redirect: true,
callbackUrl: inviteUrl ? inviteUrl : "/",
callbackUrl: callbackUrlWithSource,
});
}, [inviteUrl]);
}, [inviteUrl, source]);
useEffect(() => {
if (directRedirect) {

View File

@@ -0,0 +1,76 @@
import { cleanup, fireEvent, render, screen } from "@testing-library/react";
import { signIn } from "next-auth/react";
import { afterEach, describe, expect, it, vi } from "vitest";
import { FORMBRICKS_LOGGED_IN_WITH_LS } from "@formbricks/lib/localStorage";
import { GithubButton } from "./github-button";
// Mock next-auth/react
vi.mock("next-auth/react", () => ({
signIn: vi.fn(),
}));
// Mock localStorage
const mockLocalStorage = {
setItem: vi.fn(),
};
Object.defineProperty(window, "localStorage", {
value: mockLocalStorage,
writable: true,
});
describe("GithubButton", () => {
const defaultProps = {
source: "signin" as const,
};
afterEach(() => {
cleanup();
vi.clearAllMocks();
});
it("renders correctly with default props", () => {
render(<GithubButton {...defaultProps} />);
const button = screen.getByRole("button", { name: "auth.continue_with_github" });
expect(button).toBeInTheDocument();
});
it("renders with last used indicator when lastUsed is true", () => {
render(<GithubButton {...defaultProps} lastUsed={true} />);
expect(screen.getByText("auth.last_used")).toBeInTheDocument();
});
it("sets localStorage item and calls signIn on click", async () => {
render(<GithubButton {...defaultProps} />);
const button = screen.getByRole("button", { name: "auth.continue_with_github" });
fireEvent.click(button);
expect(mockLocalStorage.setItem).toHaveBeenCalledWith(FORMBRICKS_LOGGED_IN_WITH_LS, "Github");
expect(signIn).toHaveBeenCalledWith("github", {
redirect: true,
callbackUrl: "/?source=signin",
});
});
it("uses inviteUrl in callbackUrl when provided", async () => {
const inviteUrl = "https://example.com/invite";
render(<GithubButton {...defaultProps} inviteUrl={inviteUrl} />);
const button = screen.getByRole("button", { name: "auth.continue_with_github" });
fireEvent.click(button);
expect(signIn).toHaveBeenCalledWith("github", {
redirect: true,
callbackUrl: "https://example.com/invite?source=signin",
});
});
it("handles signup source correctly", async () => {
render(<GithubButton {...defaultProps} source="signup" />);
const button = screen.getByRole("button", { name: "auth.continue_with_github" });
fireEvent.click(button);
expect(signIn).toHaveBeenCalledWith("github", {
redirect: true,
callbackUrl: "/?source=signup",
});
});
});

View File

@@ -1,5 +1,6 @@
"use client";
import { getCallbackUrl } from "@/modules/ee/sso/lib/utils";
import { Button } from "@/modules/ui/components/button";
import { GithubIcon } from "@/modules/ui/components/icons";
import { useTranslate } from "@tolgee/react";
@@ -9,17 +10,20 @@ import { FORMBRICKS_LOGGED_IN_WITH_LS } from "@formbricks/lib/localStorage";
interface GithubButtonProps {
inviteUrl?: string;
lastUsed?: boolean;
source: "signin" | "signup";
}
export const GithubButton = ({ inviteUrl, lastUsed }: GithubButtonProps) => {
export const GithubButton = ({ inviteUrl, lastUsed, source }: GithubButtonProps) => {
const { t } = useTranslate();
const handleLogin = async () => {
if (typeof window !== "undefined") {
localStorage.setItem(FORMBRICKS_LOGGED_IN_WITH_LS, "Github");
}
const callbackUrlWithSource = getCallbackUrl(inviteUrl, source);
await signIn("github", {
redirect: true,
callbackUrl: inviteUrl ? inviteUrl : "/", // redirect after login to /
callbackUrl: callbackUrlWithSource,
});
};

View File

@@ -0,0 +1,76 @@
import { cleanup, fireEvent, render, screen } from "@testing-library/react";
import { signIn } from "next-auth/react";
import { afterEach, describe, expect, it, vi } from "vitest";
import { FORMBRICKS_LOGGED_IN_WITH_LS } from "@formbricks/lib/localStorage";
import { GoogleButton } from "./google-button";
// Mock next-auth/react
vi.mock("next-auth/react", () => ({
signIn: vi.fn(),
}));
// Mock localStorage
const mockLocalStorage = {
setItem: vi.fn(),
};
Object.defineProperty(window, "localStorage", {
value: mockLocalStorage,
writable: true,
});
describe("GoogleButton", () => {
const defaultProps = {
source: "signin" as const,
};
afterEach(() => {
cleanup();
vi.clearAllMocks();
});
it("renders correctly with default props", () => {
render(<GoogleButton {...defaultProps} />);
const button = screen.getByRole("button", { name: "auth.continue_with_google" });
expect(button).toBeInTheDocument();
});
it("renders with last used indicator when lastUsed is true", () => {
render(<GoogleButton {...defaultProps} lastUsed={true} />);
expect(screen.getByText("auth.last_used")).toBeInTheDocument();
});
it("sets localStorage item and calls signIn on click", async () => {
render(<GoogleButton {...defaultProps} />);
const button = screen.getByRole("button", { name: "auth.continue_with_google" });
fireEvent.click(button);
expect(mockLocalStorage.setItem).toHaveBeenCalledWith(FORMBRICKS_LOGGED_IN_WITH_LS, "Google");
expect(signIn).toHaveBeenCalledWith("google", {
redirect: true,
callbackUrl: "/?source=signin",
});
});
it("uses inviteUrl in callbackUrl when provided", async () => {
const inviteUrl = "https://example.com/invite";
render(<GoogleButton {...defaultProps} inviteUrl={inviteUrl} />);
const button = screen.getByRole("button", { name: "auth.continue_with_google" });
fireEvent.click(button);
expect(signIn).toHaveBeenCalledWith("google", {
redirect: true,
callbackUrl: "https://example.com/invite?source=signin",
});
});
it("handles signup source correctly", async () => {
render(<GoogleButton {...defaultProps} source="signup" />);
const button = screen.getByRole("button", { name: "auth.continue_with_google" });
fireEvent.click(button);
expect(signIn).toHaveBeenCalledWith("google", {
redirect: true,
callbackUrl: "/?source=signup",
});
});
});

View File

@@ -1,5 +1,6 @@
"use client";
import { getCallbackUrl } from "@/modules/ee/sso/lib/utils";
import { Button } from "@/modules/ui/components/button";
import { GoogleIcon } from "@/modules/ui/components/icons";
import { useTranslate } from "@tolgee/react";
@@ -9,17 +10,20 @@ import { FORMBRICKS_LOGGED_IN_WITH_LS } from "@formbricks/lib/localStorage";
interface GoogleButtonProps {
inviteUrl?: string;
lastUsed?: boolean;
source: "signin" | "signup";
}
export const GoogleButton = ({ inviteUrl, lastUsed }: GoogleButtonProps) => {
export const GoogleButton = ({ inviteUrl, lastUsed, source }: GoogleButtonProps) => {
const { t } = useTranslate();
const handleLogin = async () => {
if (typeof window !== "undefined") {
localStorage.setItem(FORMBRICKS_LOGGED_IN_WITH_LS, "Google");
}
const callbackUrlWithSource = getCallbackUrl(inviteUrl, source);
await signIn("google", {
redirect: true,
callbackUrl: inviteUrl ? inviteUrl : "/", // redirect after login to /
callbackUrl: callbackUrlWithSource,
});
};

View File

@@ -0,0 +1,91 @@
import { cleanup, fireEvent, render, screen } from "@testing-library/react";
import { signIn } from "next-auth/react";
import { afterEach, describe, expect, it, vi } from "vitest";
import { FORMBRICKS_LOGGED_IN_WITH_LS } from "@formbricks/lib/localStorage";
import { OpenIdButton } from "./open-id-button";
// Mock next-auth/react
vi.mock("next-auth/react", () => ({
signIn: vi.fn(),
}));
// Mock localStorage
const mockLocalStorage = {
setItem: vi.fn(),
};
Object.defineProperty(window, "localStorage", {
value: mockLocalStorage,
writable: true,
});
describe("OpenIdButton", () => {
const defaultProps = {
source: "signin" as const,
};
afterEach(() => {
cleanup();
vi.clearAllMocks();
});
it("renders correctly with default props", () => {
render(<OpenIdButton {...defaultProps} />);
const button = screen.getByRole("button", { name: "auth.continue_with_openid" });
expect(button).toBeInTheDocument();
});
it("renders with custom text when provided", () => {
const customText = "Custom OpenID Text";
render(<OpenIdButton {...defaultProps} text={customText} />);
const button = screen.getByRole("button", { name: customText });
expect(button).toBeInTheDocument();
});
it("renders with last used indicator when lastUsed is true", () => {
render(<OpenIdButton {...defaultProps} lastUsed={true} />);
expect(screen.getByText("auth.last_used")).toBeInTheDocument();
});
it("sets localStorage item and calls signIn on click", async () => {
render(<OpenIdButton {...defaultProps} />);
const button = screen.getByRole("button", { name: "auth.continue_with_openid" });
fireEvent.click(button);
expect(mockLocalStorage.setItem).toHaveBeenCalledWith(FORMBRICKS_LOGGED_IN_WITH_LS, "OpenID");
expect(signIn).toHaveBeenCalledWith("openid", {
redirect: true,
callbackUrl: "/?source=signin",
});
});
it("uses inviteUrl in callbackUrl when provided", async () => {
const inviteUrl = "https://example.com/invite";
render(<OpenIdButton {...defaultProps} inviteUrl={inviteUrl} />);
const button = screen.getByRole("button", { name: "auth.continue_with_openid" });
fireEvent.click(button);
expect(signIn).toHaveBeenCalledWith("openid", {
redirect: true,
callbackUrl: "https://example.com/invite?source=signin",
});
});
it("handles signup source correctly", async () => {
render(<OpenIdButton {...defaultProps} source="signup" />);
const button = screen.getByRole("button", { name: "auth.continue_with_openid" });
fireEvent.click(button);
expect(signIn).toHaveBeenCalledWith("openid", {
redirect: true,
callbackUrl: "/?source=signup",
});
});
it("triggers direct redirect when directRedirect is true", () => {
render(<OpenIdButton {...defaultProps} directRedirect={true} />);
expect(signIn).toHaveBeenCalledWith("openid", {
redirect: true,
callbackUrl: "/?source=signin",
});
});
});

View File

@@ -1,5 +1,6 @@
"use client";
import { getCallbackUrl } from "@/modules/ee/sso/lib/utils";
import { Button } from "@/modules/ui/components/button";
import { useTranslate } from "@tolgee/react";
import { signIn } from "next-auth/react";
@@ -11,19 +12,28 @@ interface OpenIdButtonProps {
lastUsed?: boolean;
directRedirect?: boolean;
text?: string;
source: "signin" | "signup";
}
export const OpenIdButton = ({ inviteUrl, lastUsed, directRedirect = false, text }: OpenIdButtonProps) => {
export const OpenIdButton = ({
inviteUrl,
lastUsed,
directRedirect = false,
text,
source,
}: OpenIdButtonProps) => {
const { t } = useTranslate();
const handleLogin = useCallback(async () => {
if (typeof window !== "undefined") {
localStorage.setItem(FORMBRICKS_LOGGED_IN_WITH_LS, "OpenID");
}
const callbackUrlWithSource = getCallbackUrl(inviteUrl, source);
await signIn("openid", {
redirect: true,
callbackUrl: inviteUrl ? inviteUrl : "/",
callbackUrl: callbackUrlWithSource,
});
}, [inviteUrl]);
}, [inviteUrl, source]);
useEffect(() => {
if (directRedirect) {

View File

@@ -0,0 +1,130 @@
import { doesSamlConnectionExistAction } from "@/modules/ee/sso/actions";
import { cleanup, fireEvent, render, screen } from "@testing-library/react";
import { signIn } from "next-auth/react";
import toast from "react-hot-toast";
import { afterEach, describe, expect, it, vi } from "vitest";
import { FORMBRICKS_LOGGED_IN_WITH_LS } from "@formbricks/lib/localStorage";
import { SamlButton } from "./saml-button";
// Mock next-auth/react
vi.mock("next-auth/react", () => ({
signIn: vi.fn().mockResolvedValue(undefined),
}));
// Mock localStorage
const mockLocalStorage = {
setItem: vi.fn(),
};
Object.defineProperty(window, "localStorage", {
value: mockLocalStorage,
writable: true,
});
// Mock actions
vi.mock("@/modules/ee/sso/actions", () => ({
doesSamlConnectionExistAction: vi.fn(),
}));
// Mock toast
vi.mock("react-hot-toast", () => ({
default: {
error: vi.fn(),
},
}));
describe("SamlButton", () => {
const defaultProps = {
source: "signin" as const,
samlTenant: "test-tenant",
samlProduct: "test-product",
};
afterEach(() => {
cleanup();
vi.clearAllMocks();
});
it("renders correctly with default props", () => {
render(<SamlButton {...defaultProps} />);
const button = screen.getByRole("button", { name: "auth.continue_with_saml" });
expect(button).toBeInTheDocument();
});
it("renders with last used indicator when lastUsed is true", () => {
render(<SamlButton {...defaultProps} lastUsed={true} />);
expect(screen.getByText("auth.last_used")).toBeInTheDocument();
});
it("sets localStorage item and calls signIn on click when SAML connection exists", async () => {
vi.mocked(doesSamlConnectionExistAction).mockResolvedValue({ data: true });
render(<SamlButton {...defaultProps} />);
const button = screen.getByRole("button", { name: "auth.continue_with_saml" });
await fireEvent.click(button);
expect(mockLocalStorage.setItem).toHaveBeenCalledWith(FORMBRICKS_LOGGED_IN_WITH_LS, "Saml");
expect(signIn).toHaveBeenCalledWith(
"saml",
{
redirect: true,
callbackUrl: "/?source=signin",
},
{
tenant: "test-tenant",
product: "test-product",
}
);
});
it("shows error toast when SAML connection does not exist", async () => {
vi.mocked(doesSamlConnectionExistAction).mockResolvedValue({ data: false });
render(<SamlButton {...defaultProps} />);
const button = screen.getByRole("button", { name: "auth.continue_with_saml" });
await fireEvent.click(button);
expect(toast.error).toHaveBeenCalledWith("auth.saml_connection_error");
expect(signIn).not.toHaveBeenCalled();
});
it("uses inviteUrl in callbackUrl when provided", async () => {
vi.mocked(doesSamlConnectionExistAction).mockResolvedValue({ data: true });
const inviteUrl = "https://example.com/invite";
render(<SamlButton {...defaultProps} inviteUrl={inviteUrl} />);
const button = screen.getByRole("button", { name: "auth.continue_with_saml" });
await fireEvent.click(button);
expect(signIn).toHaveBeenCalledWith(
"saml",
{
redirect: true,
callbackUrl: "https://example.com/invite?source=signin",
},
{
tenant: "test-tenant",
product: "test-product",
}
);
});
it("handles signup source correctly", async () => {
vi.mocked(doesSamlConnectionExistAction).mockResolvedValue({ data: true });
render(<SamlButton {...defaultProps} source="signup" />);
const button = screen.getByRole("button", { name: "auth.continue_with_saml" });
await fireEvent.click(button);
expect(signIn).toHaveBeenCalledWith(
"saml",
{
redirect: true,
callbackUrl: "/?source=signup",
},
{
tenant: "test-tenant",
product: "test-product",
}
);
});
});

View File

@@ -1,6 +1,7 @@
"use client";
import { doesSamlConnectionExistAction } from "@/modules/ee/sso/actions";
import { getCallbackUrl } from "@/modules/ee/sso/lib/utils";
import { Button } from "@/modules/ui/components/button";
import { useTranslate } from "@tolgee/react";
import { LockIcon } from "lucide-react";
@@ -14,9 +15,10 @@ interface SamlButtonProps {
lastUsed?: boolean;
samlTenant: string;
samlProduct: string;
source: "signin" | "signup";
}
export const SamlButton = ({ inviteUrl, lastUsed, samlTenant, samlProduct }: SamlButtonProps) => {
export const SamlButton = ({ inviteUrl, lastUsed, samlTenant, samlProduct, source }: SamlButtonProps) => {
const { t } = useTranslate();
const [isLoading, setIsLoading] = useState(false);
@@ -32,11 +34,13 @@ export const SamlButton = ({ inviteUrl, lastUsed, samlTenant, samlProduct }: Sam
return;
}
const callbackUrlWithSource = getCallbackUrl(inviteUrl, source);
signIn(
"saml",
{
redirect: true,
callbackUrl: inviteUrl ? inviteUrl : "/", // redirect after login to /
callbackUrl: callbackUrlWithSource,
},
{
tenant: samlTenant,

View File

@@ -0,0 +1,137 @@
import { cleanup, render, screen } from "@testing-library/react";
import { afterEach, describe, expect, it, vi } from "vitest";
import { SSOOptions } from "./sso-options";
// Mock environment variables
vi.mock("@formbricks/lib/env", () => ({
env: {
IS_FORMBRICKS_CLOUD: "0",
},
}));
// Mock the translation hook
vi.mock("@tolgee/react", () => ({
useTranslate: () => ({
t: (key: string) => key,
}),
}));
// Mock the individual SSO buttons
vi.mock("./google-button", () => ({
GoogleButton: ({ lastUsed, source }: any) => (
<div data-testid="google-button" data-last-used={lastUsed} data-source={source}>
Google Button
</div>
),
}));
vi.mock("./github-button", () => ({
GithubButton: ({ lastUsed, source }: any) => (
<div data-testid="github-button" data-last-used={lastUsed} data-source={source}>
Github Button
</div>
),
}));
vi.mock("./azure-button", () => ({
AzureButton: ({ lastUsed, source }: any) => (
<div data-testid="azure-button" data-last-used={lastUsed} data-source={source}>
Azure Button
</div>
),
}));
vi.mock("./open-id-button", () => ({
OpenIdButton: ({ lastUsed, source, text }: any) => (
<div data-testid="openid-button" data-last-used={lastUsed} data-source={source}>
{text}
</div>
),
}));
vi.mock("./saml-button", () => ({
SamlButton: ({ lastUsed, source, samlTenant, samlProduct }: any) => (
<div
data-testid="saml-button"
data-last-used={lastUsed}
data-source={source}
data-tenant={samlTenant}
data-product={samlProduct}>
Saml Button
</div>
),
}));
describe("SSOOptions Component", () => {
afterEach(() => {
cleanup();
vi.clearAllMocks();
});
const defaultProps = {
googleOAuthEnabled: true,
githubOAuthEnabled: true,
azureOAuthEnabled: true,
oidcOAuthEnabled: true,
oidcDisplayName: "OpenID",
callbackUrl: "http://localhost:3000",
samlSsoEnabled: true,
samlTenant: "test-tenant",
samlProduct: "test-product",
source: "signin" as const,
};
it("renders all SSO options when all are enabled", () => {
render(<SSOOptions {...defaultProps} />);
expect(screen.getByTestId("google-button")).toBeInTheDocument();
expect(screen.getByTestId("github-button")).toBeInTheDocument();
expect(screen.getByTestId("azure-button")).toBeInTheDocument();
expect(screen.getByTestId("openid-button")).toBeInTheDocument();
expect(screen.getByTestId("saml-button")).toBeInTheDocument();
});
it("only renders enabled SSO options", () => {
render(
<SSOOptions
{...defaultProps}
googleOAuthEnabled={false}
githubOAuthEnabled={false}
azureOAuthEnabled={false}
/>
);
expect(screen.queryByTestId("google-button")).not.toBeInTheDocument();
expect(screen.queryByTestId("github-button")).not.toBeInTheDocument();
expect(screen.queryByTestId("azure-button")).not.toBeInTheDocument();
expect(screen.getByTestId("openid-button")).toBeInTheDocument();
expect(screen.getByTestId("saml-button")).toBeInTheDocument();
});
it("passes correct props to OpenID button", () => {
render(<SSOOptions {...defaultProps} />);
const openIdButton = screen.getByTestId("openid-button");
expect(openIdButton).toHaveAttribute("data-source", "signin");
expect(openIdButton).toHaveTextContent("auth.continue_with_oidc");
});
it("passes correct props to SAML button", () => {
render(<SSOOptions {...defaultProps} />);
const samlButton = screen.getByTestId("saml-button");
expect(samlButton).toHaveAttribute("data-source", "signin");
expect(samlButton).toHaveAttribute("data-tenant", "test-tenant");
expect(samlButton).toHaveAttribute("data-product", "test-product");
});
it("passes correct source prop to all buttons", () => {
render(<SSOOptions {...defaultProps} source="signup" />);
expect(screen.getByTestId("google-button")).toHaveAttribute("data-source", "signup");
expect(screen.getByTestId("github-button")).toHaveAttribute("data-source", "signup");
expect(screen.getByTestId("azure-button")).toHaveAttribute("data-source", "signup");
expect(screen.getByTestId("openid-button")).toHaveAttribute("data-source", "signup");
expect(screen.getByTestId("saml-button")).toHaveAttribute("data-source", "signup");
});
});

View File

@@ -19,6 +19,7 @@ interface SSOOptionsProps {
samlSsoEnabled: boolean;
samlTenant: string;
samlProduct: string;
source: "signin" | "signup";
}
export const SSOOptions = ({
@@ -31,6 +32,7 @@ export const SSOOptions = ({
samlSsoEnabled,
samlTenant,
samlProduct,
source,
}: SSOOptionsProps) => {
const { t } = useTranslate();
const [lastLoggedInWith, setLastLoggedInWith] = useState("");
@@ -44,17 +46,20 @@ export const SSOOptions = ({
return (
<div className="space-y-2">
{googleOAuthEnabled && (
<GoogleButton inviteUrl={callbackUrl} lastUsed={lastLoggedInWith === "Google"} />
<GoogleButton inviteUrl={callbackUrl} lastUsed={lastLoggedInWith === "Google"} source={source} />
)}
{githubOAuthEnabled && (
<GithubButton inviteUrl={callbackUrl} lastUsed={lastLoggedInWith === "Github"} />
<GithubButton inviteUrl={callbackUrl} lastUsed={lastLoggedInWith === "Github"} source={source} />
)}
{azureOAuthEnabled && (
<AzureButton inviteUrl={callbackUrl} lastUsed={lastLoggedInWith === "Azure"} source={source} />
)}
{azureOAuthEnabled && <AzureButton inviteUrl={callbackUrl} lastUsed={lastLoggedInWith === "Azure"} />}
{oidcOAuthEnabled && (
<OpenIdButton
inviteUrl={callbackUrl}
lastUsed={lastLoggedInWith === "OpenID"}
text={t("auth.continue_with_oidc", { oidcDisplayName })}
source={source}
/>
)}
{samlSsoEnabled && (
@@ -63,6 +68,7 @@ export const SSOOptions = ({
lastUsed={lastLoggedInWith === "Saml"}
samlTenant={samlTenant}
samlProduct={samlProduct}
source={source}
/>
)}
</div>

View File

@@ -1,19 +1,34 @@
import { createBrevoCustomer } from "@/modules/auth/lib/brevo";
import { getUserByEmail, updateUser } from "@/modules/auth/lib/user";
import { createUser } from "@/modules/auth/lib/user";
import { getIsValidInviteToken } from "@/modules/auth/signup/lib/invite";
import { TOidcNameFields, TSamlNameFields } from "@/modules/auth/types/auth";
import { getIsSamlSsoEnabled, getisSsoEnabled } from "@/modules/ee/license-check/lib/utils";
import {
getIsMultiOrgEnabled,
getIsSamlSsoEnabled,
getisSsoEnabled,
} from "@/modules/ee/license-check/lib/utils";
import type { IdentityProvider } from "@prisma/client";
import type { Account } from "next-auth";
import { prisma } from "@formbricks/database";
import { createAccount } from "@formbricks/lib/account/service";
import { DEFAULT_ORGANIZATION_ID, DEFAULT_ORGANIZATION_ROLE } from "@formbricks/lib/constants";
import { verifyInviteToken } from "@formbricks/lib/jwt";
import { createMembership } from "@formbricks/lib/membership/service";
import { createOrganization, getOrganization } from "@formbricks/lib/organization/service";
import { findMatchingLocale } from "@formbricks/lib/utils/locale";
import { logger } from "@formbricks/logger";
import type { TUser, TUserNotificationSettings } from "@formbricks/types/user";
export const handleSsoCallback = async ({ user, account }: { user: TUser; account: Account }) => {
export const handleSsoCallback = async ({
user,
account,
callbackUrl,
}: {
user: TUser;
account: Account;
callbackUrl: string;
}) => {
const isSsoEnabled = await getisSsoEnabled();
if (!isSsoEnabled) {
return false;
@@ -102,6 +117,46 @@ export const handleSsoCallback = async ({ user, account }: { user: TUser; accoun
}
}
// Get multi-org license status
const isMultiOrgEnabled = await getIsMultiOrgEnabled();
// Reject if no callback URL and no default org in self-hosted environment
if (!callbackUrl && !DEFAULT_ORGANIZATION_ID && !isMultiOrgEnabled) {
return false;
}
// Additional security checks for self-hosted instances without default org
if (!DEFAULT_ORGANIZATION_ID && !isMultiOrgEnabled) {
try {
// Parse and validate the callback URL
const isValidCallbackUrl = new URL(callbackUrl);
// Extract invite token and source from URL parameters
const inviteToken = isValidCallbackUrl.searchParams.get("token") || "";
const source = isValidCallbackUrl.searchParams.get("source") || "";
// Allow sign-in if multi-org is enabled, otherwise check for invite token
if (source === "signin" && !inviteToken) {
return false;
}
// If multi-org is enabled, skip invite token validation
// Verify invite token and check email match
const { email, inviteId } = verifyInviteToken(inviteToken);
if (email !== user.email) {
return false;
}
// Check if invite token is still valid
const isValidInviteToken = await getIsValidInviteToken(inviteId);
if (!isValidInviteToken) {
return false;
}
} catch (err) {
// Log and reject on any validation errors
logger.error(err, "Invalid callbackUrl");
return false;
}
}
const userProfile = await createUser({
name:
userName ||

View File

@@ -1,5 +1,6 @@
import { createBrevoCustomer } from "@/modules/auth/lib/brevo";
import { createUser, getUserByEmail, updateUser } from "@/modules/auth/lib/user";
import type { TSamlNameFields } from "@/modules/auth/types/auth";
import { getIsSamlSsoEnabled, getisSsoEnabled } from "@/modules/ee/license-check/lib/utils";
import { beforeEach, describe, expect, it, vi } from "vitest";
import { prisma } from "@formbricks/database";
@@ -7,6 +8,7 @@ import { createAccount } from "@formbricks/lib/account/service";
import { createMembership } from "@formbricks/lib/membership/service";
import { createOrganization, getOrganization } from "@formbricks/lib/organization/service";
import { findMatchingLocale } from "@formbricks/lib/utils/locale";
import type { TUser } from "@formbricks/types/user";
import { handleSsoCallback } from "../sso-handlers";
import {
mockAccount,
@@ -32,6 +34,7 @@ vi.mock("@/modules/auth/lib/user", () => ({
vi.mock("@/modules/ee/license-check/lib/utils", () => ({
getIsSamlSsoEnabled: vi.fn(),
getisSsoEnabled: vi.fn(),
getIsMultiOrgEnabled: vi.fn().mockResolvedValue(true),
}));
vi.mock("@formbricks/database", () => ({
@@ -63,6 +66,7 @@ vi.mock("@formbricks/lib/utils/locale", () => ({
vi.mock("@formbricks/lib/constants", () => ({
DEFAULT_ORGANIZATION_ID: "org-123",
DEFAULT_ORGANIZATION_ROLE: "member",
ENCRYPTION_KEY: "test-encryption-key-32-chars-long",
}));
describe("handleSsoCallback", () => {
@@ -90,24 +94,31 @@ describe("handleSsoCallback", () => {
it("should return false if SSO is not enabled", async () => {
vi.mocked(getisSsoEnabled).mockResolvedValue(false);
const result = await handleSsoCallback({ user: mockUser, account: mockAccount });
const result = await handleSsoCallback({
user: mockUser,
account: mockAccount,
callbackUrl: "http://localhost:3000",
});
expect(result).toBe(false);
expect(getisSsoEnabled).toHaveBeenCalled();
});
it("should return false if user email is missing", async () => {
const userWithoutEmail = { ...mockUser, email: "" };
const result = await handleSsoCallback({ user: userWithoutEmail, account: mockAccount });
const result = await handleSsoCallback({
user: { ...mockUser, email: "" },
account: mockAccount,
callbackUrl: "http://localhost:3000",
});
expect(result).toBe(false);
});
it("should return false if account type is not oauth", async () => {
const nonOauthAccount = { ...mockAccount, type: "credentials" as const };
const result = await handleSsoCallback({ user: mockUser, account: nonOauthAccount });
const result = await handleSsoCallback({
user: mockUser,
account: { ...mockAccount, type: "credentials" },
callbackUrl: "http://localhost:3000",
});
expect(result).toBe(false);
});
@@ -115,10 +126,13 @@ describe("handleSsoCallback", () => {
it("should return false if provider is SAML and SAML SSO is not enabled", async () => {
vi.mocked(getIsSamlSsoEnabled).mockResolvedValue(false);
const result = await handleSsoCallback({ user: mockUser, account: mockSamlAccount });
const result = await handleSsoCallback({
user: mockUser,
account: mockSamlAccount,
callbackUrl: "http://localhost:3000",
});
expect(result).toBe(false);
expect(getIsSamlSsoEnabled).toHaveBeenCalled();
});
});
@@ -130,7 +144,11 @@ describe("handleSsoCallback", () => {
accounts: [{ provider: mockAccount.provider }],
});
const result = await handleSsoCallback({ user: mockUser, account: mockAccount });
const result = await handleSsoCallback({
user: mockUser,
account: mockAccount,
callbackUrl: "http://localhost:3000",
});
expect(result).toBe(true);
expect(prisma.user.findFirst).toHaveBeenCalledWith({
@@ -160,7 +178,11 @@ describe("handleSsoCallback", () => {
vi.mocked(getUserByEmail).mockResolvedValue(null);
vi.mocked(updateUser).mockResolvedValue({ ...existingUser, email: mockUser.email });
const result = await handleSsoCallback({ user: mockUser, account: mockAccount });
const result = await handleSsoCallback({
user: mockUser,
account: mockAccount,
callbackUrl: "http://localhost:3000",
});
expect(result).toBe(true);
expect(updateUser).toHaveBeenCalledWith(existingUser.id, { email: mockUser.email });
@@ -180,9 +202,16 @@ describe("handleSsoCallback", () => {
email: mockUser.email,
emailVerified: mockUser.emailVerified,
locale: mockUser.locale,
isActive: true,
});
await expect(handleSsoCallback({ user: mockUser, account: mockAccount })).rejects.toThrow(
await expect(
handleSsoCallback({
user: mockUser,
account: mockAccount,
callbackUrl: "http://localhost:3000",
})
).rejects.toThrow(
"Looks like you updated your email somewhere else. A user with this new email exists already."
);
});
@@ -194,9 +223,14 @@ describe("handleSsoCallback", () => {
email: mockUser.email,
emailVerified: mockUser.emailVerified,
locale: mockUser.locale,
isActive: true,
});
const result = await handleSsoCallback({ user: mockUser, account: mockAccount });
const result = await handleSsoCallback({
user: mockUser,
account: mockAccount,
callbackUrl: "http://localhost:3000",
});
expect(result).toBe(true);
});
@@ -208,7 +242,11 @@ describe("handleSsoCallback", () => {
vi.mocked(getUserByEmail).mockResolvedValue(null);
vi.mocked(createUser).mockResolvedValue(mockCreatedUser());
const result = await handleSsoCallback({ user: mockUser, account: mockAccount });
const result = await handleSsoCallback({
user: mockUser,
account: mockAccount,
callbackUrl: "http://localhost:3000",
});
expect(result).toBe(true);
expect(createUser).toHaveBeenCalledWith({
@@ -228,7 +266,11 @@ describe("handleSsoCallback", () => {
vi.mocked(createUser).mockResolvedValue(mockCreatedUser());
vi.mocked(getOrganization).mockResolvedValue(null);
const result = await handleSsoCallback({ user: mockUser, account: mockAccount });
const result = await handleSsoCallback({
user: mockUser,
account: mockAccount,
callbackUrl: "http://localhost:3000",
});
expect(result).toBe(true);
expect(createOrganization).toHaveBeenCalledWith({
@@ -255,7 +297,11 @@ describe("handleSsoCallback", () => {
vi.mocked(getUserByEmail).mockResolvedValue(null);
vi.mocked(createUser).mockResolvedValue(mockCreatedUser());
const result = await handleSsoCallback({ user: mockUser, account: mockAccount });
const result = await handleSsoCallback({
user: mockUser,
account: mockAccount,
callbackUrl: "http://localhost:3000",
});
expect(result).toBe(true);
expect(createOrganization).not.toHaveBeenCalled();
@@ -276,14 +322,21 @@ describe("handleSsoCallback", () => {
vi.mocked(createUser).mockResolvedValue(mockCreatedUser("Direct Name"));
const result = await handleSsoCallback({ user: openIdUser, account: mockOpenIdAccount });
const result = await handleSsoCallback({
user: openIdUser,
account: mockOpenIdAccount,
callbackUrl: "http://localhost:3000",
});
expect(result).toBe(true);
expect(createUser).toHaveBeenCalledWith(
expect.objectContaining({
name: "Direct Name",
email: openIdUser.email,
emailVerified: expect.any(Date),
identityProvider: "openid",
identityProviderAccountId: mockOpenIdAccount.providerAccountId,
locale: "en-US",
})
);
});
@@ -297,14 +350,21 @@ describe("handleSsoCallback", () => {
vi.mocked(createUser).mockResolvedValue(mockCreatedUser("John Doe"));
const result = await handleSsoCallback({ user: openIdUser, account: mockOpenIdAccount });
const result = await handleSsoCallback({
user: openIdUser,
account: mockOpenIdAccount,
callbackUrl: "http://localhost:3000",
});
expect(result).toBe(true);
expect(createUser).toHaveBeenCalledWith(
expect.objectContaining({
name: "John Doe",
email: openIdUser.email,
emailVerified: expect.any(Date),
identityProvider: "openid",
identityProviderAccountId: mockOpenIdAccount.providerAccountId,
locale: "en-US",
})
);
});
@@ -319,14 +379,21 @@ describe("handleSsoCallback", () => {
vi.mocked(createUser).mockResolvedValue(mockCreatedUser("preferred.user"));
const result = await handleSsoCallback({ user: openIdUser, account: mockOpenIdAccount });
const result = await handleSsoCallback({
user: openIdUser,
account: mockOpenIdAccount,
callbackUrl: "http://localhost:3000",
});
expect(result).toBe(true);
expect(createUser).toHaveBeenCalledWith(
expect.objectContaining({
name: "preferred.user",
email: openIdUser.email,
emailVerified: expect.any(Date),
identityProvider: "openid",
identityProviderAccountId: mockOpenIdAccount.providerAccountId,
locale: "en-US",
})
);
});
@@ -340,18 +407,152 @@ describe("handleSsoCallback", () => {
email: "test.user@example.com",
});
vi.mocked(createUser).mockResolvedValue(mockCreatedUser("test.user"));
vi.mocked(createUser).mockResolvedValue(mockCreatedUser("test user"));
const result = await handleSsoCallback({ user: openIdUser, account: mockOpenIdAccount });
const result = await handleSsoCallback({
user: openIdUser,
account: mockOpenIdAccount,
callbackUrl: "http://localhost:3000",
});
expect(result).toBe(true);
expect(createUser).toHaveBeenCalledWith(
expect.objectContaining({
name: "test user",
email: openIdUser.email,
emailVerified: expect.any(Date),
identityProvider: "openid",
identityProviderAccountId: mockOpenIdAccount.providerAccountId,
locale: "en-US",
})
);
});
});
describe("SAML name handling", () => {
it("should use samlUser.name when available", async () => {
const samlUser = {
...mockUser,
name: "Direct Name",
firstName: "John",
lastName: "Doe",
} as TUser & TSamlNameFields;
vi.mocked(createUser).mockResolvedValue(mockCreatedUser("Direct Name"));
const result = await handleSsoCallback({
user: samlUser,
account: mockSamlAccount,
callbackUrl: "http://localhost:3000",
});
expect(result).toBe(true);
expect(createUser).toHaveBeenCalledWith(
expect.objectContaining({
name: "Direct Name",
email: samlUser.email,
emailVerified: expect.any(Date),
identityProvider: "saml",
identityProviderAccountId: mockSamlAccount.providerAccountId,
locale: "en-US",
})
);
});
it("should use firstName + lastName when name is not available", async () => {
const samlUser = {
...mockUser,
name: "",
firstName: "John",
lastName: "Doe",
} as TUser & TSamlNameFields;
vi.mocked(createUser).mockResolvedValue(mockCreatedUser("John Doe"));
const result = await handleSsoCallback({
user: samlUser,
account: mockSamlAccount,
callbackUrl: "http://localhost:3000",
});
expect(result).toBe(true);
expect(createUser).toHaveBeenCalledWith(
expect.objectContaining({
name: "John Doe",
email: samlUser.email,
emailVerified: expect.any(Date),
identityProvider: "saml",
identityProviderAccountId: mockSamlAccount.providerAccountId,
locale: "en-US",
})
);
});
});
describe("Organization handling", () => {
it("should handle invalid DEFAULT_ORGANIZATION_ID gracefully", async () => {
vi.mocked(prisma.user.findFirst).mockResolvedValue(null);
vi.mocked(getUserByEmail).mockResolvedValue(null);
vi.mocked(createUser).mockResolvedValue(mockCreatedUser());
vi.mocked(getOrganization).mockResolvedValue(null);
vi.mocked(createOrganization).mockRejectedValue(new Error("Invalid organization ID"));
await expect(
handleSsoCallback({
user: mockUser,
account: mockAccount,
callbackUrl: "http://localhost:3000",
})
).rejects.toThrow("Invalid organization ID");
expect(createOrganization).toHaveBeenCalled();
expect(createMembership).not.toHaveBeenCalled();
});
it("should handle membership creation failure gracefully", async () => {
vi.mocked(prisma.user.findFirst).mockResolvedValue(null);
vi.mocked(getUserByEmail).mockResolvedValue(null);
vi.mocked(createUser).mockResolvedValue(mockCreatedUser());
vi.mocked(createMembership).mockRejectedValue(new Error("Failed to create membership"));
await expect(
handleSsoCallback({
user: mockUser,
account: mockAccount,
callbackUrl: "http://localhost:3000",
})
).rejects.toThrow("Failed to create membership");
expect(createMembership).toHaveBeenCalled();
});
});
describe("Error handling", () => {
it("should handle prisma errors gracefully", async () => {
vi.mocked(prisma.user.findFirst).mockRejectedValue(new Error("Database error"));
await expect(
handleSsoCallback({
user: mockUser,
account: mockAccount,
callbackUrl: "http://localhost:3000",
})
).rejects.toThrow("Database error");
});
it("should handle locale finding errors gracefully", async () => {
vi.mocked(findMatchingLocale).mockRejectedValue(new Error("Locale error"));
vi.mocked(prisma.user.findFirst).mockResolvedValue(null);
vi.mocked(getUserByEmail).mockResolvedValue(null);
vi.mocked(createUser).mockResolvedValue(mockCreatedUser());
await expect(
handleSsoCallback({
user: mockUser,
account: mockAccount,
callbackUrl: "http://localhost:3000",
})
).rejects.toThrow("Locale error");
});
});
});

View File

@@ -0,0 +1,29 @@
import { describe, expect, it } from "vitest";
import { getCallbackUrl } from "../utils";
describe("getCallbackUrl", () => {
it("should return base URL with source when no inviteUrl is provided", () => {
const result = getCallbackUrl(undefined, "test-source");
expect(result).toBe("/?source=test-source");
});
it("should append source parameter to inviteUrl with existing query parameters", () => {
const result = getCallbackUrl("https://example.com/invite?param=value", "test-source");
expect(result).toBe("https://example.com/invite?param=value&source=test-source");
});
it("should append source parameter to inviteUrl without existing query parameters", () => {
const result = getCallbackUrl("https://example.com/invite", "test-source");
expect(result).toBe("https://example.com/invite?source=test-source");
});
it("should handle empty source parameter", () => {
const result = getCallbackUrl("https://example.com/invite", "");
expect(result).toBe("https://example.com/invite?source=");
});
it("should handle undefined source parameter", () => {
const result = getCallbackUrl("https://example.com/invite", undefined);
expect(result).toBe("https://example.com/invite?source=undefined");
});
});

View File

@@ -0,0 +1,5 @@
export const getCallbackUrl = (inviteUrl?: string, source?: string) => {
return inviteUrl
? `${inviteUrl}${inviteUrl.includes("?") ? "&" : "?"}source=${source}`
: `/?source=${source}`;
};

View File

@@ -8,7 +8,7 @@ import { useSingleUseId } from "./useSingleUseId";
// Mock external functions
vi.mock("@/modules/survey/list/actions", () => ({
generateSingleUseIdAction: vi.fn(),
generateSingleUseIdAction: vi.fn().mockResolvedValue({ data: "initialId" }),
}));
vi.mock("@/lib/utils/helper", () => ({
@@ -88,8 +88,10 @@ describe("useSingleUseId", () => {
vi.mocked(generateSingleUseIdAction).mockResolvedValueOnce({ data: "initialId" });
const { result } = renderHook(() => useSingleUseId(mockSurvey));
// Wait for initial
await new Promise((r) => setTimeout(r, 0));
// Wait for initial value to be set
await act(async () => {
await new Promise((r) => setTimeout(r, 0));
});
expect(result.current.singleUseId).toBe("initialId");
vi.mocked(generateSingleUseIdAction).mockResolvedValueOnce({ data: "refreshedId" });

View File

@@ -19,7 +19,8 @@ export default defineConfig({
"modules/api/v2/**/*.ts",
"modules/api/v2/**/*.tsx",
"modules/auth/lib/**/*.ts",
"modules/signup/lib/**/*.ts",
"modules/auth/signup/lib/**/*.ts",
"modules/auth/signup/**/*.tsx",
"modules/ee/whitelabel/email-customization/components/*.tsx",
"modules/ee/role-management/components/*.tsx",
"modules/organization/settings/teams/components/edit-memberships/organization-actions.tsx",
@@ -52,6 +53,7 @@ export default defineConfig({
"modules/survey/list/components/survey-card.tsx",
"modules/survey/list/components/survey-dropdown-menu.tsx",
"modules/ee/contacts/api/v2/management/contacts/bulk/lib/contact.ts",
"modules/ee/sso/components/**/*.tsx",
],
exclude: [
"**/.next/**",

View File

@@ -54,6 +54,19 @@ export const createMembership = async (
validateInputs([organizationId, ZString], [userId, ZString], [data, ZMembership.partial()]);
try {
const existingMembership = await prisma.membership.findUnique({
where: {
userId_organizationId: {
userId,
organizationId,
},
},
});
if (existingMembership) {
return existingMembership;
}
const membership = await prisma.membership.create({
data: {
userId,