add embedding service

This commit is contained in:
Matthias Nannt
2024-08-05 16:47:18 +02:00
parent ab1fe677d9
commit 14e3bb07ec
5 changed files with 75 additions and 22 deletions

View File

@@ -663,11 +663,10 @@ enum EmbeddingType {
}
model Embedding {
id String @id @default(cuid())
referenceId String @id
createdAt DateTime @default(now()) @map(name: "created_at")
updatedAt DateTime @updatedAt @map(name: "updated_at")
type EmbeddingType
referenceId String @unique
vector Unsupported("vector(512)")?
@@index([type, referenceId])

View File

@@ -0,0 +1,18 @@
import { revalidateTag } from "next/cache";
interface RevalidateProps {
referenceId?: string;
}
export const embeddingCache = {
tag: {
byReferenceId(id: string) {
return `environments-${id}`;
},
},
revalidate({ referenceId }: RevalidateProps): void {
if (referenceId) {
revalidateTag(this.tag.byReferenceId(referenceId));
}
},
};

View File

@@ -0,0 +1,39 @@
import "server-only";
import { Prisma } from "@prisma/client";
import { prisma } from "@formbricks/database";
import { TEmbedding, TEmbeddingCreateInput, ZEmbeddingCreateInput } from "@formbricks/types/embedding";
import { ZId } from "@formbricks/types/environment";
import { DatabaseError } from "@formbricks/types/errors";
import { validateInputs } from "../utils/validate";
import { embeddingCache } from "./cache";
export const createEmbedding = async (
productId: string,
embeddingInput: TEmbeddingCreateInput
): Promise<TEmbedding> => {
validateInputs([productId, ZId], [embeddingInput, ZEmbeddingCreateInput]);
try {
const vectorString = embeddingInput.vector.join(",");
const result = await prisma.$executeRaw`
INSERT INTO Embedding (referenceId, created_at, updated_at, type, vector)
VALUES (${embeddingInput.referenceId}, NOW(), NOW(), ${embeddingInput.type}, '${vectorString}')
`;
const embedding: TEmbedding = await prisma.$queryRaw`
SELECT * FROM Embedding WHERE referenceId = ${embeddingInput.referenceId}
`;
embeddingCache.revalidate({
referenceId: embedding.referenceId,
});
return embedding;
} catch (error) {
if (error instanceof Prisma.PrismaClientKnownRequestError) {
throw new DatabaseError(error.message);
}
throw error;
}
};

View File

@@ -1319,22 +1319,3 @@ export const getResponseHiddenFields = (
throw error;
}
};
export const getResponseAsDocumentString = (response: TResponse, survey: TSurvey): string => {
// generate text representation of response
let text = "";
let first = true;
survey.questions.forEach((question) => {
if (first) {
first = false;
} else {
text += "\n\n";
}
const answer = response.data[question.id];
if (answer) {
text += `${getLocalizedValue(question.headline, response.language || "default")}\nAnswer: ${answer}`;
}
});
return text;
};

View File

@@ -1,5 +1,21 @@
import { z } from "zod";
export const ZEmbedding = z.array(z.number()).length(1024);
export const ZEmbeddingType = z.enum(["questionResponse"]);
export const ZEmbedding = z.object({
referenceId: z.string(),
createdAt: z.date(),
updatedAt: z.date(),
type: ZEmbeddingType,
vector: z.array(z.number()).length(512),
});
export type TEmbedding = z.infer<typeof ZEmbedding>;
export const ZEmbeddingCreateInput = z.object({
type: ZEmbeddingType,
referenceId: z.string(),
vector: z.array(z.number()).length(512),
});
export type TEmbeddingCreateInput = z.infer<typeof ZEmbeddingCreateInput>;