mirror of
https://github.com/mudler/LocalAI.git
synced 2026-02-11 20:49:15 -06:00
feat(api): Add transcribe response format request parameter & adjust STT backends (#8318)
* WIP response format implementation for audio transcriptions (cherry picked from commit e271dd764bbc13846accf3beb8b6522153aa276f) Signed-off-by: Andres Smith <andressmithdev@pm.me> * Rework transcript response_format and add more formats (cherry picked from commit 6a93a8f63e2ee5726bca2980b0c9cf4ef8b7aeb8) Signed-off-by: Andres Smith <andressmithdev@pm.me> * Add test and replace go-openai package with official openai go client (cherry picked from commit f25d1a04e46526429c89db4c739e1e65942ca893) Signed-off-by: Andres Smith <andressmithdev@pm.me> * Fix faster-whisper backend and refactor transcription formatting to also work on CLI Signed-off-by: Andres Smith <andressmithdev@pm.me> (cherry picked from commit 69a93977d5e113eb7172bd85a0f918592d3d2168) Signed-off-by: Andres Smith <andressmithdev@pm.me> --------- Signed-off-by: Andres Smith <andressmithdev@pm.me> Co-authored-by: nanoandrew4 <nanoandrew4@gmail.com> Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
@@ -11,13 +11,14 @@ import (
|
||||
"github.com/docker/go-connections/nat"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/openai/openai-go/v3"
|
||||
"github.com/openai/openai-go/v3/option"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
)
|
||||
|
||||
var container testcontainers.Container
|
||||
var client *openai.Client
|
||||
var client openai.Client
|
||||
|
||||
var containerImage = os.Getenv("LOCALAI_IMAGE")
|
||||
var containerImageTag = os.Getenv("LOCALAI_IMAGE_TAG")
|
||||
@@ -37,26 +38,22 @@ func TestLocalAI(t *testing.T) {
|
||||
|
||||
var _ = BeforeSuite(func() {
|
||||
|
||||
var defaultConfig openai.ClientConfig
|
||||
if apiEndpoint == "" {
|
||||
startDockerImage()
|
||||
apiPort, err := container.MappedPort(context.Background(), nat.Port(defaultApiPort))
|
||||
apiPort, err := container.MappedPort(context.Background(), defaultApiPort)
|
||||
Expect(err).To(Not(HaveOccurred()))
|
||||
|
||||
defaultConfig = openai.DefaultConfig(apiKey)
|
||||
apiEndpoint = "http://localhost:" + apiPort.Port() + "/v1" // So that other tests can reference this value safely.
|
||||
defaultConfig.BaseURL = apiEndpoint
|
||||
} else {
|
||||
GinkgoWriter.Printf("docker apiEndpoint set from env: %q\n", apiEndpoint)
|
||||
defaultConfig = openai.DefaultConfig(apiKey)
|
||||
defaultConfig.BaseURL = apiEndpoint
|
||||
}
|
||||
opts := []option.RequestOption{option.WithAPIKey(apiKey), option.WithBaseURL(apiEndpoint)}
|
||||
|
||||
// Wait for API to be ready
|
||||
client = openai.NewClientWithConfig(defaultConfig)
|
||||
client = openai.NewClient(opts...)
|
||||
|
||||
Eventually(func() error {
|
||||
_, err := client.ListModels(context.TODO())
|
||||
_, err := client.Models.List(context.TODO())
|
||||
return err
|
||||
}, "50m").ShouldNot(HaveOccurred())
|
||||
})
|
||||
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
"github.com/openai/openai-go/v3"
|
||||
"github.com/openai/openai-go/v3/option"
|
||||
)
|
||||
|
||||
var _ = Describe("E2E test", func() {
|
||||
@@ -30,14 +30,13 @@ var _ = Describe("E2E test", func() {
|
||||
Context("text", func() {
|
||||
It("correctly", func() {
|
||||
model := "gpt-4"
|
||||
resp, err := client.CreateChatCompletion(context.TODO(),
|
||||
openai.ChatCompletionRequest{
|
||||
Model: model, Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "How much is 2+2?",
|
||||
},
|
||||
}})
|
||||
resp, err := client.Chat.Completions.New(context.TODO(),
|
||||
openai.ChatCompletionNewParams{
|
||||
Model: model,
|
||||
Messages: []openai.ChatCompletionMessageParamUnion{
|
||||
openai.UserMessage("How much is 2+2?"),
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(resp.Choices)).To(Equal(1), fmt.Sprint(resp))
|
||||
Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("4"), ContainSubstring("four")), fmt.Sprint(resp.Choices[0].Message.Content))
|
||||
@@ -46,39 +45,36 @@ var _ = Describe("E2E test", func() {
|
||||
|
||||
Context("function calls", func() {
|
||||
It("correctly invoke", func() {
|
||||
params := jsonschema.Definition{
|
||||
Type: jsonschema.Object,
|
||||
Properties: map[string]jsonschema.Definition{
|
||||
"location": {
|
||||
Type: jsonschema.String,
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
params := openai.FunctionParameters{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"location": map[string]string{
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
Type: jsonschema.String,
|
||||
Enum: []string{"celsius", "fahrenheit"},
|
||||
"unit": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []string{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
Required: []string{"location"},
|
||||
"required": []string{"location"},
|
||||
}
|
||||
|
||||
f := openai.FunctionDefinition{
|
||||
Name: "get_current_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: params,
|
||||
}
|
||||
t := openai.Tool{
|
||||
Type: openai.ToolTypeFunction,
|
||||
Function: &f,
|
||||
tool := openai.ChatCompletionToolUnionParam{
|
||||
OfFunction: &openai.ChatCompletionFunctionToolParam{
|
||||
Function: openai.FunctionDefinitionParam{
|
||||
Name: "get_current_weather",
|
||||
Description: openai.String("Get the current weather in a given location"),
|
||||
Parameters: params,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
dialogue := []openai.ChatCompletionMessage{
|
||||
{Role: openai.ChatMessageRoleUser, Content: "What is the weather in Boston today?"},
|
||||
}
|
||||
resp, err := client.CreateChatCompletion(context.TODO(),
|
||||
openai.ChatCompletionRequest{
|
||||
Model: openai.GPT4,
|
||||
Messages: dialogue,
|
||||
Tools: []openai.Tool{t},
|
||||
resp, err := client.Chat.Completions.New(context.TODO(),
|
||||
openai.ChatCompletionNewParams{
|
||||
Model: openai.ChatModelGPT4,
|
||||
Messages: []openai.ChatCompletionMessageParamUnion{openai.UserMessage("What is the weather in Boston today?")},
|
||||
Tools: []openai.ChatCompletionToolUnionParam{tool},
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -90,23 +86,21 @@ var _ = Describe("E2E test", func() {
|
||||
Expect(msg.ToolCalls[0].Function.Arguments).To(ContainSubstring("Boston"), fmt.Sprint(msg.ToolCalls[0].Function.Arguments))
|
||||
})
|
||||
})
|
||||
|
||||
Context("json", func() {
|
||||
It("correctly", func() {
|
||||
model := "gpt-4"
|
||||
|
||||
req := openai.ChatCompletionRequest{
|
||||
ResponseFormat: &openai.ChatCompletionResponseFormat{Type: openai.ChatCompletionResponseFormatTypeJSONObject},
|
||||
Model: model,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
|
||||
Role: "user",
|
||||
Content: "Generate a JSON object of an animal with 'name', 'gender' and 'legs' fields",
|
||||
resp, err := client.Chat.Completions.New(context.TODO(),
|
||||
openai.ChatCompletionNewParams{
|
||||
Model: model,
|
||||
Messages: []openai.ChatCompletionMessageParamUnion{
|
||||
openai.UserMessage("Generate a JSON object of an animal with 'name', 'gender' and 'legs' fields"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.CreateChatCompletion(context.TODO(), req)
|
||||
ResponseFormat: openai.ChatCompletionNewParamsResponseFormatUnion{
|
||||
OfJSONObject: &openai.ResponseFormatJSONObjectParam{},
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(resp.Choices)).To(Equal(1), fmt.Sprint(resp))
|
||||
|
||||
@@ -121,23 +115,23 @@ var _ = Describe("E2E test", func() {
|
||||
|
||||
Context("images", func() {
|
||||
It("correctly", func() {
|
||||
req := openai.ImageRequest{
|
||||
Prompt: "test",
|
||||
Quality: "1",
|
||||
Size: openai.CreateImageSize256x256,
|
||||
}
|
||||
resp, err := client.CreateImage(context.TODO(), req)
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("error sending image request %+v", req))
|
||||
resp, err := client.Images.Generate(context.TODO(),
|
||||
openai.ImageGenerateParams{
|
||||
Prompt: "test",
|
||||
Size: openai.ImageGenerateParamsSize256x256,
|
||||
Quality: openai.ImageGenerateParamsQualityLow,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("error sending image request"))
|
||||
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
|
||||
Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL))
|
||||
})
|
||||
It("correctly changes the response format to url", func() {
|
||||
resp, err := client.CreateImage(context.TODO(),
|
||||
openai.ImageRequest{
|
||||
resp, err := client.Images.Generate(context.TODO(),
|
||||
openai.ImageGenerateParams{
|
||||
Prompt: "test",
|
||||
Size: openai.CreateImageSize256x256,
|
||||
Quality: "1",
|
||||
ResponseFormat: openai.CreateImageResponseFormatURL,
|
||||
Size: openai.ImageGenerateParamsSize256x256,
|
||||
ResponseFormat: openai.ImageGenerateParamsResponseFormatURL,
|
||||
Quality: openai.ImageGenerateParamsQualityLow,
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -145,12 +139,11 @@ var _ = Describe("E2E test", func() {
|
||||
Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL))
|
||||
})
|
||||
It("correctly changes the response format to base64", func() {
|
||||
resp, err := client.CreateImage(context.TODO(),
|
||||
openai.ImageRequest{
|
||||
resp, err := client.Images.Generate(context.TODO(),
|
||||
openai.ImageGenerateParams{
|
||||
Prompt: "test",
|
||||
Size: openai.CreateImageSize256x256,
|
||||
Quality: "1",
|
||||
ResponseFormat: openai.CreateImageResponseFormatB64JSON,
|
||||
Size: openai.ImageGenerateParamsSize256x256,
|
||||
ResponseFormat: openai.ImageGenerateParamsResponseFormatB64JSON,
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -158,22 +151,27 @@ var _ = Describe("E2E test", func() {
|
||||
Expect(resp.Data[0].B64JSON).ToNot(BeEmpty(), fmt.Sprint(resp.Data[0].B64JSON))
|
||||
})
|
||||
})
|
||||
|
||||
Context("embeddings", func() {
|
||||
It("correctly", func() {
|
||||
resp, err := client.CreateEmbeddings(context.TODO(),
|
||||
openai.EmbeddingRequestStrings{
|
||||
Input: []string{"doc"},
|
||||
Model: openai.AdaEmbeddingV2,
|
||||
resp, err := client.Embeddings.New(context.TODO(),
|
||||
openai.EmbeddingNewParams{
|
||||
Input: openai.EmbeddingNewParamsInputUnion{
|
||||
OfArrayOfStrings: []string{"doc"},
|
||||
},
|
||||
Model: openai.EmbeddingModelTextEmbeddingAda002,
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
|
||||
Expect(resp.Data[0].Embedding).ToNot(BeEmpty())
|
||||
|
||||
resp2, err := client.CreateEmbeddings(context.TODO(),
|
||||
openai.EmbeddingRequestStrings{
|
||||
Input: []string{"cat"},
|
||||
Model: openai.AdaEmbeddingV2,
|
||||
resp2, err := client.Embeddings.New(context.TODO(),
|
||||
openai.EmbeddingNewParams{
|
||||
Input: openai.EmbeddingNewParamsInputUnion{
|
||||
OfArrayOfStrings: []string{"cat"},
|
||||
},
|
||||
Model: openai.EmbeddingModelTextEmbeddingAda002,
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -181,10 +179,12 @@ var _ = Describe("E2E test", func() {
|
||||
Expect(resp2.Data[0].Embedding).ToNot(BeEmpty())
|
||||
Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[0].Embedding))
|
||||
|
||||
resp3, err := client.CreateEmbeddings(context.TODO(),
|
||||
openai.EmbeddingRequestStrings{
|
||||
Input: []string{"doc", "cat"},
|
||||
Model: openai.AdaEmbeddingV2,
|
||||
resp3, err := client.Embeddings.New(context.TODO(),
|
||||
openai.EmbeddingNewParams{
|
||||
Input: openai.EmbeddingNewParamsInputUnion{
|
||||
OfArrayOfStrings: []string{"doc", "cat"},
|
||||
},
|
||||
Model: openai.EmbeddingModelTextEmbeddingAda002,
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -195,66 +195,101 @@ var _ = Describe("E2E test", func() {
|
||||
Expect(resp3.Data[0].Embedding).ToNot(Equal(resp3.Data[1].Embedding))
|
||||
})
|
||||
})
|
||||
|
||||
Context("vision", func() {
|
||||
It("correctly", func() {
|
||||
model := "gpt-4o"
|
||||
resp, err := client.CreateChatCompletion(context.TODO(),
|
||||
openai.ChatCompletionRequest{
|
||||
Model: model, Messages: []openai.ChatCompletionMessage{
|
||||
resp, err := client.Chat.Completions.New(context.TODO(),
|
||||
openai.ChatCompletionNewParams{
|
||||
Model: model,
|
||||
Messages: []openai.ChatCompletionMessageParamUnion{
|
||||
{
|
||||
|
||||
Role: "user",
|
||||
MultiContent: []openai.ChatMessagePart{
|
||||
{
|
||||
Type: openai.ChatMessagePartTypeText,
|
||||
Text: "What is in the image?",
|
||||
},
|
||||
{
|
||||
Type: openai.ChatMessagePartTypeImageURL,
|
||||
ImageURL: &openai.ChatMessageImageURL{
|
||||
URL: "https://picsum.photos/id/22/4434/3729",
|
||||
Detail: openai.ImageURLDetailLow,
|
||||
OfUser: &openai.ChatCompletionUserMessageParam{
|
||||
Role: "user",
|
||||
Content: openai.ChatCompletionUserMessageParamContentUnion{
|
||||
OfArrayOfContentParts: []openai.ChatCompletionContentPartUnionParam{
|
||||
{
|
||||
OfText: &openai.ChatCompletionContentPartTextParam{
|
||||
Type: "text",
|
||||
Text: "What is in the image?",
|
||||
},
|
||||
},
|
||||
{
|
||||
OfImageURL: &openai.ChatCompletionContentPartImageParam{
|
||||
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
|
||||
URL: "https://picsum.photos/id/22/4434/3729",
|
||||
Detail: "low",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}})
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(resp.Choices)).To(Equal(1), fmt.Sprint(resp))
|
||||
Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("man"), ContainSubstring("road")), fmt.Sprint(resp.Choices[0].Message.Content))
|
||||
})
|
||||
})
|
||||
|
||||
Context("text to audio", func() {
|
||||
It("correctly", func() {
|
||||
res, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{
|
||||
Model: openai.TTSModel1,
|
||||
res, err := client.Audio.Speech.New(context.Background(), openai.AudioSpeechNewParams{
|
||||
Model: openai.SpeechModelTTS1,
|
||||
Input: "Hello!",
|
||||
Voice: openai.VoiceAlloy,
|
||||
Voice: openai.AudioSpeechNewParamsVoiceAlloy,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer res.Close()
|
||||
defer res.Body.Close()
|
||||
|
||||
_, err = io.ReadAll(res)
|
||||
_, err = io.ReadAll(res.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
})
|
||||
})
|
||||
|
||||
Context("audio to text", func() {
|
||||
It("correctly", func() {
|
||||
|
||||
downloadURL := "https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav"
|
||||
file, err := downloadHttpFile(downloadURL)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req := openai.AudioRequest{
|
||||
Model: openai.Whisper1,
|
||||
FilePath: file,
|
||||
}
|
||||
resp, err := client.CreateTranscription(context.Background(), req)
|
||||
fileHandle, err := os.Open(file)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer fileHandle.Close()
|
||||
|
||||
transcriptionResp, err := client.Audio.Transcriptions.New(context.Background(), openai.AudioTranscriptionNewParams{
|
||||
Model: openai.AudioModelWhisper1,
|
||||
File: fileHandle,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
resp := transcriptionResp.AsTranscription()
|
||||
Expect(resp.Text).To(ContainSubstring("This is the"), fmt.Sprint(resp.Text))
|
||||
})
|
||||
|
||||
It("with VTT format", func() {
|
||||
downloadURL := "https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav"
|
||||
file, err := downloadHttpFile(downloadURL)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
fileHandle, err := os.Open(file)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer fileHandle.Close()
|
||||
|
||||
var resp string
|
||||
_, err = client.Audio.Transcriptions.New(context.Background(), openai.AudioTranscriptionNewParams{
|
||||
Model: openai.AudioModelWhisper1,
|
||||
File: fileHandle,
|
||||
ResponseFormat: openai.AudioResponseFormatVTT,
|
||||
}, option.WithResponseBodyInto(&resp))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp).To(ContainSubstring("This is the"), resp)
|
||||
Expect(resp).To(ContainSubstring("WEBVTT"), resp)
|
||||
Expect(resp).To(ContainSubstring("00:00:00.000 -->"), resp)
|
||||
})
|
||||
})
|
||||
|
||||
Context("vad", func() {
|
||||
It("correctly", func() {
|
||||
modelName := "silero-vad"
|
||||
@@ -283,6 +318,7 @@ var _ = Describe("E2E test", func() {
|
||||
Expect(deserializedResponse.Segments).ToNot(BeZero())
|
||||
})
|
||||
})
|
||||
|
||||
Context("reranker", func() {
|
||||
It("correctly", func() {
|
||||
modelName := "jina-reranker-v1-base-en"
|
||||
@@ -317,7 +353,6 @@ var _ = Describe("E2E test", func() {
|
||||
Expect(err).To(BeNil())
|
||||
Expect(deserializedResponse).ToNot(BeZero())
|
||||
Expect(deserializedResponse.Model).To(Equal(modelName))
|
||||
//Expect(len(deserializedResponse.Results)).To(BeNumerically(">", 0))
|
||||
Expect(len(deserializedResponse.Results)).To(Equal(expectResults))
|
||||
// Assert that relevance scores are in decreasing order
|
||||
for i := 1; i < len(deserializedResponse.Results); i++ {
|
||||
|
||||
@@ -17,14 +17,14 @@ import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/phayes/freeport"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
"github.com/openai/openai-go/v3"
|
||||
"github.com/openai/openai-go/v3/option"
|
||||
)
|
||||
|
||||
var (
|
||||
localAIURL string
|
||||
anthropicBaseURL string
|
||||
tmpDir string
|
||||
backendPath string
|
||||
@@ -33,7 +33,7 @@ var (
|
||||
app *echo.Echo
|
||||
appCtx context.Context
|
||||
appCancel context.CancelFunc
|
||||
client *openai.Client
|
||||
client openai.Client
|
||||
apiPort int
|
||||
apiURL string
|
||||
mockBackendPath string
|
||||
@@ -129,7 +129,6 @@ var _ = BeforeSuite(func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
apiPort = port
|
||||
apiURL = fmt.Sprintf("http://127.0.0.1:%d/v1", apiPort)
|
||||
localAIURL = apiURL
|
||||
// Anthropic SDK appends /v1/messages to base URL; use base without /v1 so requests go to /v1/messages
|
||||
anthropicBaseURL = fmt.Sprintf("http://127.0.0.1:%d", apiPort)
|
||||
|
||||
@@ -141,12 +140,10 @@ var _ = BeforeSuite(func() {
|
||||
}()
|
||||
|
||||
// Wait for server to be ready
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
defaultConfig.BaseURL = apiURL
|
||||
client = openai.NewClientWithConfig(defaultConfig)
|
||||
client = openai.NewClient(option.WithBaseURL(apiURL))
|
||||
|
||||
Eventually(func() error {
|
||||
_, err := client.ListModels(context.TODO())
|
||||
_, err := client.Models.List(context.TODO())
|
||||
return err
|
||||
}, "2m").ShouldNot(HaveOccurred())
|
||||
})
|
||||
|
||||
@@ -9,22 +9,19 @@ import (
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/openai/openai-go/v3"
|
||||
)
|
||||
|
||||
var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() {
|
||||
Describe("Text Generation APIs", func() {
|
||||
Context("Predict (Chat Completions)", func() {
|
||||
It("should return mocked response", func() {
|
||||
resp, err := client.CreateChatCompletion(
|
||||
resp, err := client.Chat.Completions.New(
|
||||
context.TODO(),
|
||||
openai.ChatCompletionRequest{
|
||||
openai.ChatCompletionNewParams{
|
||||
Model: "mock-model",
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Hello",
|
||||
},
|
||||
Messages: []openai.ChatCompletionMessageParamUnion{
|
||||
openai.UserMessage("Hello"),
|
||||
},
|
||||
},
|
||||
)
|
||||
@@ -36,31 +33,23 @@ var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() {
|
||||
|
||||
Context("PredictStream (Streaming Chat Completions)", func() {
|
||||
It("should stream mocked tokens", func() {
|
||||
stream, err := client.CreateChatCompletionStream(
|
||||
stream := client.Chat.Completions.NewStreaming(
|
||||
context.TODO(),
|
||||
openai.ChatCompletionRequest{
|
||||
openai.ChatCompletionNewParams{
|
||||
Model: "mock-model",
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Hello",
|
||||
},
|
||||
Messages: []openai.ChatCompletionMessageParamUnion{
|
||||
openai.UserMessage("Hello"),
|
||||
},
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer stream.Close()
|
||||
|
||||
hasContent := false
|
||||
for {
|
||||
response, err := stream.Recv()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
for stream.Next() {
|
||||
response := stream.Current()
|
||||
if len(response.Choices) > 0 && response.Choices[0].Delta.Content != "" {
|
||||
hasContent = true
|
||||
}
|
||||
}
|
||||
Expect(stream.Err()).ToNot(HaveOccurred())
|
||||
Expect(hasContent).To(BeTrue())
|
||||
})
|
||||
})
|
||||
@@ -68,11 +57,13 @@ var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() {
|
||||
|
||||
Describe("Embeddings API", func() {
|
||||
It("should return mocked embeddings", func() {
|
||||
resp, err := client.CreateEmbeddings(
|
||||
resp, err := client.Embeddings.New(
|
||||
context.TODO(),
|
||||
openai.EmbeddingRequest{
|
||||
openai.EmbeddingNewParams{
|
||||
Model: "mock-model",
|
||||
Input: []string{"test"},
|
||||
Input: openai.EmbeddingNewParamsInputUnion{
|
||||
OfArrayOfStrings: []string{"test"},
|
||||
},
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Reference in New Issue
Block a user