mirror of
https://github.com/mudler/LocalAI.git
synced 2026-01-04 09:40:32 -06:00
feat(api): OpenAI video create enpoint integration (#6777)
* feat: add OpenAI-compatible /v1/videos endpoint - Add VideoEndpoint handler with OpenAI request mapping - Add MapOpenAIToVideo function to convert OpenAI format to LocalAI VideoRequest - Add Swagger documentation for API endpoint - Add Ginkgo unit tests for mapping logic - Add Ginkgo integration test with embedded fake backend Signed-off-by: Greg <marianigregory@pm.me> * Apply suggestion from @mudler Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com> * Apply suggestion from @mudler Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com> * Apply suggestion from @mudler Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com> * Apply suggestion from @mudler Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com> * Apply suggestion from @mudler Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com> * Apply suggestion from @mudler Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com> --------- Signed-off-by: Greg <marianigregory@pm.me> Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com> Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
136
core/http/endpoints/openai/video.go
Normal file
136
core/http/endpoints/openai/video.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
var raw map[string]interface{}
|
||||
if body := c.Body(); len(body) > 0 {
|
||||
_ = json.Unmarshal(body, &raw)
|
||||
}
|
||||
// Build VideoRequest using shared mapper
|
||||
vr := MapOpenAIToVideo(input, raw)
|
||||
// Place VideoRequest into locals so localai.VideoEndpoint can consume it
|
||||
c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, vr)
|
||||
// Delegate to existing localai handler
|
||||
return localai.VideoEndpoint(cl, ml, appConfig)(c)
|
||||
}
|
||||
}
|
||||
|
||||
// VideoEndpoint godoc
|
||||
// @Summary Generate a video from an OpenAI-compatible request
|
||||
// @Description Accepts an OpenAI-style request and delegates to the LocalAI video generator
|
||||
// @Tags openai
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body schema.OpenAIRequest true "OpenAI-style request"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Failure 400 {object} map[string]interface{}
|
||||
// @Router /v1/videos [post]
|
||||
|
||||
func MapOpenAIToVideo(input *schema.OpenAIRequest, raw map[string]interface{}) *schema.VideoRequest {
|
||||
vr := &schema.VideoRequest{}
|
||||
if input == nil {
|
||||
return vr
|
||||
}
|
||||
|
||||
if input.Model != "" {
|
||||
vr.Model = input.Model
|
||||
}
|
||||
|
||||
// Prompt mapping
|
||||
switch p := input.Prompt.(type) {
|
||||
case string:
|
||||
vr.Prompt = p
|
||||
case []interface{}:
|
||||
if len(p) > 0 {
|
||||
if s, ok := p[0].(string); ok {
|
||||
vr.Prompt = s
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Size
|
||||
size := input.Size
|
||||
if size == "" && raw != nil {
|
||||
if v, ok := raw["size"].(string); ok {
|
||||
size = v
|
||||
}
|
||||
}
|
||||
if size != "" {
|
||||
parts := strings.SplitN(size, "x", 2)
|
||||
if len(parts) == 2 {
|
||||
if wi, err := strconv.Atoi(parts[0]); err == nil {
|
||||
vr.Width = int32(wi)
|
||||
}
|
||||
if hi, err := strconv.Atoi(parts[1]); err == nil {
|
||||
vr.Height = int32(hi)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// seconds -> num frames
|
||||
secondsStr := ""
|
||||
if raw != nil {
|
||||
if v, ok := raw["seconds"].(string); ok {
|
||||
secondsStr = v
|
||||
} else if v, ok := raw["seconds"].(float64); ok {
|
||||
secondsStr = fmt.Sprintf("%v", int(v))
|
||||
}
|
||||
}
|
||||
fps := int32(30)
|
||||
if raw != nil {
|
||||
if rawFPS, ok := raw["fps"]; ok {
|
||||
switch rf := rawFPS.(type) {
|
||||
case float64:
|
||||
fps = int32(rf)
|
||||
case string:
|
||||
if fi, err := strconv.Atoi(rf); err == nil {
|
||||
fps = int32(fi)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if secondsStr != "" {
|
||||
if secF, err := strconv.Atoi(secondsStr); err == nil {
|
||||
vr.FPS = fps
|
||||
vr.NumFrames = int32(secF) * fps
|
||||
}
|
||||
}
|
||||
|
||||
// input_reference
|
||||
if raw != nil {
|
||||
if v, ok := raw["input_reference"].(string); ok {
|
||||
vr.StartImage = v
|
||||
}
|
||||
}
|
||||
|
||||
// response format
|
||||
if input.ResponseFormat != nil {
|
||||
if rf, ok := input.ResponseFormat.(string); ok {
|
||||
vr.ResponseFormat = rf
|
||||
}
|
||||
}
|
||||
|
||||
if input.Step != 0 {
|
||||
vr.Step = int32(input.Step)
|
||||
}
|
||||
|
||||
return vr
|
||||
}
|
||||
75
core/http/openai_mapping_test.go
Normal file
75
core/http/openai_mapping_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package http_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
openai "github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("MapOpenAIToVideo", func() {
|
||||
It("maps size and seconds correctly", func() {
|
||||
cases := []struct {
|
||||
name string
|
||||
input *schema.OpenAIRequest
|
||||
raw map[string]interface{}
|
||||
expectsW int32
|
||||
expectsH int32
|
||||
expectsF int32
|
||||
expectsN int32
|
||||
}{
|
||||
{
|
||||
name: "size in input",
|
||||
input: &schema.OpenAIRequest{
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: "m"},
|
||||
},
|
||||
Size: "256x128",
|
||||
},
|
||||
expectsW: 256,
|
||||
expectsH: 128,
|
||||
},
|
||||
{
|
||||
name: "size in raw and seconds as string",
|
||||
input: &schema.OpenAIRequest{PredictionOptions: schema.PredictionOptions{BasicModelRequest: schema.BasicModelRequest{Model: "m"}}},
|
||||
raw: map[string]interface{}{"size": "720x480", "seconds": "2"},
|
||||
expectsW: 720,
|
||||
expectsH: 480,
|
||||
expectsF: 30,
|
||||
expectsN: 60,
|
||||
},
|
||||
{
|
||||
name: "seconds as number and fps override",
|
||||
input: &schema.OpenAIRequest{PredictionOptions: schema.PredictionOptions{BasicModelRequest: schema.BasicModelRequest{Model: "m"}}},
|
||||
raw: map[string]interface{}{"seconds": 3.0, "fps": 24.0},
|
||||
expectsF: 24,
|
||||
expectsN: 72,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
By(c.name)
|
||||
vr := openai.MapOpenAIToVideo(c.input, c.raw)
|
||||
if c.expectsW != 0 {
|
||||
Expect(vr.Width).To(Equal(c.expectsW))
|
||||
}
|
||||
if c.expectsH != 0 {
|
||||
Expect(vr.Height).To(Equal(c.expectsH))
|
||||
}
|
||||
if c.expectsF != 0 {
|
||||
Expect(vr.FPS).To(Equal(c.expectsF))
|
||||
}
|
||||
if c.expectsN != 0 {
|
||||
Expect(vr.NumFrames).To(Equal(c.expectsN))
|
||||
}
|
||||
|
||||
b, err := json.Marshal(vr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_ = b
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
162
core/http/openai_videos_test.go
Normal file
162
core/http/openai_videos_test.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package http_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"fmt"
|
||||
. "github.com/mudler/LocalAI/core/http"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const testAPIKey = "joshua"
|
||||
|
||||
type fakeAI struct{}
|
||||
|
||||
func (f *fakeAI) Busy() bool { return false }
|
||||
func (f *fakeAI) Lock() {}
|
||||
func (f *fakeAI) Unlock() {}
|
||||
func (f *fakeAI) Locking() bool { return false }
|
||||
func (f *fakeAI) Predict(*pb.PredictOptions) (string, error) { return "", nil }
|
||||
func (f *fakeAI) PredictStream(*pb.PredictOptions, chan string) error {
|
||||
return nil
|
||||
}
|
||||
func (f *fakeAI) Load(*pb.ModelOptions) error { return nil }
|
||||
func (f *fakeAI) Embeddings(*pb.PredictOptions) ([]float32, error) { return nil, nil }
|
||||
func (f *fakeAI) GenerateImage(*pb.GenerateImageRequest) error { return nil }
|
||||
func (f *fakeAI) GenerateVideo(*pb.GenerateVideoRequest) error { return nil }
|
||||
func (f *fakeAI) Detect(*pb.DetectOptions) (pb.DetectResponse, error) { return pb.DetectResponse{}, nil }
|
||||
func (f *fakeAI) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
return pb.TranscriptResult{}, nil
|
||||
}
|
||||
func (f *fakeAI) TTS(*pb.TTSRequest) error { return nil }
|
||||
func (f *fakeAI) SoundGeneration(*pb.SoundGenerationRequest) error { return nil }
|
||||
func (f *fakeAI) TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error) {
|
||||
return pb.TokenizationResponse{}, nil
|
||||
}
|
||||
func (f *fakeAI) Status() (pb.StatusResponse, error) { return pb.StatusResponse{}, nil }
|
||||
func (f *fakeAI) StoresSet(*pb.StoresSetOptions) error { return nil }
|
||||
func (f *fakeAI) StoresDelete(*pb.StoresDeleteOptions) error { return nil }
|
||||
func (f *fakeAI) StoresGet(*pb.StoresGetOptions) (pb.StoresGetResult, error) {
|
||||
return pb.StoresGetResult{}, nil
|
||||
}
|
||||
func (f *fakeAI) StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
return pb.StoresFindResult{}, nil
|
||||
}
|
||||
func (f *fakeAI) VAD(*pb.VADRequest) (pb.VADResponse, error) { return pb.VADResponse{}, nil }
|
||||
|
||||
var _ = Describe("OpenAI /v1/videos (embedded backend)", func() {
|
||||
var tmpdir string
|
||||
var appServer *application.Application
|
||||
var app *fiber.App
|
||||
var ctx context.Context
|
||||
var cancel context.CancelFunc
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tmpdir, err = os.MkdirTemp("", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
modelDir := filepath.Join(tmpdir, "models")
|
||||
err = os.Mkdir(modelDir, 0750)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(modelDir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
grpc.Provide("embedded://fake", &fakeAI{})
|
||||
|
||||
appServer, err = application.New(
|
||||
config.WithContext(ctx),
|
||||
config.WithSystemState(systemState),
|
||||
config.WithApiKeys([]string{testAPIKey}),
|
||||
config.WithGeneratedContentDir(tmpdir),
|
||||
config.WithExternalBackend("fake", "embedded://fake"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
if app != nil {
|
||||
_ = app.Shutdown()
|
||||
}
|
||||
_ = os.RemoveAll(tmpdir)
|
||||
})
|
||||
|
||||
It("accepts OpenAI-style video create and delegates to backend", func() {
|
||||
var err error
|
||||
app, err = API(appServer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9091")
|
||||
|
||||
// wait for server
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
Eventually(func() error {
|
||||
req, _ := http.NewRequest("GET", "http://127.0.0.1:9091/v1/models", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+testAPIKey)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
return fmt.Errorf("bad status: %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}, "30s", "500ms").Should(Succeed())
|
||||
|
||||
body := map[string]interface{}{
|
||||
"model": "fake-model",
|
||||
"backend": "fake",
|
||||
"prompt": "a test video",
|
||||
"size": "256x256",
|
||||
"seconds": "1",
|
||||
}
|
||||
payload, err := json.Marshal(body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9091/v1/videos", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+testAPIKey)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer resp.Body.Close()
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
|
||||
dat, err := io.ReadAll(resp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var out map[string]interface{}
|
||||
err = json.Unmarshal(dat, &out)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, ok := out["data"].([]interface{})
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(len(data)).To(BeNumerically(">", 0))
|
||||
first := data[0].(map[string]interface{})
|
||||
url, ok := first["url"].(string)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(url).To(ContainSubstring("/generated-videos/"))
|
||||
Expect(url).To(ContainSubstring(".mp4"))
|
||||
})
|
||||
})
|
||||
@@ -108,6 +108,19 @@ func RegisterOpenAIRoutes(app *fiber.App,
|
||||
imageChain...)
|
||||
app.Post("/images/generations", imageChain...)
|
||||
|
||||
// videos (OpenAI-compatible endpoints mapped to LocalAI video handler)
|
||||
videoChain := []fiber.Handler{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VIDEO)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.VideoEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()),
|
||||
}
|
||||
|
||||
// OpenAI-style create video endpoint
|
||||
app.Post("/v1/videos", videoChain...)
|
||||
app.Post("/v1/videos/generations", videoChain...)
|
||||
app.Post("/videos", videoChain...)
|
||||
|
||||
// List models
|
||||
app.Get("/v1/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Get("/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
|
||||
@@ -35,6 +35,9 @@ type VideoRequest struct {
|
||||
Height int32 `json:"height" yaml:"height"`
|
||||
NumFrames int32 `json:"num_frames" yaml:"num_frames"`
|
||||
FPS int32 `json:"fps" yaml:"fps"`
|
||||
Seconds string `json:"seconds,omitempty" yaml:"seconds,omitempty"`
|
||||
Size string `json:"size,omitempty" yaml:"size,omitempty"`
|
||||
InputReference string `json:"input_reference,omitempty" yaml:"input_reference,omitempty"`
|
||||
Seed int32 `json:"seed" yaml:"seed"`
|
||||
CFGScale float32 `json:"cfg_scale" yaml:"cfg_scale"`
|
||||
Step int32 `json:"step" yaml:"step"`
|
||||
|
||||
Reference in New Issue
Block a user