mirror of
https://github.com/formbricks/formbricks.git
synced 2025-12-30 10:19:51 -06:00
fix embeddings creation, add embeddings retrieval
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
"use server";
|
||||
|
||||
import { getEmailTemplateHtml } from "@/app/(app)/environments/[environmentId]/surveys/[surveyId]/(analysis)/summary/lib/emailTemplate";
|
||||
import { get } from "lodash";
|
||||
import { customAlphabet } from "nanoid";
|
||||
import { getServerSession } from "next-auth";
|
||||
import { sendEmbedSurveyPreviewEmail } from "@formbricks/email";
|
||||
import { authOptions } from "@formbricks/lib/authOptions";
|
||||
import { getEmbeddingsByTypeAndReferenceId } from "@formbricks/lib/embedding/service";
|
||||
import { getQuestionResponseReferenceId } from "@formbricks/lib/embedding/utils";
|
||||
import { canUserAccessSurvey } from "@formbricks/lib/survey/auth";
|
||||
import { getSurvey, updateSurvey } from "@formbricks/lib/survey/service";
|
||||
import { getUser } from "@formbricks/lib/user/service";
|
||||
@@ -105,3 +108,20 @@ export const getEmailHtmlAction = async (surveyId: string) => {
|
||||
|
||||
return await getEmailTemplateHtml(surveyId);
|
||||
};
|
||||
|
||||
export const getOpenTextSummaryAction = async (surveyId: string, questionId: string) => {
|
||||
const session = await getServerSession(authOptions);
|
||||
if (!session) throw new AuthorizationError("Not authorized");
|
||||
|
||||
const hasUserSurveyAccess = await canUserAccessSurvey(session.user.id, surveyId);
|
||||
if (!hasUserSurveyAccess) throw new AuthorizationError("Not authorized");
|
||||
|
||||
const embeddings = await getEmbeddingsByTypeAndReferenceId(
|
||||
"questionResponse",
|
||||
getQuestionResponseReferenceId(surveyId, questionId)
|
||||
);
|
||||
|
||||
console.log(embeddings);
|
||||
|
||||
return;
|
||||
};
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { getOpenTextSummaryAction } from "@/app/(app)/environments/[environmentId]/surveys/[surveyId]/(analysis)/summary/actions";
|
||||
import Link from "next/link";
|
||||
import { useState } from "react";
|
||||
import { getPersonIdentifier } from "@formbricks/lib/person/utils";
|
||||
@@ -30,6 +31,11 @@ export const OpenTextSummary = ({
|
||||
);
|
||||
};
|
||||
|
||||
const getOpenTextSummary = async () => {
|
||||
// This function is not implemented yet
|
||||
await getOpenTextSummaryAction(survey.id, questionSummary.question.id);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="rounded-xl border border-slate-200 bg-white shadow-sm">
|
||||
<QuestionSummaryHeader
|
||||
@@ -37,6 +43,7 @@ export const OpenTextSummary = ({
|
||||
survey={survey}
|
||||
attributeClasses={attributeClasses}
|
||||
/>
|
||||
<Button onClick={() => getOpenTextSummary()}>Create Summary</Button>
|
||||
<div className="">
|
||||
<div className="grid h-10 grid-cols-4 items-center border-y border-slate-200 bg-slate-100 text-sm font-bold text-slate-600">
|
||||
<div className="pl-4 md:pl-6">User</div>
|
||||
|
||||
@@ -6,12 +6,12 @@ import { prisma } from "@formbricks/database";
|
||||
import { sendResponseFinishedEmail } from "@formbricks/email";
|
||||
import { embeddingsModel } from "@formbricks/lib/ai";
|
||||
import { INTERNAL_SECRET, IS_AI_ENABLED, IS_FORMBRICKS_CLOUD } from "@formbricks/lib/constants";
|
||||
import { createEmbedding } from "@formbricks/lib/embedding/service";
|
||||
import { getQuestionResponseReferenceId } from "@formbricks/lib/embedding/utils";
|
||||
import { getIntegrations } from "@formbricks/lib/integration/service";
|
||||
import { getOrganizationByEnvironmentId } from "@formbricks/lib/organization/service";
|
||||
import { getProductByEnvironmentId } from "@formbricks/lib/product/service";
|
||||
import { updateResponseEmbedding } from "@formbricks/lib/response/embedding";
|
||||
import { getResponseCountBySurveyId } from "@formbricks/lib/response/service";
|
||||
import { getResponseAsDocumentString } from "@formbricks/lib/response/utils";
|
||||
import { getSurvey, updateSurvey } from "@formbricks/lib/survey/service";
|
||||
import { convertDatesInObject } from "@formbricks/lib/time";
|
||||
import { ZPipelineInput } from "@formbricks/types/pipelines";
|
||||
@@ -165,20 +165,34 @@ export const POST = async (request: Request) => {
|
||||
|
||||
// generate embeddings for all open text question responses for enterprise and scale plans
|
||||
const hasSurveyOpenTextQuestions = survey.questions.some((question) => question.type === "openText");
|
||||
console.log("hasSurveyOpenTextQuestions", hasSurveyOpenTextQuestions);
|
||||
console.log("is Cloud", hasSurveyOpenTextQuestions && IS_FORMBRICKS_CLOUD && IS_AI_ENABLED);
|
||||
if (hasSurveyOpenTextQuestions && IS_FORMBRICKS_CLOUD && IS_AI_ENABLED) {
|
||||
const organization = await getOrganizationByEnvironmentId(environmentId);
|
||||
if (!organization) {
|
||||
throw new Error("Organization not found");
|
||||
}
|
||||
console.log(
|
||||
"valid billing plan",
|
||||
organization.billing.plan === "enterprise" || organization.billing.plan === "scale"
|
||||
);
|
||||
if (organization.billing.plan === "enterprise" || organization.billing.plan === "scale") {
|
||||
for (const question of survey.questions) {
|
||||
if (question.type === "openText") {
|
||||
const isQuestionAnswered = response[question.id] !== undefined;
|
||||
const isQuestionAnswered = response.data[question.id] !== undefined;
|
||||
console.log("isQuestionAnswered", isQuestionAnswered);
|
||||
if (!isQuestionAnswered) {
|
||||
continue;
|
||||
}
|
||||
const responseEmbedding = await embeddingsModel.embed(response[question.id]);
|
||||
await updateResponseEmbedding(response.id, question.id, responseEmbedding);
|
||||
const { embedding } = await embed({
|
||||
model: embeddingsModel,
|
||||
value: `${question.headline.default} Answer: ${response.data[question.id]}`,
|
||||
});
|
||||
await createEmbedding({
|
||||
referenceId: getQuestionResponseReferenceId(survey.id, question.id),
|
||||
type: "questionResponse",
|
||||
vector: embedding,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,8 +16,5 @@ CREATE TABLE "Embedding" (
|
||||
CONSTRAINT "Embedding_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "Embedding_referenceId_key" ON "Embedding"("referenceId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "Embedding_type_referenceId_idx" ON "Embedding"("type", "referenceId");
|
||||
@@ -663,10 +663,11 @@ enum EmbeddingType {
|
||||
}
|
||||
|
||||
model Embedding {
|
||||
referenceId String @id
|
||||
id String @id @default(cuid())
|
||||
createdAt DateTime @default(now()) @map(name: "created_at")
|
||||
updatedAt DateTime @updatedAt @map(name: "updated_at")
|
||||
type EmbeddingType
|
||||
referenceId String
|
||||
vector Unsupported("vector(512)")?
|
||||
|
||||
@@index([type, referenceId])
|
||||
|
||||
@@ -8,5 +8,5 @@ const azure = createAzure({
|
||||
|
||||
export const llmModel = azure(env.AI_AZURE_LLM_DEPLOYMENT_ID || "llm");
|
||||
export const embeddingsModel = azure.embedding(env.AI_AZURE_EMBEDDINGS_DEPLOYMENT_ID || "embeddings", {
|
||||
dimensions: 1024,
|
||||
dimensions: 512,
|
||||
});
|
||||
|
||||
@@ -210,5 +210,8 @@ export const BILLING_LIMITS = {
|
||||
},
|
||||
} as const;
|
||||
|
||||
export const IS_AI_ENABLED =
|
||||
env.AI_AZURE_RESSOURCE_NAME && env.AI_AZURE_API_KEY && env.AI_AZURE_EMBEDDINGS_DEPLOYMENT_ID;
|
||||
export const IS_AI_ENABLED = !!(
|
||||
env.AI_AZURE_RESSOURCE_NAME &&
|
||||
env.AI_AZURE_API_KEY &&
|
||||
env.AI_AZURE_EMBEDDINGS_DEPLOYMENT_ID
|
||||
);
|
||||
|
||||
@@ -1,18 +1,26 @@
|
||||
import { revalidateTag } from "next/cache";
|
||||
|
||||
interface RevalidateProps {
|
||||
id?: string;
|
||||
type?: string;
|
||||
referenceId?: string;
|
||||
}
|
||||
|
||||
export const embeddingCache = {
|
||||
tag: {
|
||||
byReferenceId(id: string) {
|
||||
return `environments-${id}`;
|
||||
byId(id: string) {
|
||||
return `embeddings-${id}`;
|
||||
},
|
||||
byTypeAndReferenceId(type: string, id: string) {
|
||||
return `embeddings-${type}-${id}`;
|
||||
},
|
||||
},
|
||||
revalidate({ referenceId }: RevalidateProps): void {
|
||||
if (referenceId) {
|
||||
revalidateTag(this.tag.byReferenceId(referenceId));
|
||||
revalidate({ id, type, referenceId }: RevalidateProps): void {
|
||||
if (id) {
|
||||
revalidateTag(this.tag.byId(referenceId));
|
||||
}
|
||||
if (type && referenceId) {
|
||||
revalidateTag(this.tag.byTypeAndReferenceId(type, referenceId));
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
@@ -1,30 +1,46 @@
|
||||
import "server-only";
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { cache as reactCache } from "react";
|
||||
import { prisma } from "@formbricks/database";
|
||||
import { TEmbedding, TEmbeddingCreateInput, ZEmbeddingCreateInput } from "@formbricks/types/embedding";
|
||||
import { ZId } from "@formbricks/types/environment";
|
||||
import { ZString } from "@formbricks/types/common";
|
||||
import {
|
||||
TEmbedding,
|
||||
TEmbeddingCreateInput,
|
||||
ZEmbeddingCreateInput,
|
||||
ZEmbeddingType,
|
||||
} from "@formbricks/types/embedding";
|
||||
import { DatabaseError } from "@formbricks/types/errors";
|
||||
import { cache } from "../cache";
|
||||
import { validateInputs } from "../utils/validate";
|
||||
import { embeddingCache } from "./cache";
|
||||
|
||||
export const createEmbedding = async (
|
||||
productId: string,
|
||||
embeddingInput: TEmbeddingCreateInput
|
||||
): Promise<TEmbedding> => {
|
||||
validateInputs([productId, ZId], [embeddingInput, ZEmbeddingCreateInput]);
|
||||
export type TPrismaEmbedding = Omit<TEmbedding, "vector"> & {
|
||||
vector: string;
|
||||
};
|
||||
|
||||
export const createEmbedding = async (embeddingInput: TEmbeddingCreateInput): Promise<TEmbedding> => {
|
||||
validateInputs([embeddingInput, ZEmbeddingCreateInput]);
|
||||
|
||||
try {
|
||||
const vectorString = embeddingInput.vector.join(",");
|
||||
const { vector, ...data } = embeddingInput;
|
||||
|
||||
const result = await prisma.$executeRaw`
|
||||
INSERT INTO Embedding (referenceId, created_at, updated_at, type, vector)
|
||||
VALUES (${embeddingInput.referenceId}, NOW(), NOW(), ${embeddingInput.type}, '${vectorString}')
|
||||
const prismaEmbedding = await prisma.embedding.create({
|
||||
data,
|
||||
});
|
||||
|
||||
const embedding = {
|
||||
...prismaEmbedding,
|
||||
vector,
|
||||
};
|
||||
|
||||
// update vector
|
||||
const vectorString = `[${vector.join(",")}]`;
|
||||
await prisma.$executeRaw`
|
||||
UPDATE "Embedding"
|
||||
SET "vector" = ${vectorString}::vector(512)
|
||||
WHERE "id" = ${embedding.id};
|
||||
`;
|
||||
|
||||
const embedding: TEmbedding = await prisma.$queryRaw`
|
||||
SELECT * FROM Embedding WHERE referenceId = ${embeddingInput.referenceId}
|
||||
`;
|
||||
|
||||
embeddingCache.revalidate({
|
||||
referenceId: embedding.referenceId,
|
||||
});
|
||||
@@ -37,3 +53,51 @@ SELECT * FROM Embedding WHERE referenceId = ${embeddingInput.referenceId}
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
export const getEmbeddingsByTypeAndReferenceId = reactCache(
|
||||
(type: string, referenceId: string): Promise<TEmbedding[]> =>
|
||||
cache(
|
||||
async () => {
|
||||
validateInputs([type, ZEmbeddingType], [referenceId, ZString]);
|
||||
|
||||
try {
|
||||
const prismaEmbeddings: TPrismaEmbedding[] = await prisma.$queryRaw`
|
||||
SELECT
|
||||
id,
|
||||
created_at AS "createdAt",
|
||||
updated_at AS "updatedAt",
|
||||
type,
|
||||
"referenceId",
|
||||
vector::text
|
||||
FROM "Embedding" e
|
||||
WHERE e."type" = ${type}::"EmbeddingType"
|
||||
AND e."referenceId" = ${referenceId}
|
||||
`;
|
||||
|
||||
const embeddings = prismaEmbeddings.map((prismaEmbedding) => {
|
||||
// Convert the string representation of the embedding back to an array of numbers
|
||||
const vector = prismaEmbedding.vector
|
||||
.slice(1, -1) // Remove the surrounding square brackets
|
||||
.split(",") // Split the string into an array of strings
|
||||
.map(Number); // Convert each string to a number
|
||||
return {
|
||||
...prismaEmbedding,
|
||||
vector,
|
||||
};
|
||||
});
|
||||
|
||||
return embeddings;
|
||||
} catch (error) {
|
||||
if (error instanceof Prisma.PrismaClientKnownRequestError) {
|
||||
console.error(error);
|
||||
throw new DatabaseError(error.message);
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[`getEmbeddingsByTypeAndReferenceId-${type}-${referenceId}`],
|
||||
{
|
||||
tags: [embeddingCache.tag.byTypeAndReferenceId(type, referenceId)],
|
||||
}
|
||||
)()
|
||||
);
|
||||
|
||||
3
packages/lib/embedding/utils.ts
Normal file
3
packages/lib/embedding/utils.ts
Normal file
@@ -0,0 +1,3 @@
|
||||
export const getQuestionResponseReferenceId = (surveyId: string, questionId: string) => {
|
||||
return `${surveyId}-${questionId}`;
|
||||
};
|
||||
@@ -1,16 +0,0 @@
|
||||
import { prisma } from "@formbricks/database";
|
||||
import { type TEmbedding, ZEmbedding } from "@formbricks/types/embedding";
|
||||
import { ZId } from "@formbricks/types/environment";
|
||||
import { validateInputs } from "../utils/validate";
|
||||
|
||||
export const updateResponseEmbedding = async (id: string, embedding: TEmbedding) => {
|
||||
validateInputs([id, ZId], [embedding, ZEmbedding]);
|
||||
// Convert the embedding array to a string representation that PostgreSQL understands
|
||||
const embeddingString = `[${embedding.join(",")}]`;
|
||||
|
||||
await prisma.$executeRaw`
|
||||
UPDATE "Response"
|
||||
SET "embedding" = ${embeddingString}::vector(1024)
|
||||
WHERE "id" = ${id};
|
||||
`;
|
||||
};
|
||||
25482
pnpm-lock.yaml
generated
Normal file
25482
pnpm-lock.yaml
generated
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user