diff --git a/apps/web/modules/api/v2/management/lib/utils.ts b/apps/web/modules/api/v2/management/lib/utils.ts index 33d5eb5fe8..bc10c929d7 100644 --- a/apps/web/modules/api/v2/management/lib/utils.ts +++ b/apps/web/modules/api/v2/management/lib/utils.ts @@ -9,7 +9,11 @@ export function pickCommonFilter(params: T) { return { limit, skip, sortBy, order, startDate, endDate }; } -type HasFindMany = Prisma.WebhookFindManyArgs | Prisma.ResponseFindManyArgs | Prisma.TeamFindManyArgs | Prisma.ProjectTeamFindManyArgs; +type HasFindMany = + | Prisma.WebhookFindManyArgs + | Prisma.ResponseFindManyArgs + | Prisma.TeamFindManyArgs + | Prisma.ProjectTeamFindManyArgs; export function buildCommonFilterQuery(query: T, params: TGetFilter): T { const { limit, skip, sortBy, order, startDate, endDate } = params || {}; diff --git a/apps/web/modules/auth/lib/authOptions.ts b/apps/web/modules/auth/lib/authOptions.ts index db9d31e98a..8711e00c8e 100644 --- a/apps/web/modules/auth/lib/authOptions.ts +++ b/apps/web/modules/auth/lib/authOptions.ts @@ -4,6 +4,7 @@ import { getSSOProviders } from "@/modules/ee/sso/lib/providers"; import { handleSsoCallback } from "@/modules/ee/sso/lib/sso-handlers"; import type { Account, NextAuthOptions } from "next-auth"; import CredentialsProvider from "next-auth/providers/credentials"; +import { cookies } from "next/headers"; import { prisma } from "@formbricks/database"; import { EMAIL_VERIFICATION_DISABLED, @@ -208,6 +209,10 @@ export const authOptions: NextAuthOptions = { return session; }, async signIn({ user, account }: { user: TUser; account: Account }) { + const cookieStore = await cookies(); + + const callbackUrl = cookieStore.get("next-auth.callback-url")?.value || ""; + if (account?.provider === "credentials" || account?.provider === "token") { // check if user's email is verified or not if (!user.emailVerified && !EMAIL_VERIFICATION_DISABLED) { @@ -217,7 +222,7 @@ export const authOptions: NextAuthOptions = { return true; } if (ENTERPRISE_LICENSE_KEY) { - const result = await handleSsoCallback({ user, account }); + const result = await handleSsoCallback({ user, account, callbackUrl }); if (result) { await updateUserLastLoginAt(user.email); } diff --git a/apps/web/modules/auth/login/components/login-form.tsx b/apps/web/modules/auth/login/components/login-form.tsx index d0eea0c62d..51d5db84ce 100644 --- a/apps/web/modules/auth/login/components/login-form.tsx +++ b/apps/web/modules/auth/login/components/login-form.tsx @@ -256,6 +256,7 @@ export const LoginForm = ({ samlTenant={samlTenant} samlProduct={samlProduct} callbackUrl={callbackUrl} + source="signin" /> )} diff --git a/apps/web/modules/auth/signup/components/signup-form.tsx b/apps/web/modules/auth/signup/components/signup-form.tsx index 76fde04992..b2672dd5ea 100644 --- a/apps/web/modules/auth/signup/components/signup-form.tsx +++ b/apps/web/modules/auth/signup/components/signup-form.tsx @@ -280,6 +280,7 @@ export const SignupForm = ({ samlTenant={samlTenant} samlProduct={samlProduct} callbackUrl={callbackUrl} + source="signup" /> )} diff --git a/apps/web/modules/auth/signup/lib/invite.test.ts b/apps/web/modules/auth/signup/lib/invite.test.ts new file mode 100644 index 0000000000..e2628d8aed --- /dev/null +++ b/apps/web/modules/auth/signup/lib/invite.test.ts @@ -0,0 +1,246 @@ +import { inviteCache } from "@/lib/cache/invite"; +import { Prisma } from "@prisma/client"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { prisma } from "@formbricks/database"; +import { PrismaErrorType } from "@formbricks/database/types/error"; +import { logger } from "@formbricks/logger"; +import { DatabaseError } from "@formbricks/types/errors"; +import { deleteInvite, getInvite, getIsValidInviteToken } from "./invite"; + +// Mock data +const mockInviteId = "test-invite-id"; +const mockOrganizationId = "test-org-id"; +const mockCreatorId = "test-creator-id"; +const mockInvite = { + id: mockInviteId, + email: "test@test.com", + name: "Test Name", + organizationId: mockOrganizationId, + creatorId: mockCreatorId, + acceptorId: null, + createdAt: new Date(), + expiresAt: new Date(Date.now() + 24 * 60 * 60 * 1000), // 24 hours from now + deprecatedRole: null, + role: "member" as const, + teamIds: ["team-1"], + creator: { + name: "Test Creator", + email: "creator@test.com", + locale: "en", + }, +}; + +// Mock prisma methods +vi.mock("@formbricks/database", () => ({ + prisma: { + invite: { + delete: vi.fn(), + findUnique: vi.fn(), + }, + }, +})); + +// Mock cache +vi.mock("@/lib/cache/invite", () => ({ + inviteCache: { + revalidate: vi.fn(), + tag: { + byId: (id: string) => `invite-${id}`, + }, + }, +})); + +// Mock logger +vi.mock("@formbricks/logger", () => ({ + logger: { + error: vi.fn(), + }, +})); + +describe("Invite Management", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + describe("deleteInvite", () => { + it("deletes an invite successfully and invalidates cache", async () => { + vi.mocked(prisma.invite.delete).mockResolvedValue(mockInvite); + + const result = await deleteInvite(mockInviteId); + + expect(result).toBe(true); + expect(prisma.invite.delete).toHaveBeenCalledWith({ + where: { id: mockInviteId }, + select: { id: true, organizationId: true }, + }); + expect(inviteCache.revalidate).toHaveBeenCalledWith({ + id: mockInviteId, + organizationId: mockOrganizationId, + }); + }); + + it("throws DatabaseError when invite doesn't exist", async () => { + const errToThrow = new Prisma.PrismaClientKnownRequestError("Record not found", { + code: PrismaErrorType.RecordDoesNotExist, + clientVersion: "0.0.1", + }); + vi.mocked(prisma.invite.delete).mockRejectedValue(errToThrow); + + await expect(deleteInvite(mockInviteId)).rejects.toThrow(DatabaseError); + }); + + it("throws DatabaseError for other Prisma errors", async () => { + const errToThrow = new Prisma.PrismaClientKnownRequestError("Database error", { + code: "P2002", + clientVersion: "0.0.1", + }); + vi.mocked(prisma.invite.delete).mockRejectedValue(errToThrow); + + await expect(deleteInvite(mockInviteId)).rejects.toThrow(DatabaseError); + }); + + it("throws DatabaseError for generic errors", async () => { + vi.mocked(prisma.invite.delete).mockRejectedValue(new Error("Generic error")); + + await expect(deleteInvite(mockInviteId)).rejects.toThrow(DatabaseError); + }); + }); + + describe("getInvite", () => { + it("retrieves an invite with creator details successfully", async () => { + vi.mocked(prisma.invite.findUnique).mockResolvedValue(mockInvite); + + const result = await getInvite(mockInviteId); + + expect(result).toEqual(mockInvite); + expect(prisma.invite.findUnique).toHaveBeenCalledWith({ + where: { id: mockInviteId }, + select: { + id: true, + organizationId: true, + role: true, + teamIds: true, + creator: { + select: { + name: true, + email: true, + locale: true, + }, + }, + }, + }); + }); + + it("returns null when invite doesn't exist", async () => { + vi.mocked(prisma.invite.findUnique).mockResolvedValue(null); + + const result = await getInvite(mockInviteId); + + expect(result).toBeNull(); + }); + + it("throws DatabaseError on prisma error", async () => { + const errToThrow = new Prisma.PrismaClientKnownRequestError("Database error", { + code: "P2002", + clientVersion: "0.0.1", + }); + vi.mocked(prisma.invite.findUnique).mockRejectedValue(errToThrow); + + await expect(getInvite(mockInviteId)).rejects.toThrow(DatabaseError); + }); + + it("throws DatabaseError for generic errors", async () => { + vi.mocked(prisma.invite.findUnique).mockRejectedValue(new Error("Generic error")); + + await expect(getInvite(mockInviteId)).rejects.toThrow(DatabaseError); + }); + }); + + describe("getIsValidInviteToken", () => { + it("returns true for valid invite", async () => { + vi.mocked(prisma.invite.findUnique).mockResolvedValue(mockInvite); + + const result = await getIsValidInviteToken(mockInviteId); + + expect(result).toBe(true); + expect(prisma.invite.findUnique).toHaveBeenCalledWith({ + where: { id: mockInviteId }, + }); + }); + + it("returns false when invite doesn't exist", async () => { + vi.mocked(prisma.invite.findUnique).mockResolvedValue(null); + + const result = await getIsValidInviteToken(mockInviteId); + + expect(result).toBe(false); + }); + + it("returns false for expired invite", async () => { + const expiredInvite = { + ...mockInvite, + expiresAt: new Date(Date.now() - 24 * 60 * 60 * 1000), // 24 hours ago + }; + vi.mocked(prisma.invite.findUnique).mockResolvedValue(expiredInvite); + + const result = await getIsValidInviteToken(mockInviteId); + + expect(result).toBe(false); + expect(logger.error).toHaveBeenCalledWith( + { + inviteId: mockInviteId, + expiresAt: expiredInvite.expiresAt, + }, + "SSO: Invite token expired" + ); + }); + + it("returns false and logs error when database error occurs", async () => { + const error = new Error("Database error"); + vi.mocked(prisma.invite.findUnique).mockRejectedValue(error); + + const result = await getIsValidInviteToken(mockInviteId); + + expect(result).toBe(false); + expect(logger.error).toHaveBeenCalledWith(error, "Error getting invite"); + }); + + it("returns false for invite with null expiresAt", async () => { + const invalidInvite = { + ...mockInvite, + expiresAt: null, + }; + vi.mocked(prisma.invite.findUnique).mockResolvedValue(invalidInvite); + + const result = await getIsValidInviteToken(mockInviteId); + + expect(result).toBe(false); + expect(logger.error).toHaveBeenCalledWith( + { + inviteId: mockInviteId, + expiresAt: null, + }, + "SSO: Invite token expired" + ); + }); + + it("returns false for invite with invalid expiresAt", async () => { + const invalidInvite = { + ...mockInvite, + expiresAt: new Date("invalid-date"), + }; + vi.mocked(prisma.invite.findUnique).mockResolvedValue(invalidInvite); + + const result = await getIsValidInviteToken(mockInviteId); + + expect(result).toBe(false); + expect(logger.error).toHaveBeenCalledWith( + { + inviteId: mockInviteId, + expiresAt: invalidInvite.expiresAt, + }, + "SSO: Invite token expired" + ); + }); + }); +}); diff --git a/apps/web/modules/auth/signup/lib/invite.ts b/apps/web/modules/auth/signup/lib/invite.ts index 6a2cd7be03..fd879abbef 100644 --- a/apps/web/modules/auth/signup/lib/invite.ts +++ b/apps/web/modules/auth/signup/lib/invite.ts @@ -4,6 +4,7 @@ import { Prisma } from "@prisma/client"; import { cache as reactCache } from "react"; import { prisma } from "@formbricks/database"; import { cache } from "@formbricks/lib/cache"; +import { logger } from "@formbricks/logger"; import { DatabaseError, ResourceNotFoundError } from "@formbricks/types/errors"; export const deleteInvite = async (inviteId: string): Promise => { @@ -32,8 +33,7 @@ export const deleteInvite = async (inviteId: string): Promise => { if (error instanceof Prisma.PrismaClientKnownRequestError) { throw new DatabaseError(error.message); } - - throw error; + throw new DatabaseError(error instanceof Error ? error.message : "Unknown error occurred"); } }; @@ -66,8 +66,7 @@ export const getInvite = reactCache( if (error instanceof Prisma.PrismaClientKnownRequestError) { throw new DatabaseError(error.message); } - - throw error; + throw new DatabaseError(error instanceof Error ? error.message : "Unknown error occurred"); } }, [`signup-getInvite-${inviteId}`], @@ -76,3 +75,47 @@ export const getInvite = reactCache( } )() ); + +export const getIsValidInviteToken = reactCache( + async (inviteId: string): Promise => + cache( + async () => { + try { + const invite = await prisma.invite.findUnique({ + where: { id: inviteId }, + }); + if (!invite) { + return false; + } + if (!invite.expiresAt || isNaN(invite.expiresAt.getTime())) { + logger.error( + { + inviteId, + expiresAt: invite.expiresAt, + }, + "SSO: Invite token expired" + ); + return false; + } + if (invite.expiresAt < new Date()) { + logger.error( + { + inviteId, + expiresAt: invite.expiresAt, + }, + "SSO: Invite token expired" + ); + return false; + } + return true; + } catch (err) { + logger.error(err, "Error getting invite"); + return false; + } + }, + [`getIsValidInviteToken-${inviteId}`], + { + tags: [inviteCache.tag.byId(inviteId)], + } + )() +); diff --git a/apps/web/modules/auth/signup/page.test.tsx b/apps/web/modules/auth/signup/page.test.tsx new file mode 100644 index 0000000000..afe7955e43 --- /dev/null +++ b/apps/web/modules/auth/signup/page.test.tsx @@ -0,0 +1,154 @@ +import { getIsValidInviteToken } from "@/modules/auth/signup/lib/invite"; +import { + getIsMultiOrgEnabled, + getIsSamlSsoEnabled, + getisSsoEnabled, +} from "@/modules/ee/license-check/lib/utils"; +import { cleanup, render, screen } from "@testing-library/react"; +import { notFound } from "next/navigation"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import { verifyInviteToken } from "@formbricks/lib/jwt"; +import { findMatchingLocale } from "@formbricks/lib/utils/locale"; +import { SignupPage } from "./page"; + +// Mock the necessary dependencies +vi.mock("@/modules/auth/components/testimonial", () => ({ + Testimonial: () =>
Testimonial
, +})); + +vi.mock("@/modules/auth/components/form-wrapper", () => ({ + FormWrapper: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), +})); + +vi.mock("@/modules/auth/signup/components/signup-form", () => ({ + SignupForm: () =>
SignupForm
, +})); + +vi.mock("@/modules/ee/license-check/lib/utils", () => ({ + getIsMultiOrgEnabled: vi.fn(), + getIsSamlSsoEnabled: vi.fn(), + getisSsoEnabled: vi.fn(), +})); + +vi.mock("@/modules/auth/signup/lib/invite", () => ({ + getIsValidInviteToken: vi.fn(), +})); + +vi.mock("@formbricks/lib/jwt", () => ({ + verifyInviteToken: vi.fn(), +})); + +vi.mock("@formbricks/lib/utils/locale", () => ({ + findMatchingLocale: vi.fn(), +})); + +vi.mock("next/navigation", () => ({ + notFound: vi.fn(), +})); + +// Mock environment variables and constants +vi.mock("@formbricks/lib/constants", () => ({ + SIGNUP_ENABLED: true, + EMAIL_AUTH_ENABLED: true, + EMAIL_VERIFICATION_DISABLED: false, + GOOGLE_OAUTH_ENABLED: true, + GITHUB_OAUTH_ENABLED: true, + AZURE_OAUTH_ENABLED: true, + OIDC_OAUTH_ENABLED: true, + OIDC_DISPLAY_NAME: "OpenID", + SAML_OAUTH_ENABLED: true, + SAML_TENANT: "test-tenant", + SAML_PRODUCT: "test-product", + IS_TURNSTILE_CONFIGURED: true, + WEBAPP_URL: "http://localhost:3000", + TERMS_URL: "http://localhost:3000/terms", + PRIVACY_URL: "http://localhost:3000/privacy", + DEFAULT_ORGANIZATION_ID: "test-org-id", + DEFAULT_ORGANIZATION_ROLE: "admin", +})); + +describe("SignupPage", () => { + const mockSearchParams = { + inviteToken: "test-token", + email: "test@example.com", + }; + + beforeEach(() => { + vi.clearAllMocks(); + }); + + afterEach(() => { + cleanup(); + }); + + it("renders the signup page with all components when signup is enabled", async () => { + // Mock the license check functions to return true + vi.mocked(getIsMultiOrgEnabled).mockResolvedValue(true); + vi.mocked(getisSsoEnabled).mockResolvedValue(true); + vi.mocked(getIsSamlSsoEnabled).mockResolvedValue(true); + vi.mocked(findMatchingLocale).mockResolvedValue("en"); + vi.mocked(verifyInviteToken).mockReturnValue({ inviteId: "test-invite-id" }); + vi.mocked(getIsValidInviteToken).mockResolvedValue(true); + + const result = await SignupPage({ searchParams: mockSearchParams }); + render(result); + + // Verify that all components are rendered + expect(screen.getByTestId("testimonial")).toBeInTheDocument(); + expect(screen.getByTestId("form-wrapper")).toBeInTheDocument(); + expect(screen.getByTestId("signup-form")).toBeInTheDocument(); + }); + + it("calls notFound when signup is disabled and no valid invite token is provided", async () => { + // Mock the license check functions to return false + vi.mocked(getIsMultiOrgEnabled).mockResolvedValue(false); + vi.mocked(verifyInviteToken).mockImplementation(() => { + throw new Error("Invalid token"); + }); + + await SignupPage({ searchParams: {} }); + + expect(notFound).toHaveBeenCalled(); + }); + + it("calls notFound when invite token is invalid", async () => { + // Mock the license check functions to return false + vi.mocked(getIsMultiOrgEnabled).mockResolvedValue(false); + vi.mocked(verifyInviteToken).mockImplementation(() => { + throw new Error("Invalid token"); + }); + + await SignupPage({ searchParams: { inviteToken: "invalid-token" } }); + + expect(notFound).toHaveBeenCalled(); + }); + + it("calls notFound when invite token is valid but invite is not found", async () => { + // Mock the license check functions to return false + vi.mocked(getIsMultiOrgEnabled).mockResolvedValue(false); + vi.mocked(verifyInviteToken).mockReturnValue({ inviteId: "test-invite-id" }); + vi.mocked(getIsValidInviteToken).mockResolvedValue(false); + + await SignupPage({ searchParams: { inviteToken: "test-token" } }); + + expect(notFound).toHaveBeenCalled(); + }); + + it("renders the page with email from search params", async () => { + // Mock the license check functions to return true + vi.mocked(getIsMultiOrgEnabled).mockResolvedValue(true); + vi.mocked(getisSsoEnabled).mockResolvedValue(true); + vi.mocked(getIsSamlSsoEnabled).mockResolvedValue(true); + vi.mocked(findMatchingLocale).mockResolvedValue("en"); + vi.mocked(verifyInviteToken).mockReturnValue({ inviteId: "test-invite-id" }); + vi.mocked(getIsValidInviteToken).mockResolvedValue(true); + + const result = await SignupPage({ searchParams: { email: "test@example.com" } }); + render(result); + + // Verify that the form is rendered with the email from search params + expect(screen.getByTestId("signup-form")).toBeInTheDocument(); + }); +}); diff --git a/apps/web/modules/auth/signup/page.tsx b/apps/web/modules/auth/signup/page.tsx index d4f164fa5a..17e810ba2e 100644 --- a/apps/web/modules/auth/signup/page.tsx +++ b/apps/web/modules/auth/signup/page.tsx @@ -1,5 +1,6 @@ import { FormWrapper } from "@/modules/auth/components/form-wrapper"; import { Testimonial } from "@/modules/auth/components/testimonial"; +import { getIsValidInviteToken } from "@/modules/auth/signup/lib/invite"; import { getIsMultiOrgEnabled, getIsSamlSsoEnabled, @@ -25,6 +26,7 @@ import { TERMS_URL, WEBAPP_URL, } from "@formbricks/lib/constants"; +import { verifyInviteToken } from "@formbricks/lib/jwt"; import { findMatchingLocale } from "@formbricks/lib/utils/locale"; import { SignupForm } from "./components/signup-form"; @@ -38,11 +40,20 @@ export const SignupPage = async ({ searchParams: searchParamsProps }) => { ]); const samlSsoEnabled = isSamlSsoEnabled && SAML_OAUTH_ENABLED; - const locale = await findMatchingLocale(); - if (!inviteToken && (!SIGNUP_ENABLED || !isMultOrgEnabled)) { - notFound(); + if (!SIGNUP_ENABLED || !isMultOrgEnabled) { + if (!inviteToken) notFound(); + + try { + const { inviteId } = verifyInviteToken(inviteToken); + const isValidInviteToken = await getIsValidInviteToken(inviteId); + + if (!isValidInviteToken) notFound(); + } catch { + notFound(); + } } + const emailFromSearchParams = searchParams["email"]; return ( diff --git a/apps/web/modules/ee/sso/components/azure-button.test.tsx b/apps/web/modules/ee/sso/components/azure-button.test.tsx new file mode 100644 index 0000000000..bef70859f7 --- /dev/null +++ b/apps/web/modules/ee/sso/components/azure-button.test.tsx @@ -0,0 +1,84 @@ +import { cleanup, fireEvent, render, screen } from "@testing-library/react"; +import { signIn } from "next-auth/react"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import { FORMBRICKS_LOGGED_IN_WITH_LS } from "@formbricks/lib/localStorage"; +import { AzureButton } from "./azure-button"; + +// Mock next-auth/react +vi.mock("next-auth/react", () => ({ + signIn: vi.fn(), +})); + +// Mock localStorage +const mockLocalStorage = { + setItem: vi.fn(), +}; +Object.defineProperty(window, "localStorage", { + value: mockLocalStorage, + writable: true, +}); + +describe("AzureButton", () => { + const defaultProps = { + source: "signin" as const, + }; + + afterEach(() => { + cleanup(); + vi.clearAllMocks(); + }); + + it("renders correctly with default props", () => { + render(); + const button = screen.getByRole("button", { name: "auth.continue_with_azure" }); + expect(button).toBeInTheDocument(); + }); + + it("renders with last used indicator when lastUsed is true", () => { + render(); + expect(screen.getByText("auth.last_used")).toBeInTheDocument(); + }); + + it("sets localStorage item and calls signIn on click", async () => { + render(); + const button = screen.getByRole("button", { name: "auth.continue_with_azure" }); + fireEvent.click(button); + + expect(mockLocalStorage.setItem).toHaveBeenCalledWith(FORMBRICKS_LOGGED_IN_WITH_LS, "Azure"); + expect(signIn).toHaveBeenCalledWith("azure-ad", { + redirect: true, + callbackUrl: "/?source=signin", + }); + }); + + it("uses inviteUrl in callbackUrl when provided", async () => { + const inviteUrl = "https://example.com/invite"; + render(); + const button = screen.getByRole("button", { name: "auth.continue_with_azure" }); + fireEvent.click(button); + + expect(signIn).toHaveBeenCalledWith("azure-ad", { + redirect: true, + callbackUrl: "https://example.com/invite?source=signin", + }); + }); + + it("handles signup source correctly", async () => { + render(); + const button = screen.getByRole("button", { name: "auth.continue_with_azure" }); + fireEvent.click(button); + + expect(signIn).toHaveBeenCalledWith("azure-ad", { + redirect: true, + callbackUrl: "/?source=signup", + }); + }); + + it("triggers direct redirect when directRedirect is true", () => { + render(); + expect(signIn).toHaveBeenCalledWith("azure-ad", { + redirect: true, + callbackUrl: "/?source=signin", + }); + }); +}); diff --git a/apps/web/modules/ee/sso/components/azure-button.tsx b/apps/web/modules/ee/sso/components/azure-button.tsx index c6676213ae..84109ff94a 100644 --- a/apps/web/modules/ee/sso/components/azure-button.tsx +++ b/apps/web/modules/ee/sso/components/azure-button.tsx @@ -1,5 +1,6 @@ "use client"; +import { getCallbackUrl } from "@/modules/ee/sso/lib/utils"; import { Button } from "@/modules/ui/components/button"; import { MicrosoftIcon } from "@/modules/ui/components/icons"; import { useTranslate } from "@tolgee/react"; @@ -11,20 +12,22 @@ interface AzureButtonProps { inviteUrl?: string; directRedirect?: boolean; lastUsed?: boolean; + source: "signin" | "signup"; } -export const AzureButton = ({ inviteUrl, directRedirect = false, lastUsed }: AzureButtonProps) => { +export const AzureButton = ({ inviteUrl, directRedirect = false, lastUsed, source }: AzureButtonProps) => { const { t } = useTranslate(); const handleLogin = useCallback(async () => { if (typeof window !== "undefined") { localStorage.setItem(FORMBRICKS_LOGGED_IN_WITH_LS, "Azure"); } + const callbackUrlWithSource = getCallbackUrl(inviteUrl, source); await signIn("azure-ad", { redirect: true, - callbackUrl: inviteUrl ? inviteUrl : "/", + callbackUrl: callbackUrlWithSource, }); - }, [inviteUrl]); + }, [inviteUrl, source]); useEffect(() => { if (directRedirect) { diff --git a/apps/web/modules/ee/sso/components/github-button.test.tsx b/apps/web/modules/ee/sso/components/github-button.test.tsx new file mode 100644 index 0000000000..cde77b5ae6 --- /dev/null +++ b/apps/web/modules/ee/sso/components/github-button.test.tsx @@ -0,0 +1,76 @@ +import { cleanup, fireEvent, render, screen } from "@testing-library/react"; +import { signIn } from "next-auth/react"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import { FORMBRICKS_LOGGED_IN_WITH_LS } from "@formbricks/lib/localStorage"; +import { GithubButton } from "./github-button"; + +// Mock next-auth/react +vi.mock("next-auth/react", () => ({ + signIn: vi.fn(), +})); + +// Mock localStorage +const mockLocalStorage = { + setItem: vi.fn(), +}; +Object.defineProperty(window, "localStorage", { + value: mockLocalStorage, + writable: true, +}); + +describe("GithubButton", () => { + const defaultProps = { + source: "signin" as const, + }; + + afterEach(() => { + cleanup(); + vi.clearAllMocks(); + }); + + it("renders correctly with default props", () => { + render(); + const button = screen.getByRole("button", { name: "auth.continue_with_github" }); + expect(button).toBeInTheDocument(); + }); + + it("renders with last used indicator when lastUsed is true", () => { + render(); + expect(screen.getByText("auth.last_used")).toBeInTheDocument(); + }); + + it("sets localStorage item and calls signIn on click", async () => { + render(); + const button = screen.getByRole("button", { name: "auth.continue_with_github" }); + fireEvent.click(button); + + expect(mockLocalStorage.setItem).toHaveBeenCalledWith(FORMBRICKS_LOGGED_IN_WITH_LS, "Github"); + expect(signIn).toHaveBeenCalledWith("github", { + redirect: true, + callbackUrl: "/?source=signin", + }); + }); + + it("uses inviteUrl in callbackUrl when provided", async () => { + const inviteUrl = "https://example.com/invite"; + render(); + const button = screen.getByRole("button", { name: "auth.continue_with_github" }); + fireEvent.click(button); + + expect(signIn).toHaveBeenCalledWith("github", { + redirect: true, + callbackUrl: "https://example.com/invite?source=signin", + }); + }); + + it("handles signup source correctly", async () => { + render(); + const button = screen.getByRole("button", { name: "auth.continue_with_github" }); + fireEvent.click(button); + + expect(signIn).toHaveBeenCalledWith("github", { + redirect: true, + callbackUrl: "/?source=signup", + }); + }); +}); diff --git a/apps/web/modules/ee/sso/components/github-button.tsx b/apps/web/modules/ee/sso/components/github-button.tsx index be497b5e31..e758a7f93a 100644 --- a/apps/web/modules/ee/sso/components/github-button.tsx +++ b/apps/web/modules/ee/sso/components/github-button.tsx @@ -1,5 +1,6 @@ "use client"; +import { getCallbackUrl } from "@/modules/ee/sso/lib/utils"; import { Button } from "@/modules/ui/components/button"; import { GithubIcon } from "@/modules/ui/components/icons"; import { useTranslate } from "@tolgee/react"; @@ -9,17 +10,20 @@ import { FORMBRICKS_LOGGED_IN_WITH_LS } from "@formbricks/lib/localStorage"; interface GithubButtonProps { inviteUrl?: string; lastUsed?: boolean; + source: "signin" | "signup"; } -export const GithubButton = ({ inviteUrl, lastUsed }: GithubButtonProps) => { +export const GithubButton = ({ inviteUrl, lastUsed, source }: GithubButtonProps) => { const { t } = useTranslate(); const handleLogin = async () => { if (typeof window !== "undefined") { localStorage.setItem(FORMBRICKS_LOGGED_IN_WITH_LS, "Github"); } + const callbackUrlWithSource = getCallbackUrl(inviteUrl, source); + await signIn("github", { redirect: true, - callbackUrl: inviteUrl ? inviteUrl : "/", // redirect after login to / + callbackUrl: callbackUrlWithSource, }); }; diff --git a/apps/web/modules/ee/sso/components/google-button.test.tsx b/apps/web/modules/ee/sso/components/google-button.test.tsx new file mode 100644 index 0000000000..c6a8055d55 --- /dev/null +++ b/apps/web/modules/ee/sso/components/google-button.test.tsx @@ -0,0 +1,76 @@ +import { cleanup, fireEvent, render, screen } from "@testing-library/react"; +import { signIn } from "next-auth/react"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import { FORMBRICKS_LOGGED_IN_WITH_LS } from "@formbricks/lib/localStorage"; +import { GoogleButton } from "./google-button"; + +// Mock next-auth/react +vi.mock("next-auth/react", () => ({ + signIn: vi.fn(), +})); + +// Mock localStorage +const mockLocalStorage = { + setItem: vi.fn(), +}; +Object.defineProperty(window, "localStorage", { + value: mockLocalStorage, + writable: true, +}); + +describe("GoogleButton", () => { + const defaultProps = { + source: "signin" as const, + }; + + afterEach(() => { + cleanup(); + vi.clearAllMocks(); + }); + + it("renders correctly with default props", () => { + render(); + const button = screen.getByRole("button", { name: "auth.continue_with_google" }); + expect(button).toBeInTheDocument(); + }); + + it("renders with last used indicator when lastUsed is true", () => { + render(); + expect(screen.getByText("auth.last_used")).toBeInTheDocument(); + }); + + it("sets localStorage item and calls signIn on click", async () => { + render(); + const button = screen.getByRole("button", { name: "auth.continue_with_google" }); + fireEvent.click(button); + + expect(mockLocalStorage.setItem).toHaveBeenCalledWith(FORMBRICKS_LOGGED_IN_WITH_LS, "Google"); + expect(signIn).toHaveBeenCalledWith("google", { + redirect: true, + callbackUrl: "/?source=signin", + }); + }); + + it("uses inviteUrl in callbackUrl when provided", async () => { + const inviteUrl = "https://example.com/invite"; + render(); + const button = screen.getByRole("button", { name: "auth.continue_with_google" }); + fireEvent.click(button); + + expect(signIn).toHaveBeenCalledWith("google", { + redirect: true, + callbackUrl: "https://example.com/invite?source=signin", + }); + }); + + it("handles signup source correctly", async () => { + render(); + const button = screen.getByRole("button", { name: "auth.continue_with_google" }); + fireEvent.click(button); + + expect(signIn).toHaveBeenCalledWith("google", { + redirect: true, + callbackUrl: "/?source=signup", + }); + }); +}); diff --git a/apps/web/modules/ee/sso/components/google-button.tsx b/apps/web/modules/ee/sso/components/google-button.tsx index 7b8e679e1f..5379311ec0 100644 --- a/apps/web/modules/ee/sso/components/google-button.tsx +++ b/apps/web/modules/ee/sso/components/google-button.tsx @@ -1,5 +1,6 @@ "use client"; +import { getCallbackUrl } from "@/modules/ee/sso/lib/utils"; import { Button } from "@/modules/ui/components/button"; import { GoogleIcon } from "@/modules/ui/components/icons"; import { useTranslate } from "@tolgee/react"; @@ -9,17 +10,20 @@ import { FORMBRICKS_LOGGED_IN_WITH_LS } from "@formbricks/lib/localStorage"; interface GoogleButtonProps { inviteUrl?: string; lastUsed?: boolean; + source: "signin" | "signup"; } -export const GoogleButton = ({ inviteUrl, lastUsed }: GoogleButtonProps) => { +export const GoogleButton = ({ inviteUrl, lastUsed, source }: GoogleButtonProps) => { const { t } = useTranslate(); const handleLogin = async () => { if (typeof window !== "undefined") { localStorage.setItem(FORMBRICKS_LOGGED_IN_WITH_LS, "Google"); } + const callbackUrlWithSource = getCallbackUrl(inviteUrl, source); + await signIn("google", { redirect: true, - callbackUrl: inviteUrl ? inviteUrl : "/", // redirect after login to / + callbackUrl: callbackUrlWithSource, }); }; diff --git a/apps/web/modules/ee/sso/components/open-id-button.test.tsx b/apps/web/modules/ee/sso/components/open-id-button.test.tsx new file mode 100644 index 0000000000..4943794ec8 --- /dev/null +++ b/apps/web/modules/ee/sso/components/open-id-button.test.tsx @@ -0,0 +1,91 @@ +import { cleanup, fireEvent, render, screen } from "@testing-library/react"; +import { signIn } from "next-auth/react"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import { FORMBRICKS_LOGGED_IN_WITH_LS } from "@formbricks/lib/localStorage"; +import { OpenIdButton } from "./open-id-button"; + +// Mock next-auth/react +vi.mock("next-auth/react", () => ({ + signIn: vi.fn(), +})); + +// Mock localStorage +const mockLocalStorage = { + setItem: vi.fn(), +}; +Object.defineProperty(window, "localStorage", { + value: mockLocalStorage, + writable: true, +}); + +describe("OpenIdButton", () => { + const defaultProps = { + source: "signin" as const, + }; + + afterEach(() => { + cleanup(); + vi.clearAllMocks(); + }); + + it("renders correctly with default props", () => { + render(); + const button = screen.getByRole("button", { name: "auth.continue_with_openid" }); + expect(button).toBeInTheDocument(); + }); + + it("renders with custom text when provided", () => { + const customText = "Custom OpenID Text"; + render(); + const button = screen.getByRole("button", { name: customText }); + expect(button).toBeInTheDocument(); + }); + + it("renders with last used indicator when lastUsed is true", () => { + render(); + expect(screen.getByText("auth.last_used")).toBeInTheDocument(); + }); + + it("sets localStorage item and calls signIn on click", async () => { + render(); + const button = screen.getByRole("button", { name: "auth.continue_with_openid" }); + fireEvent.click(button); + + expect(mockLocalStorage.setItem).toHaveBeenCalledWith(FORMBRICKS_LOGGED_IN_WITH_LS, "OpenID"); + expect(signIn).toHaveBeenCalledWith("openid", { + redirect: true, + callbackUrl: "/?source=signin", + }); + }); + + it("uses inviteUrl in callbackUrl when provided", async () => { + const inviteUrl = "https://example.com/invite"; + render(); + const button = screen.getByRole("button", { name: "auth.continue_with_openid" }); + fireEvent.click(button); + + expect(signIn).toHaveBeenCalledWith("openid", { + redirect: true, + callbackUrl: "https://example.com/invite?source=signin", + }); + }); + + it("handles signup source correctly", async () => { + render(); + const button = screen.getByRole("button", { name: "auth.continue_with_openid" }); + fireEvent.click(button); + + expect(signIn).toHaveBeenCalledWith("openid", { + redirect: true, + callbackUrl: "/?source=signup", + }); + }); + + it("triggers direct redirect when directRedirect is true", () => { + render(); + expect(signIn).toHaveBeenCalledWith("openid", { + redirect: true, + callbackUrl: "/?source=signin", + }); + }); +}); diff --git a/apps/web/modules/ee/sso/components/open-id-button.tsx b/apps/web/modules/ee/sso/components/open-id-button.tsx index 588b452f36..b07258c5e2 100644 --- a/apps/web/modules/ee/sso/components/open-id-button.tsx +++ b/apps/web/modules/ee/sso/components/open-id-button.tsx @@ -1,5 +1,6 @@ "use client"; +import { getCallbackUrl } from "@/modules/ee/sso/lib/utils"; import { Button } from "@/modules/ui/components/button"; import { useTranslate } from "@tolgee/react"; import { signIn } from "next-auth/react"; @@ -11,19 +12,28 @@ interface OpenIdButtonProps { lastUsed?: boolean; directRedirect?: boolean; text?: string; + source: "signin" | "signup"; } -export const OpenIdButton = ({ inviteUrl, lastUsed, directRedirect = false, text }: OpenIdButtonProps) => { +export const OpenIdButton = ({ + inviteUrl, + lastUsed, + directRedirect = false, + text, + source, +}: OpenIdButtonProps) => { const { t } = useTranslate(); const handleLogin = useCallback(async () => { if (typeof window !== "undefined") { localStorage.setItem(FORMBRICKS_LOGGED_IN_WITH_LS, "OpenID"); } + const callbackUrlWithSource = getCallbackUrl(inviteUrl, source); + await signIn("openid", { redirect: true, - callbackUrl: inviteUrl ? inviteUrl : "/", + callbackUrl: callbackUrlWithSource, }); - }, [inviteUrl]); + }, [inviteUrl, source]); useEffect(() => { if (directRedirect) { diff --git a/apps/web/modules/ee/sso/components/saml-button.test.tsx b/apps/web/modules/ee/sso/components/saml-button.test.tsx new file mode 100644 index 0000000000..5c7b707bc8 --- /dev/null +++ b/apps/web/modules/ee/sso/components/saml-button.test.tsx @@ -0,0 +1,130 @@ +import { doesSamlConnectionExistAction } from "@/modules/ee/sso/actions"; +import { cleanup, fireEvent, render, screen } from "@testing-library/react"; +import { signIn } from "next-auth/react"; +import toast from "react-hot-toast"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import { FORMBRICKS_LOGGED_IN_WITH_LS } from "@formbricks/lib/localStorage"; +import { SamlButton } from "./saml-button"; + +// Mock next-auth/react +vi.mock("next-auth/react", () => ({ + signIn: vi.fn().mockResolvedValue(undefined), +})); + +// Mock localStorage +const mockLocalStorage = { + setItem: vi.fn(), +}; +Object.defineProperty(window, "localStorage", { + value: mockLocalStorage, + writable: true, +}); + +// Mock actions +vi.mock("@/modules/ee/sso/actions", () => ({ + doesSamlConnectionExistAction: vi.fn(), +})); + +// Mock toast +vi.mock("react-hot-toast", () => ({ + default: { + error: vi.fn(), + }, +})); + +describe("SamlButton", () => { + const defaultProps = { + source: "signin" as const, + samlTenant: "test-tenant", + samlProduct: "test-product", + }; + + afterEach(() => { + cleanup(); + vi.clearAllMocks(); + }); + + it("renders correctly with default props", () => { + render(); + const button = screen.getByRole("button", { name: "auth.continue_with_saml" }); + expect(button).toBeInTheDocument(); + }); + + it("renders with last used indicator when lastUsed is true", () => { + render(); + expect(screen.getByText("auth.last_used")).toBeInTheDocument(); + }); + + it("sets localStorage item and calls signIn on click when SAML connection exists", async () => { + vi.mocked(doesSamlConnectionExistAction).mockResolvedValue({ data: true }); + render(); + const button = screen.getByRole("button", { name: "auth.continue_with_saml" }); + + await fireEvent.click(button); + + expect(mockLocalStorage.setItem).toHaveBeenCalledWith(FORMBRICKS_LOGGED_IN_WITH_LS, "Saml"); + expect(signIn).toHaveBeenCalledWith( + "saml", + { + redirect: true, + callbackUrl: "/?source=signin", + }, + { + tenant: "test-tenant", + product: "test-product", + } + ); + }); + + it("shows error toast when SAML connection does not exist", async () => { + vi.mocked(doesSamlConnectionExistAction).mockResolvedValue({ data: false }); + render(); + const button = screen.getByRole("button", { name: "auth.continue_with_saml" }); + + await fireEvent.click(button); + + expect(toast.error).toHaveBeenCalledWith("auth.saml_connection_error"); + expect(signIn).not.toHaveBeenCalled(); + }); + + it("uses inviteUrl in callbackUrl when provided", async () => { + vi.mocked(doesSamlConnectionExistAction).mockResolvedValue({ data: true }); + const inviteUrl = "https://example.com/invite"; + render(); + const button = screen.getByRole("button", { name: "auth.continue_with_saml" }); + + await fireEvent.click(button); + + expect(signIn).toHaveBeenCalledWith( + "saml", + { + redirect: true, + callbackUrl: "https://example.com/invite?source=signin", + }, + { + tenant: "test-tenant", + product: "test-product", + } + ); + }); + + it("handles signup source correctly", async () => { + vi.mocked(doesSamlConnectionExistAction).mockResolvedValue({ data: true }); + render(); + const button = screen.getByRole("button", { name: "auth.continue_with_saml" }); + + await fireEvent.click(button); + + expect(signIn).toHaveBeenCalledWith( + "saml", + { + redirect: true, + callbackUrl: "/?source=signup", + }, + { + tenant: "test-tenant", + product: "test-product", + } + ); + }); +}); diff --git a/apps/web/modules/ee/sso/components/saml-button.tsx b/apps/web/modules/ee/sso/components/saml-button.tsx index 6167e1cb20..258a6b80ed 100644 --- a/apps/web/modules/ee/sso/components/saml-button.tsx +++ b/apps/web/modules/ee/sso/components/saml-button.tsx @@ -1,6 +1,7 @@ "use client"; import { doesSamlConnectionExistAction } from "@/modules/ee/sso/actions"; +import { getCallbackUrl } from "@/modules/ee/sso/lib/utils"; import { Button } from "@/modules/ui/components/button"; import { useTranslate } from "@tolgee/react"; import { LockIcon } from "lucide-react"; @@ -14,9 +15,10 @@ interface SamlButtonProps { lastUsed?: boolean; samlTenant: string; samlProduct: string; + source: "signin" | "signup"; } -export const SamlButton = ({ inviteUrl, lastUsed, samlTenant, samlProduct }: SamlButtonProps) => { +export const SamlButton = ({ inviteUrl, lastUsed, samlTenant, samlProduct, source }: SamlButtonProps) => { const { t } = useTranslate(); const [isLoading, setIsLoading] = useState(false); @@ -32,11 +34,13 @@ export const SamlButton = ({ inviteUrl, lastUsed, samlTenant, samlProduct }: Sam return; } + const callbackUrlWithSource = getCallbackUrl(inviteUrl, source); + signIn( "saml", { redirect: true, - callbackUrl: inviteUrl ? inviteUrl : "/", // redirect after login to / + callbackUrl: callbackUrlWithSource, }, { tenant: samlTenant, diff --git a/apps/web/modules/ee/sso/components/sso-options.test.tsx b/apps/web/modules/ee/sso/components/sso-options.test.tsx new file mode 100644 index 0000000000..458413d5a0 --- /dev/null +++ b/apps/web/modules/ee/sso/components/sso-options.test.tsx @@ -0,0 +1,137 @@ +import { cleanup, render, screen } from "@testing-library/react"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import { SSOOptions } from "./sso-options"; + +// Mock environment variables +vi.mock("@formbricks/lib/env", () => ({ + env: { + IS_FORMBRICKS_CLOUD: "0", + }, +})); + +// Mock the translation hook +vi.mock("@tolgee/react", () => ({ + useTranslate: () => ({ + t: (key: string) => key, + }), +})); + +// Mock the individual SSO buttons +vi.mock("./google-button", () => ({ + GoogleButton: ({ lastUsed, source }: any) => ( +
+ Google Button +
+ ), +})); + +vi.mock("./github-button", () => ({ + GithubButton: ({ lastUsed, source }: any) => ( +
+ Github Button +
+ ), +})); + +vi.mock("./azure-button", () => ({ + AzureButton: ({ lastUsed, source }: any) => ( +
+ Azure Button +
+ ), +})); + +vi.mock("./open-id-button", () => ({ + OpenIdButton: ({ lastUsed, source, text }: any) => ( +
+ {text} +
+ ), +})); + +vi.mock("./saml-button", () => ({ + SamlButton: ({ lastUsed, source, samlTenant, samlProduct }: any) => ( +
+ Saml Button +
+ ), +})); + +describe("SSOOptions Component", () => { + afterEach(() => { + cleanup(); + vi.clearAllMocks(); + }); + + const defaultProps = { + googleOAuthEnabled: true, + githubOAuthEnabled: true, + azureOAuthEnabled: true, + oidcOAuthEnabled: true, + oidcDisplayName: "OpenID", + callbackUrl: "http://localhost:3000", + samlSsoEnabled: true, + samlTenant: "test-tenant", + samlProduct: "test-product", + source: "signin" as const, + }; + + it("renders all SSO options when all are enabled", () => { + render(); + + expect(screen.getByTestId("google-button")).toBeInTheDocument(); + expect(screen.getByTestId("github-button")).toBeInTheDocument(); + expect(screen.getByTestId("azure-button")).toBeInTheDocument(); + expect(screen.getByTestId("openid-button")).toBeInTheDocument(); + expect(screen.getByTestId("saml-button")).toBeInTheDocument(); + }); + + it("only renders enabled SSO options", () => { + render( + + ); + + expect(screen.queryByTestId("google-button")).not.toBeInTheDocument(); + expect(screen.queryByTestId("github-button")).not.toBeInTheDocument(); + expect(screen.queryByTestId("azure-button")).not.toBeInTheDocument(); + expect(screen.getByTestId("openid-button")).toBeInTheDocument(); + expect(screen.getByTestId("saml-button")).toBeInTheDocument(); + }); + + it("passes correct props to OpenID button", () => { + render(); + const openIdButton = screen.getByTestId("openid-button"); + + expect(openIdButton).toHaveAttribute("data-source", "signin"); + expect(openIdButton).toHaveTextContent("auth.continue_with_oidc"); + }); + + it("passes correct props to SAML button", () => { + render(); + const samlButton = screen.getByTestId("saml-button"); + + expect(samlButton).toHaveAttribute("data-source", "signin"); + expect(samlButton).toHaveAttribute("data-tenant", "test-tenant"); + expect(samlButton).toHaveAttribute("data-product", "test-product"); + }); + + it("passes correct source prop to all buttons", () => { + render(); + + expect(screen.getByTestId("google-button")).toHaveAttribute("data-source", "signup"); + expect(screen.getByTestId("github-button")).toHaveAttribute("data-source", "signup"); + expect(screen.getByTestId("azure-button")).toHaveAttribute("data-source", "signup"); + expect(screen.getByTestId("openid-button")).toHaveAttribute("data-source", "signup"); + expect(screen.getByTestId("saml-button")).toHaveAttribute("data-source", "signup"); + }); +}); diff --git a/apps/web/modules/ee/sso/components/sso-options.tsx b/apps/web/modules/ee/sso/components/sso-options.tsx index 8f3c490520..cf0cb40579 100644 --- a/apps/web/modules/ee/sso/components/sso-options.tsx +++ b/apps/web/modules/ee/sso/components/sso-options.tsx @@ -19,6 +19,7 @@ interface SSOOptionsProps { samlSsoEnabled: boolean; samlTenant: string; samlProduct: string; + source: "signin" | "signup"; } export const SSOOptions = ({ @@ -31,6 +32,7 @@ export const SSOOptions = ({ samlSsoEnabled, samlTenant, samlProduct, + source, }: SSOOptionsProps) => { const { t } = useTranslate(); const [lastLoggedInWith, setLastLoggedInWith] = useState(""); @@ -44,17 +46,20 @@ export const SSOOptions = ({ return (
{googleOAuthEnabled && ( - + )} {githubOAuthEnabled && ( - + + )} + {azureOAuthEnabled && ( + )} - {azureOAuthEnabled && } {oidcOAuthEnabled && ( )} {samlSsoEnabled && ( @@ -63,6 +68,7 @@ export const SSOOptions = ({ lastUsed={lastLoggedInWith === "Saml"} samlTenant={samlTenant} samlProduct={samlProduct} + source={source} /> )}
diff --git a/apps/web/modules/ee/sso/lib/sso-handlers.ts b/apps/web/modules/ee/sso/lib/sso-handlers.ts index 35bb9920a0..b5500f34cb 100644 --- a/apps/web/modules/ee/sso/lib/sso-handlers.ts +++ b/apps/web/modules/ee/sso/lib/sso-handlers.ts @@ -1,19 +1,34 @@ import { createBrevoCustomer } from "@/modules/auth/lib/brevo"; import { getUserByEmail, updateUser } from "@/modules/auth/lib/user"; import { createUser } from "@/modules/auth/lib/user"; +import { getIsValidInviteToken } from "@/modules/auth/signup/lib/invite"; import { TOidcNameFields, TSamlNameFields } from "@/modules/auth/types/auth"; -import { getIsSamlSsoEnabled, getisSsoEnabled } from "@/modules/ee/license-check/lib/utils"; +import { + getIsMultiOrgEnabled, + getIsSamlSsoEnabled, + getisSsoEnabled, +} from "@/modules/ee/license-check/lib/utils"; import type { IdentityProvider } from "@prisma/client"; import type { Account } from "next-auth"; import { prisma } from "@formbricks/database"; import { createAccount } from "@formbricks/lib/account/service"; import { DEFAULT_ORGANIZATION_ID, DEFAULT_ORGANIZATION_ROLE } from "@formbricks/lib/constants"; +import { verifyInviteToken } from "@formbricks/lib/jwt"; import { createMembership } from "@formbricks/lib/membership/service"; import { createOrganization, getOrganization } from "@formbricks/lib/organization/service"; import { findMatchingLocale } from "@formbricks/lib/utils/locale"; +import { logger } from "@formbricks/logger"; import type { TUser, TUserNotificationSettings } from "@formbricks/types/user"; -export const handleSsoCallback = async ({ user, account }: { user: TUser; account: Account }) => { +export const handleSsoCallback = async ({ + user, + account, + callbackUrl, +}: { + user: TUser; + account: Account; + callbackUrl: string; +}) => { const isSsoEnabled = await getisSsoEnabled(); if (!isSsoEnabled) { return false; @@ -102,6 +117,46 @@ export const handleSsoCallback = async ({ user, account }: { user: TUser; accoun } } + // Get multi-org license status + const isMultiOrgEnabled = await getIsMultiOrgEnabled(); + + // Reject if no callback URL and no default org in self-hosted environment + if (!callbackUrl && !DEFAULT_ORGANIZATION_ID && !isMultiOrgEnabled) { + return false; + } + + // Additional security checks for self-hosted instances without default org + if (!DEFAULT_ORGANIZATION_ID && !isMultiOrgEnabled) { + try { + // Parse and validate the callback URL + const isValidCallbackUrl = new URL(callbackUrl); + // Extract invite token and source from URL parameters + const inviteToken = isValidCallbackUrl.searchParams.get("token") || ""; + const source = isValidCallbackUrl.searchParams.get("source") || ""; + + // Allow sign-in if multi-org is enabled, otherwise check for invite token + if (source === "signin" && !inviteToken) { + return false; + } + + // If multi-org is enabled, skip invite token validation + // Verify invite token and check email match + const { email, inviteId } = verifyInviteToken(inviteToken); + if (email !== user.email) { + return false; + } + // Check if invite token is still valid + const isValidInviteToken = await getIsValidInviteToken(inviteId); + if (!isValidInviteToken) { + return false; + } + } catch (err) { + // Log and reject on any validation errors + logger.error(err, "Invalid callbackUrl"); + return false; + } + } + const userProfile = await createUser({ name: userName || diff --git a/apps/web/modules/ee/sso/lib/tests/sso-handlers.test.ts b/apps/web/modules/ee/sso/lib/tests/sso-handlers.test.ts index 318a4a9143..3bb8b5ef53 100644 --- a/apps/web/modules/ee/sso/lib/tests/sso-handlers.test.ts +++ b/apps/web/modules/ee/sso/lib/tests/sso-handlers.test.ts @@ -1,5 +1,6 @@ import { createBrevoCustomer } from "@/modules/auth/lib/brevo"; import { createUser, getUserByEmail, updateUser } from "@/modules/auth/lib/user"; +import type { TSamlNameFields } from "@/modules/auth/types/auth"; import { getIsSamlSsoEnabled, getisSsoEnabled } from "@/modules/ee/license-check/lib/utils"; import { beforeEach, describe, expect, it, vi } from "vitest"; import { prisma } from "@formbricks/database"; @@ -7,6 +8,7 @@ import { createAccount } from "@formbricks/lib/account/service"; import { createMembership } from "@formbricks/lib/membership/service"; import { createOrganization, getOrganization } from "@formbricks/lib/organization/service"; import { findMatchingLocale } from "@formbricks/lib/utils/locale"; +import type { TUser } from "@formbricks/types/user"; import { handleSsoCallback } from "../sso-handlers"; import { mockAccount, @@ -32,6 +34,7 @@ vi.mock("@/modules/auth/lib/user", () => ({ vi.mock("@/modules/ee/license-check/lib/utils", () => ({ getIsSamlSsoEnabled: vi.fn(), getisSsoEnabled: vi.fn(), + getIsMultiOrgEnabled: vi.fn().mockResolvedValue(true), })); vi.mock("@formbricks/database", () => ({ @@ -63,6 +66,7 @@ vi.mock("@formbricks/lib/utils/locale", () => ({ vi.mock("@formbricks/lib/constants", () => ({ DEFAULT_ORGANIZATION_ID: "org-123", DEFAULT_ORGANIZATION_ROLE: "member", + ENCRYPTION_KEY: "test-encryption-key-32-chars-long", })); describe("handleSsoCallback", () => { @@ -90,24 +94,31 @@ describe("handleSsoCallback", () => { it("should return false if SSO is not enabled", async () => { vi.mocked(getisSsoEnabled).mockResolvedValue(false); - const result = await handleSsoCallback({ user: mockUser, account: mockAccount }); + const result = await handleSsoCallback({ + user: mockUser, + account: mockAccount, + callbackUrl: "http://localhost:3000", + }); expect(result).toBe(false); - expect(getisSsoEnabled).toHaveBeenCalled(); }); it("should return false if user email is missing", async () => { - const userWithoutEmail = { ...mockUser, email: "" }; - - const result = await handleSsoCallback({ user: userWithoutEmail, account: mockAccount }); + const result = await handleSsoCallback({ + user: { ...mockUser, email: "" }, + account: mockAccount, + callbackUrl: "http://localhost:3000", + }); expect(result).toBe(false); }); it("should return false if account type is not oauth", async () => { - const nonOauthAccount = { ...mockAccount, type: "credentials" as const }; - - const result = await handleSsoCallback({ user: mockUser, account: nonOauthAccount }); + const result = await handleSsoCallback({ + user: mockUser, + account: { ...mockAccount, type: "credentials" }, + callbackUrl: "http://localhost:3000", + }); expect(result).toBe(false); }); @@ -115,10 +126,13 @@ describe("handleSsoCallback", () => { it("should return false if provider is SAML and SAML SSO is not enabled", async () => { vi.mocked(getIsSamlSsoEnabled).mockResolvedValue(false); - const result = await handleSsoCallback({ user: mockUser, account: mockSamlAccount }); + const result = await handleSsoCallback({ + user: mockUser, + account: mockSamlAccount, + callbackUrl: "http://localhost:3000", + }); expect(result).toBe(false); - expect(getIsSamlSsoEnabled).toHaveBeenCalled(); }); }); @@ -130,7 +144,11 @@ describe("handleSsoCallback", () => { accounts: [{ provider: mockAccount.provider }], }); - const result = await handleSsoCallback({ user: mockUser, account: mockAccount }); + const result = await handleSsoCallback({ + user: mockUser, + account: mockAccount, + callbackUrl: "http://localhost:3000", + }); expect(result).toBe(true); expect(prisma.user.findFirst).toHaveBeenCalledWith({ @@ -160,7 +178,11 @@ describe("handleSsoCallback", () => { vi.mocked(getUserByEmail).mockResolvedValue(null); vi.mocked(updateUser).mockResolvedValue({ ...existingUser, email: mockUser.email }); - const result = await handleSsoCallback({ user: mockUser, account: mockAccount }); + const result = await handleSsoCallback({ + user: mockUser, + account: mockAccount, + callbackUrl: "http://localhost:3000", + }); expect(result).toBe(true); expect(updateUser).toHaveBeenCalledWith(existingUser.id, { email: mockUser.email }); @@ -180,9 +202,16 @@ describe("handleSsoCallback", () => { email: mockUser.email, emailVerified: mockUser.emailVerified, locale: mockUser.locale, + isActive: true, }); - await expect(handleSsoCallback({ user: mockUser, account: mockAccount })).rejects.toThrow( + await expect( + handleSsoCallback({ + user: mockUser, + account: mockAccount, + callbackUrl: "http://localhost:3000", + }) + ).rejects.toThrow( "Looks like you updated your email somewhere else. A user with this new email exists already." ); }); @@ -194,9 +223,14 @@ describe("handleSsoCallback", () => { email: mockUser.email, emailVerified: mockUser.emailVerified, locale: mockUser.locale, + isActive: true, }); - const result = await handleSsoCallback({ user: mockUser, account: mockAccount }); + const result = await handleSsoCallback({ + user: mockUser, + account: mockAccount, + callbackUrl: "http://localhost:3000", + }); expect(result).toBe(true); }); @@ -208,7 +242,11 @@ describe("handleSsoCallback", () => { vi.mocked(getUserByEmail).mockResolvedValue(null); vi.mocked(createUser).mockResolvedValue(mockCreatedUser()); - const result = await handleSsoCallback({ user: mockUser, account: mockAccount }); + const result = await handleSsoCallback({ + user: mockUser, + account: mockAccount, + callbackUrl: "http://localhost:3000", + }); expect(result).toBe(true); expect(createUser).toHaveBeenCalledWith({ @@ -228,7 +266,11 @@ describe("handleSsoCallback", () => { vi.mocked(createUser).mockResolvedValue(mockCreatedUser()); vi.mocked(getOrganization).mockResolvedValue(null); - const result = await handleSsoCallback({ user: mockUser, account: mockAccount }); + const result = await handleSsoCallback({ + user: mockUser, + account: mockAccount, + callbackUrl: "http://localhost:3000", + }); expect(result).toBe(true); expect(createOrganization).toHaveBeenCalledWith({ @@ -255,7 +297,11 @@ describe("handleSsoCallback", () => { vi.mocked(getUserByEmail).mockResolvedValue(null); vi.mocked(createUser).mockResolvedValue(mockCreatedUser()); - const result = await handleSsoCallback({ user: mockUser, account: mockAccount }); + const result = await handleSsoCallback({ + user: mockUser, + account: mockAccount, + callbackUrl: "http://localhost:3000", + }); expect(result).toBe(true); expect(createOrganization).not.toHaveBeenCalled(); @@ -276,14 +322,21 @@ describe("handleSsoCallback", () => { vi.mocked(createUser).mockResolvedValue(mockCreatedUser("Direct Name")); - const result = await handleSsoCallback({ user: openIdUser, account: mockOpenIdAccount }); + const result = await handleSsoCallback({ + user: openIdUser, + account: mockOpenIdAccount, + callbackUrl: "http://localhost:3000", + }); expect(result).toBe(true); expect(createUser).toHaveBeenCalledWith( expect.objectContaining({ name: "Direct Name", email: openIdUser.email, + emailVerified: expect.any(Date), identityProvider: "openid", + identityProviderAccountId: mockOpenIdAccount.providerAccountId, + locale: "en-US", }) ); }); @@ -297,14 +350,21 @@ describe("handleSsoCallback", () => { vi.mocked(createUser).mockResolvedValue(mockCreatedUser("John Doe")); - const result = await handleSsoCallback({ user: openIdUser, account: mockOpenIdAccount }); + const result = await handleSsoCallback({ + user: openIdUser, + account: mockOpenIdAccount, + callbackUrl: "http://localhost:3000", + }); expect(result).toBe(true); expect(createUser).toHaveBeenCalledWith( expect.objectContaining({ name: "John Doe", email: openIdUser.email, + emailVerified: expect.any(Date), identityProvider: "openid", + identityProviderAccountId: mockOpenIdAccount.providerAccountId, + locale: "en-US", }) ); }); @@ -319,14 +379,21 @@ describe("handleSsoCallback", () => { vi.mocked(createUser).mockResolvedValue(mockCreatedUser("preferred.user")); - const result = await handleSsoCallback({ user: openIdUser, account: mockOpenIdAccount }); + const result = await handleSsoCallback({ + user: openIdUser, + account: mockOpenIdAccount, + callbackUrl: "http://localhost:3000", + }); expect(result).toBe(true); expect(createUser).toHaveBeenCalledWith( expect.objectContaining({ name: "preferred.user", email: openIdUser.email, + emailVerified: expect.any(Date), identityProvider: "openid", + identityProviderAccountId: mockOpenIdAccount.providerAccountId, + locale: "en-US", }) ); }); @@ -340,18 +407,152 @@ describe("handleSsoCallback", () => { email: "test.user@example.com", }); - vi.mocked(createUser).mockResolvedValue(mockCreatedUser("test.user")); + vi.mocked(createUser).mockResolvedValue(mockCreatedUser("test user")); - const result = await handleSsoCallback({ user: openIdUser, account: mockOpenIdAccount }); + const result = await handleSsoCallback({ + user: openIdUser, + account: mockOpenIdAccount, + callbackUrl: "http://localhost:3000", + }); expect(result).toBe(true); - expect(createUser).toHaveBeenCalledWith( expect.objectContaining({ + name: "test user", email: openIdUser.email, + emailVerified: expect.any(Date), identityProvider: "openid", + identityProviderAccountId: mockOpenIdAccount.providerAccountId, + locale: "en-US", }) ); }); }); + + describe("SAML name handling", () => { + it("should use samlUser.name when available", async () => { + const samlUser = { + ...mockUser, + name: "Direct Name", + firstName: "John", + lastName: "Doe", + } as TUser & TSamlNameFields; + + vi.mocked(createUser).mockResolvedValue(mockCreatedUser("Direct Name")); + + const result = await handleSsoCallback({ + user: samlUser, + account: mockSamlAccount, + callbackUrl: "http://localhost:3000", + }); + + expect(result).toBe(true); + expect(createUser).toHaveBeenCalledWith( + expect.objectContaining({ + name: "Direct Name", + email: samlUser.email, + emailVerified: expect.any(Date), + identityProvider: "saml", + identityProviderAccountId: mockSamlAccount.providerAccountId, + locale: "en-US", + }) + ); + }); + + it("should use firstName + lastName when name is not available", async () => { + const samlUser = { + ...mockUser, + name: "", + firstName: "John", + lastName: "Doe", + } as TUser & TSamlNameFields; + + vi.mocked(createUser).mockResolvedValue(mockCreatedUser("John Doe")); + + const result = await handleSsoCallback({ + user: samlUser, + account: mockSamlAccount, + callbackUrl: "http://localhost:3000", + }); + + expect(result).toBe(true); + expect(createUser).toHaveBeenCalledWith( + expect.objectContaining({ + name: "John Doe", + email: samlUser.email, + emailVerified: expect.any(Date), + identityProvider: "saml", + identityProviderAccountId: mockSamlAccount.providerAccountId, + locale: "en-US", + }) + ); + }); + }); + + describe("Organization handling", () => { + it("should handle invalid DEFAULT_ORGANIZATION_ID gracefully", async () => { + vi.mocked(prisma.user.findFirst).mockResolvedValue(null); + vi.mocked(getUserByEmail).mockResolvedValue(null); + vi.mocked(createUser).mockResolvedValue(mockCreatedUser()); + vi.mocked(getOrganization).mockResolvedValue(null); + vi.mocked(createOrganization).mockRejectedValue(new Error("Invalid organization ID")); + + await expect( + handleSsoCallback({ + user: mockUser, + account: mockAccount, + callbackUrl: "http://localhost:3000", + }) + ).rejects.toThrow("Invalid organization ID"); + + expect(createOrganization).toHaveBeenCalled(); + expect(createMembership).not.toHaveBeenCalled(); + }); + + it("should handle membership creation failure gracefully", async () => { + vi.mocked(prisma.user.findFirst).mockResolvedValue(null); + vi.mocked(getUserByEmail).mockResolvedValue(null); + vi.mocked(createUser).mockResolvedValue(mockCreatedUser()); + vi.mocked(createMembership).mockRejectedValue(new Error("Failed to create membership")); + + await expect( + handleSsoCallback({ + user: mockUser, + account: mockAccount, + callbackUrl: "http://localhost:3000", + }) + ).rejects.toThrow("Failed to create membership"); + + expect(createMembership).toHaveBeenCalled(); + }); + }); + + describe("Error handling", () => { + it("should handle prisma errors gracefully", async () => { + vi.mocked(prisma.user.findFirst).mockRejectedValue(new Error("Database error")); + + await expect( + handleSsoCallback({ + user: mockUser, + account: mockAccount, + callbackUrl: "http://localhost:3000", + }) + ).rejects.toThrow("Database error"); + }); + + it("should handle locale finding errors gracefully", async () => { + vi.mocked(findMatchingLocale).mockRejectedValue(new Error("Locale error")); + vi.mocked(prisma.user.findFirst).mockResolvedValue(null); + vi.mocked(getUserByEmail).mockResolvedValue(null); + vi.mocked(createUser).mockResolvedValue(mockCreatedUser()); + + await expect( + handleSsoCallback({ + user: mockUser, + account: mockAccount, + callbackUrl: "http://localhost:3000", + }) + ).rejects.toThrow("Locale error"); + }); + }); }); diff --git a/apps/web/modules/ee/sso/lib/tests/utils.test.ts b/apps/web/modules/ee/sso/lib/tests/utils.test.ts new file mode 100644 index 0000000000..6d263ef4e0 --- /dev/null +++ b/apps/web/modules/ee/sso/lib/tests/utils.test.ts @@ -0,0 +1,29 @@ +import { describe, expect, it } from "vitest"; +import { getCallbackUrl } from "../utils"; + +describe("getCallbackUrl", () => { + it("should return base URL with source when no inviteUrl is provided", () => { + const result = getCallbackUrl(undefined, "test-source"); + expect(result).toBe("/?source=test-source"); + }); + + it("should append source parameter to inviteUrl with existing query parameters", () => { + const result = getCallbackUrl("https://example.com/invite?param=value", "test-source"); + expect(result).toBe("https://example.com/invite?param=value&source=test-source"); + }); + + it("should append source parameter to inviteUrl without existing query parameters", () => { + const result = getCallbackUrl("https://example.com/invite", "test-source"); + expect(result).toBe("https://example.com/invite?source=test-source"); + }); + + it("should handle empty source parameter", () => { + const result = getCallbackUrl("https://example.com/invite", ""); + expect(result).toBe("https://example.com/invite?source="); + }); + + it("should handle undefined source parameter", () => { + const result = getCallbackUrl("https://example.com/invite", undefined); + expect(result).toBe("https://example.com/invite?source=undefined"); + }); +}); diff --git a/apps/web/modules/ee/sso/lib/utils.ts b/apps/web/modules/ee/sso/lib/utils.ts new file mode 100644 index 0000000000..a0f9bef776 --- /dev/null +++ b/apps/web/modules/ee/sso/lib/utils.ts @@ -0,0 +1,5 @@ +export const getCallbackUrl = (inviteUrl?: string, source?: string) => { + return inviteUrl + ? `${inviteUrl}${inviteUrl.includes("?") ? "&" : "?"}source=${source}` + : `/?source=${source}`; +}; diff --git a/apps/web/modules/survey/hooks/useSingleUseId.test.tsx b/apps/web/modules/survey/hooks/useSingleUseId.test.tsx index 206b5d01d1..e429ac2ca1 100644 --- a/apps/web/modules/survey/hooks/useSingleUseId.test.tsx +++ b/apps/web/modules/survey/hooks/useSingleUseId.test.tsx @@ -8,7 +8,7 @@ import { useSingleUseId } from "./useSingleUseId"; // Mock external functions vi.mock("@/modules/survey/list/actions", () => ({ - generateSingleUseIdAction: vi.fn(), + generateSingleUseIdAction: vi.fn().mockResolvedValue({ data: "initialId" }), })); vi.mock("@/lib/utils/helper", () => ({ @@ -88,8 +88,10 @@ describe("useSingleUseId", () => { vi.mocked(generateSingleUseIdAction).mockResolvedValueOnce({ data: "initialId" }); const { result } = renderHook(() => useSingleUseId(mockSurvey)); - // Wait for initial - await new Promise((r) => setTimeout(r, 0)); + // Wait for initial value to be set + await act(async () => { + await new Promise((r) => setTimeout(r, 0)); + }); expect(result.current.singleUseId).toBe("initialId"); vi.mocked(generateSingleUseIdAction).mockResolvedValueOnce({ data: "refreshedId" }); diff --git a/apps/web/vite.config.mts b/apps/web/vite.config.mts index aa6fe1497d..6fa694f853 100644 --- a/apps/web/vite.config.mts +++ b/apps/web/vite.config.mts @@ -19,7 +19,8 @@ export default defineConfig({ "modules/api/v2/**/*.ts", "modules/api/v2/**/*.tsx", "modules/auth/lib/**/*.ts", - "modules/signup/lib/**/*.ts", + "modules/auth/signup/lib/**/*.ts", + "modules/auth/signup/**/*.tsx", "modules/ee/whitelabel/email-customization/components/*.tsx", "modules/ee/role-management/components/*.tsx", "modules/organization/settings/teams/components/edit-memberships/organization-actions.tsx", @@ -52,6 +53,7 @@ export default defineConfig({ "modules/survey/list/components/survey-card.tsx", "modules/survey/list/components/survey-dropdown-menu.tsx", "modules/ee/contacts/api/v2/management/contacts/bulk/lib/contact.ts", + "modules/ee/sso/components/**/*.tsx", ], exclude: [ "**/.next/**", diff --git a/packages/lib/membership/service.ts b/packages/lib/membership/service.ts index 02a1544f50..2254371a81 100644 --- a/packages/lib/membership/service.ts +++ b/packages/lib/membership/service.ts @@ -54,6 +54,19 @@ export const createMembership = async ( validateInputs([organizationId, ZString], [userId, ZString], [data, ZMembership.partial()]); try { + const existingMembership = await prisma.membership.findUnique({ + where: { + userId_organizationId: { + userId, + organizationId, + }, + }, + }); + + if (existingMembership) { + return existingMembership; + } + const membership = await prisma.membership.create({ data: { userId,