diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index afb279dce..17baed364 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -31,7 +31,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator var id, textContentToReturn string var created int - process := func(s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) { + process := func(s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error { initialMessage := schema.OpenAIResponse{ ID: id, Created: created, @@ -41,7 +41,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator } responses <- initialMessage - ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool { + _, _, err := ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool { usage := schema.OpenAIUsage{ PromptTokens: tokenUsage.Prompt, CompletionTokens: tokenUsage.Completion, @@ -65,16 +65,19 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator return true }) close(responses) + return err } - processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) { + processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error { result := "" - _, tokenUsage, _ := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { + _, tokenUsage, err := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { result += s // TODO: Change generated BNF grammar to be compliant with the schema so we can // stream the result token by token here. return true }) - + if err != nil { + return err + } textContentToReturn = functions.ParseTextContent(result, config.FunctionsConfig) result = functions.CleanupLLMResult(result, config.FunctionsConfig) functionResults := functions.ParseFunctionCall(result, config.FunctionsConfig) @@ -95,7 +98,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator result, err := handleQuestion(config, cl, req, ml, startupOptions, functionResults, result, prompt) if err != nil { log.Error().Err(err).Msg("error handling question") - return + return err } usage := schema.OpenAIUsage{ PromptTokens: tokenUsage.Prompt, @@ -169,6 +172,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator } close(responses) + return err } return func(c *fiber.Ctx) error { @@ -223,9 +227,11 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator if err != nil { return err } - if d.Type == "json_object" { + + switch d.Type { + case "json_object": input.Grammar = functions.JSONBNF - } else if d.Type == "json_schema" { + case "json_schema": d := schema.JsonSchemaRequest{} dat, err := json.Marshal(config.ResponseFormatMap) if err != nil { @@ -326,31 +332,69 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator c.Set("X-Correlation-ID", id) responses := make(chan schema.OpenAIResponse) + ended := make(chan error, 1) - if !shouldUseFn { - go process(predInput, input, config, ml, responses, extraUsage) - } else { - go processTools(noActionName, predInput, input, config, ml, responses, extraUsage) - } + go func() { + if !shouldUseFn { + ended <- process(predInput, input, config, ml, responses, extraUsage) + } else { + ended <- processTools(noActionName, predInput, input, config, ml, responses, extraUsage) + } + }() c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { usage := &schema.OpenAIUsage{} toolsCalled := false - for ev := range responses { - usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it - if len(ev.Choices[0].Delta.ToolCalls) > 0 { - toolsCalled = true + + LOOP: + for { + select { + case ev := <-responses: + if len(ev.Choices) == 0 { + log.Debug().Msgf("No choices in the response, skipping") + continue + } + usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it + if len(ev.Choices[0].Delta.ToolCalls) > 0 { + toolsCalled = true + } + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.Encode(ev) + log.Debug().Msgf("Sending chunk: %s", buf.String()) + _, err := fmt.Fprintf(w, "data: %v\n", buf.String()) + if err != nil { + log.Debug().Msgf("Sending chunk failed: %v", err) + input.Cancel() + } + w.Flush() + case err := <-ended: + if err == nil { + break LOOP + } + log.Error().Msgf("Stream ended with error: %v", err) + + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{ + { + FinishReason: "stop", + Index: 0, + Delta: &schema.Message{Content: "Internal error: " + err.Error()}, + }}, + Object: "chat.completion.chunk", + Usage: *usage, + } + respData, _ := json.Marshal(resp) + + w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) + w.WriteString("data: [DONE]\n\n") + w.Flush() + + return } - var buf bytes.Buffer - enc := json.NewEncoder(&buf) - enc.Encode(ev) - log.Debug().Msgf("Sending chunk: %s", buf.String()) - _, err := fmt.Fprintf(w, "data: %v\n", buf.String()) - if err != nil { - log.Debug().Msgf("Sending chunk failed: %v", err) - input.Cancel() - } - w.Flush() } finishReason := "stop" @@ -378,7 +422,9 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) w.WriteString("data: [DONE]\n\n") w.Flush() + log.Debug().Msgf("Stream ended") })) + return nil // no streaming mode diff --git a/core/http/endpoints/openai/completion.go b/core/http/endpoints/openai/completion.go index 7700a4559..3b7fccc85 100644 --- a/core/http/endpoints/openai/completion.go +++ b/core/http/endpoints/openai/completion.go @@ -30,7 +30,7 @@ import ( func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { created := int(time.Now().Unix()) - process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) { + process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error { tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool { usage := schema.OpenAIUsage{ PromptTokens: tokenUsage.Prompt, @@ -59,8 +59,9 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva responses <- resp return true } - ComputeChoices(req, s, config, cl, appConfig, loader, func(s string, c *[]schema.Choice) {}, tokenCallback) + _, _, err := ComputeChoices(req, s, config, cl, appConfig, loader, func(s string, c *[]schema.Choice) {}, tokenCallback) close(responses) + return err } return func(c *fiber.Ctx) error { @@ -121,18 +122,37 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva responses := make(chan schema.OpenAIResponse) - go process(id, predInput, input, config, ml, responses, extraUsage) + ended := make(chan error) + go func() { + ended <- process(id, predInput, input, config, ml, responses, extraUsage) + }() c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { - for ev := range responses { - var buf bytes.Buffer - enc := json.NewEncoder(&buf) - enc.Encode(ev) + LOOP: + for { + select { + case ev := <-responses: + if len(ev.Choices) == 0 { + log.Debug().Msgf("No choices in the response, skipping") + continue + } + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.Encode(ev) - log.Debug().Msgf("Sending chunk: %s", buf.String()) - fmt.Fprintf(w, "data: %v\n", buf.String()) - w.Flush() + log.Debug().Msgf("Sending chunk: %s", buf.String()) + fmt.Fprintf(w, "data: %v\n", buf.String()) + w.Flush() + case err := <-ended: + if err == nil { + break LOOP + } + log.Error().Msgf("Stream ended with error: %v", err) + fmt.Fprintf(w, "data: %v\n", "Internal error: "+err.Error()) + w.Flush() + break LOOP + } } resp := &schema.OpenAIResponse{ @@ -153,7 +173,7 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva w.WriteString("data: [DONE]\n\n") w.Flush() })) - return nil + return <-ended } var result []schema.Choice