fix embeddings creation, add embeddings retrieval

This commit is contained in:
Matthias Nannt
2024-08-05 20:37:45 +02:00
parent 14e3bb07ec
commit cc9ea82e5c
12 changed files with 25631 additions and 48 deletions

View File

@@ -1,10 +1,13 @@
"use server";
import { getEmailTemplateHtml } from "@/app/(app)/environments/[environmentId]/surveys/[surveyId]/(analysis)/summary/lib/emailTemplate";
import { get } from "lodash";
import { customAlphabet } from "nanoid";
import { getServerSession } from "next-auth";
import { sendEmbedSurveyPreviewEmail } from "@formbricks/email";
import { authOptions } from "@formbricks/lib/authOptions";
import { getEmbeddingsByTypeAndReferenceId } from "@formbricks/lib/embedding/service";
import { getQuestionResponseReferenceId } from "@formbricks/lib/embedding/utils";
import { canUserAccessSurvey } from "@formbricks/lib/survey/auth";
import { getSurvey, updateSurvey } from "@formbricks/lib/survey/service";
import { getUser } from "@formbricks/lib/user/service";
@@ -105,3 +108,20 @@ export const getEmailHtmlAction = async (surveyId: string) => {
return await getEmailTemplateHtml(surveyId);
};
export const getOpenTextSummaryAction = async (surveyId: string, questionId: string) => {
const session = await getServerSession(authOptions);
if (!session) throw new AuthorizationError("Not authorized");
const hasUserSurveyAccess = await canUserAccessSurvey(session.user.id, surveyId);
if (!hasUserSurveyAccess) throw new AuthorizationError("Not authorized");
const embeddings = await getEmbeddingsByTypeAndReferenceId(
"questionResponse",
getQuestionResponseReferenceId(surveyId, questionId)
);
console.log(embeddings);
return;
};

View File

@@ -1,3 +1,4 @@
import { getOpenTextSummaryAction } from "@/app/(app)/environments/[environmentId]/surveys/[surveyId]/(analysis)/summary/actions";
import Link from "next/link";
import { useState } from "react";
import { getPersonIdentifier } from "@formbricks/lib/person/utils";
@@ -30,6 +31,11 @@ export const OpenTextSummary = ({
);
};
const getOpenTextSummary = async () => {
// This function is not implemented yet
await getOpenTextSummaryAction(survey.id, questionSummary.question.id);
};
return (
<div className="rounded-xl border border-slate-200 bg-white shadow-sm">
<QuestionSummaryHeader
@@ -37,6 +43,7 @@ export const OpenTextSummary = ({
survey={survey}
attributeClasses={attributeClasses}
/>
<Button onClick={() => getOpenTextSummary()}>Create Summary</Button>
<div className="">
<div className="grid h-10 grid-cols-4 items-center border-y border-slate-200 bg-slate-100 text-sm font-bold text-slate-600">
<div className="pl-4 md:pl-6">User</div>

View File

@@ -6,12 +6,12 @@ import { prisma } from "@formbricks/database";
import { sendResponseFinishedEmail } from "@formbricks/email";
import { embeddingsModel } from "@formbricks/lib/ai";
import { INTERNAL_SECRET, IS_AI_ENABLED, IS_FORMBRICKS_CLOUD } from "@formbricks/lib/constants";
import { createEmbedding } from "@formbricks/lib/embedding/service";
import { getQuestionResponseReferenceId } from "@formbricks/lib/embedding/utils";
import { getIntegrations } from "@formbricks/lib/integration/service";
import { getOrganizationByEnvironmentId } from "@formbricks/lib/organization/service";
import { getProductByEnvironmentId } from "@formbricks/lib/product/service";
import { updateResponseEmbedding } from "@formbricks/lib/response/embedding";
import { getResponseCountBySurveyId } from "@formbricks/lib/response/service";
import { getResponseAsDocumentString } from "@formbricks/lib/response/utils";
import { getSurvey, updateSurvey } from "@formbricks/lib/survey/service";
import { convertDatesInObject } from "@formbricks/lib/time";
import { ZPipelineInput } from "@formbricks/types/pipelines";
@@ -165,20 +165,34 @@ export const POST = async (request: Request) => {
// generate embeddings for all open text question responses for enterprise and scale plans
const hasSurveyOpenTextQuestions = survey.questions.some((question) => question.type === "openText");
console.log("hasSurveyOpenTextQuestions", hasSurveyOpenTextQuestions);
console.log("is Cloud", hasSurveyOpenTextQuestions && IS_FORMBRICKS_CLOUD && IS_AI_ENABLED);
if (hasSurveyOpenTextQuestions && IS_FORMBRICKS_CLOUD && IS_AI_ENABLED) {
const organization = await getOrganizationByEnvironmentId(environmentId);
if (!organization) {
throw new Error("Organization not found");
}
console.log(
"valid billing plan",
organization.billing.plan === "enterprise" || organization.billing.plan === "scale"
);
if (organization.billing.plan === "enterprise" || organization.billing.plan === "scale") {
for (const question of survey.questions) {
if (question.type === "openText") {
const isQuestionAnswered = response[question.id] !== undefined;
const isQuestionAnswered = response.data[question.id] !== undefined;
console.log("isQuestionAnswered", isQuestionAnswered);
if (!isQuestionAnswered) {
continue;
}
const responseEmbedding = await embeddingsModel.embed(response[question.id]);
await updateResponseEmbedding(response.id, question.id, responseEmbedding);
const { embedding } = await embed({
model: embeddingsModel,
value: `${question.headline.default} Answer: ${response.data[question.id]}`,
});
await createEmbedding({
referenceId: getQuestionResponseReferenceId(survey.id, question.id),
type: "questionResponse",
vector: embedding,
});
}
}
}

View File

@@ -16,8 +16,5 @@ CREATE TABLE "Embedding" (
CONSTRAINT "Embedding_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "Embedding_referenceId_key" ON "Embedding"("referenceId");
-- CreateIndex
CREATE INDEX "Embedding_type_referenceId_idx" ON "Embedding"("type", "referenceId");

View File

@@ -663,10 +663,11 @@ enum EmbeddingType {
}
model Embedding {
referenceId String @id
id String @id @default(cuid())
createdAt DateTime @default(now()) @map(name: "created_at")
updatedAt DateTime @updatedAt @map(name: "updated_at")
type EmbeddingType
referenceId String
vector Unsupported("vector(512)")?
@@index([type, referenceId])

View File

@@ -8,5 +8,5 @@ const azure = createAzure({
export const llmModel = azure(env.AI_AZURE_LLM_DEPLOYMENT_ID || "llm");
export const embeddingsModel = azure.embedding(env.AI_AZURE_EMBEDDINGS_DEPLOYMENT_ID || "embeddings", {
dimensions: 1024,
dimensions: 512,
});

View File

@@ -210,5 +210,8 @@ export const BILLING_LIMITS = {
},
} as const;
export const IS_AI_ENABLED =
env.AI_AZURE_RESSOURCE_NAME && env.AI_AZURE_API_KEY && env.AI_AZURE_EMBEDDINGS_DEPLOYMENT_ID;
export const IS_AI_ENABLED = !!(
env.AI_AZURE_RESSOURCE_NAME &&
env.AI_AZURE_API_KEY &&
env.AI_AZURE_EMBEDDINGS_DEPLOYMENT_ID
);

View File

@@ -1,18 +1,26 @@
import { revalidateTag } from "next/cache";
interface RevalidateProps {
id?: string;
type?: string;
referenceId?: string;
}
export const embeddingCache = {
tag: {
byReferenceId(id: string) {
return `environments-${id}`;
byId(id: string) {
return `embeddings-${id}`;
},
byTypeAndReferenceId(type: string, id: string) {
return `embeddings-${type}-${id}`;
},
},
revalidate({ referenceId }: RevalidateProps): void {
if (referenceId) {
revalidateTag(this.tag.byReferenceId(referenceId));
revalidate({ id, type, referenceId }: RevalidateProps): void {
if (id) {
revalidateTag(this.tag.byId(referenceId));
}
if (type && referenceId) {
revalidateTag(this.tag.byTypeAndReferenceId(type, referenceId));
}
},
};

View File

@@ -1,30 +1,46 @@
import "server-only";
import { Prisma } from "@prisma/client";
import { cache as reactCache } from "react";
import { prisma } from "@formbricks/database";
import { TEmbedding, TEmbeddingCreateInput, ZEmbeddingCreateInput } from "@formbricks/types/embedding";
import { ZId } from "@formbricks/types/environment";
import { ZString } from "@formbricks/types/common";
import {
TEmbedding,
TEmbeddingCreateInput,
ZEmbeddingCreateInput,
ZEmbeddingType,
} from "@formbricks/types/embedding";
import { DatabaseError } from "@formbricks/types/errors";
import { cache } from "../cache";
import { validateInputs } from "../utils/validate";
import { embeddingCache } from "./cache";
export const createEmbedding = async (
productId: string,
embeddingInput: TEmbeddingCreateInput
): Promise<TEmbedding> => {
validateInputs([productId, ZId], [embeddingInput, ZEmbeddingCreateInput]);
export type TPrismaEmbedding = Omit<TEmbedding, "vector"> & {
vector: string;
};
export const createEmbedding = async (embeddingInput: TEmbeddingCreateInput): Promise<TEmbedding> => {
validateInputs([embeddingInput, ZEmbeddingCreateInput]);
try {
const vectorString = embeddingInput.vector.join(",");
const { vector, ...data } = embeddingInput;
const result = await prisma.$executeRaw`
INSERT INTO Embedding (referenceId, created_at, updated_at, type, vector)
VALUES (${embeddingInput.referenceId}, NOW(), NOW(), ${embeddingInput.type}, '${vectorString}')
const prismaEmbedding = await prisma.embedding.create({
data,
});
const embedding = {
...prismaEmbedding,
vector,
};
// update vector
const vectorString = `[${vector.join(",")}]`;
await prisma.$executeRaw`
UPDATE "Embedding"
SET "vector" = ${vectorString}::vector(512)
WHERE "id" = ${embedding.id};
`;
const embedding: TEmbedding = await prisma.$queryRaw`
SELECT * FROM Embedding WHERE referenceId = ${embeddingInput.referenceId}
`;
embeddingCache.revalidate({
referenceId: embedding.referenceId,
});
@@ -37,3 +53,51 @@ SELECT * FROM Embedding WHERE referenceId = ${embeddingInput.referenceId}
throw error;
}
};
export const getEmbeddingsByTypeAndReferenceId = reactCache(
(type: string, referenceId: string): Promise<TEmbedding[]> =>
cache(
async () => {
validateInputs([type, ZEmbeddingType], [referenceId, ZString]);
try {
const prismaEmbeddings: TPrismaEmbedding[] = await prisma.$queryRaw`
SELECT
id,
created_at AS "createdAt",
updated_at AS "updatedAt",
type,
"referenceId",
vector::text
FROM "Embedding" e
WHERE e."type" = ${type}::"EmbeddingType"
AND e."referenceId" = ${referenceId}
`;
const embeddings = prismaEmbeddings.map((prismaEmbedding) => {
// Convert the string representation of the embedding back to an array of numbers
const vector = prismaEmbedding.vector
.slice(1, -1) // Remove the surrounding square brackets
.split(",") // Split the string into an array of strings
.map(Number); // Convert each string to a number
return {
...prismaEmbedding,
vector,
};
});
return embeddings;
} catch (error) {
if (error instanceof Prisma.PrismaClientKnownRequestError) {
console.error(error);
throw new DatabaseError(error.message);
}
throw error;
}
},
[`getEmbeddingsByTypeAndReferenceId-${type}-${referenceId}`],
{
tags: [embeddingCache.tag.byTypeAndReferenceId(type, referenceId)],
}
)()
);

View File

@@ -0,0 +1,3 @@
export const getQuestionResponseReferenceId = (surveyId: string, questionId: string) => {
return `${surveyId}-${questionId}`;
};

View File

@@ -1,16 +0,0 @@
import { prisma } from "@formbricks/database";
import { type TEmbedding, ZEmbedding } from "@formbricks/types/embedding";
import { ZId } from "@formbricks/types/environment";
import { validateInputs } from "../utils/validate";
export const updateResponseEmbedding = async (id: string, embedding: TEmbedding) => {
validateInputs([id, ZId], [embedding, ZEmbedding]);
// Convert the embedding array to a string representation that PostgreSQL understands
const embeddingString = `[${embedding.join(",")}]`;
await prisma.$executeRaw`
UPDATE "Response"
SET "embedding" = ${embeddingString}::vector(1024)
WHERE "id" = ${id};
`;
};

25482
pnpm-lock.yaml generated Normal file

File diff suppressed because it is too large Load Diff