mirror of
https://github.com/mudler/LocalAI.git
synced 2025-12-21 09:20:14 -06:00
feat(inpainting): add inpainting endpoint, wire ImageGenerationFunc and return generated image URL (#7328)
feat(inpainting): add inpainting endpoint with automatic model selection Signed-off-by: Greg <marianigregory@pm.me>
This commit is contained in:
@@ -40,3 +40,7 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
|
||||
|
||||
return fn, nil
|
||||
}
|
||||
|
||||
// ImageGenerationFunc is a test-friendly indirection to call image generation logic.
|
||||
// Tests can override this variable to provide a stub implementation.
|
||||
var ImageGenerationFunc = ImageGeneration
|
||||
|
||||
268
core/http/endpoints/openai/inpainting.go
Normal file
268
core/http/endpoints/openai/inpainting.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// InpaintingEndpoint handles POST /v1/images/inpainting
|
||||
//
|
||||
// Swagger / OpenAPI docstring (swaggo):
|
||||
// @Summary Image inpainting
|
||||
// @Description Perform image inpainting. Accepts multipart/form-data with `image` and `mask` files.
|
||||
// @Tags images
|
||||
// @Accept multipart/form-data
|
||||
// @Produce application/json
|
||||
// @Param model formData string true "Model identifier"
|
||||
// @Param prompt formData string true "Text prompt guiding the generation"
|
||||
// @Param steps formData int false "Number of inference steps (default 25)"
|
||||
// @Param image formData file true "Original image file"
|
||||
// @Param mask formData file true "Mask image file (white = area to inpaint)"
|
||||
// @Success 200 {object} schema.OpenAIResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /v1/images/inpainting [post]
|
||||
func InpaintingEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// Parse basic form values
|
||||
modelName := c.FormValue("model")
|
||||
prompt := c.FormValue("prompt")
|
||||
stepsStr := c.FormValue("steps")
|
||||
|
||||
if modelName == "" || prompt == "" {
|
||||
log.Error().Msg("Inpainting Endpoint - missing model or prompt")
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
// steps default
|
||||
steps := 25
|
||||
if stepsStr != "" {
|
||||
if v, err := strconv.Atoi(stepsStr); err == nil {
|
||||
steps = v
|
||||
}
|
||||
}
|
||||
|
||||
// Get uploaded files
|
||||
imageFile, err := c.FormFile("image")
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Inpainting Endpoint - missing image file")
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "missing image file")
|
||||
}
|
||||
maskFile, err := c.FormFile("mask")
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Inpainting Endpoint - missing mask file")
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "missing mask file")
|
||||
}
|
||||
|
||||
// Read files into memory (small files expected)
|
||||
imgSrc, err := imageFile.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer imgSrc.Close()
|
||||
imgBytes, err := io.ReadAll(imgSrc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
maskSrc, err := maskFile.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer maskSrc.Close()
|
||||
maskBytes, err := io.ReadAll(maskSrc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create JSON with base64 fields expected by backend
|
||||
b64Image := base64.StdEncoding.EncodeToString(imgBytes)
|
||||
b64Mask := base64.StdEncoding.EncodeToString(maskBytes)
|
||||
|
||||
// get model config from context (middleware set it)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
log.Error().Msg("Inpainting Endpoint - model config not found in context")
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
// Use the GeneratedContentDir so the generated PNG is placed where the
|
||||
// HTTP static handler serves `/generated-images`.
|
||||
tmpDir := appConfig.GeneratedContentDir
|
||||
// Ensure the directory exists
|
||||
if err := os.MkdirAll(tmpDir, 0750); err != nil {
|
||||
log.Error().Err(err).Msgf("Inpainting Endpoint - failed to create generated content dir: %s", tmpDir)
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "failed to prepare storage")
|
||||
}
|
||||
id := uuid.New().String()
|
||||
jsonPath := filepath.Join(tmpDir, fmt.Sprintf("inpaint_%s.json", id))
|
||||
jsonFile := map[string]string{
|
||||
"image": b64Image,
|
||||
"mask_image": b64Mask,
|
||||
}
|
||||
jf, err := os.CreateTemp(tmpDir, "inpaint_")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// setup cleanup on error; if everything succeeds we set success = true
|
||||
success := false
|
||||
var dst string
|
||||
var origRef string
|
||||
var maskRef string
|
||||
defer func() {
|
||||
if !success {
|
||||
// Best-effort cleanup; log any failures
|
||||
if jf != nil {
|
||||
if cerr := jf.Close(); cerr != nil {
|
||||
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close temp json file in cleanup")
|
||||
}
|
||||
if name := jf.Name(); name != "" {
|
||||
if rerr := os.Remove(name); rerr != nil && !os.IsNotExist(rerr) {
|
||||
log.Warn().Err(rerr).Msgf("Inpainting Endpoint - failed to remove temp json file %s in cleanup", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
if jsonPath != "" {
|
||||
if rerr := os.Remove(jsonPath); rerr != nil && !os.IsNotExist(rerr) {
|
||||
log.Warn().Err(rerr).Msgf("Inpainting Endpoint - failed to remove json file %s in cleanup", jsonPath)
|
||||
}
|
||||
}
|
||||
if dst != "" {
|
||||
if rerr := os.Remove(dst); rerr != nil && !os.IsNotExist(rerr) {
|
||||
log.Warn().Err(rerr).Msgf("Inpainting Endpoint - failed to remove dst file %s in cleanup", dst)
|
||||
}
|
||||
}
|
||||
if origRef != "" {
|
||||
if rerr := os.Remove(origRef); rerr != nil && !os.IsNotExist(rerr) {
|
||||
log.Warn().Err(rerr).Msgf("Inpainting Endpoint - failed to remove orig ref file %s in cleanup", origRef)
|
||||
}
|
||||
}
|
||||
if maskRef != "" {
|
||||
if rerr := os.Remove(maskRef); rerr != nil && !os.IsNotExist(rerr) {
|
||||
log.Warn().Err(rerr).Msgf("Inpainting Endpoint - failed to remove mask ref file %s in cleanup", maskRef)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// write original image and mask to disk as ref images so backends that
|
||||
// accept reference image files can use them (maintainer request).
|
||||
origTmp, err := os.CreateTemp(tmpDir, "refimg_")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := origTmp.Write(imgBytes); err != nil {
|
||||
_ = origTmp.Close()
|
||||
_ = os.Remove(origTmp.Name())
|
||||
return err
|
||||
}
|
||||
if cerr := origTmp.Close(); cerr != nil {
|
||||
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close orig temp file")
|
||||
}
|
||||
origRef = origTmp.Name()
|
||||
|
||||
maskTmp, err := os.CreateTemp(tmpDir, "refmask_")
|
||||
if err != nil {
|
||||
// cleanup origTmp on error
|
||||
_ = os.Remove(origRef)
|
||||
return err
|
||||
}
|
||||
if _, err := maskTmp.Write(maskBytes); err != nil {
|
||||
_ = maskTmp.Close()
|
||||
_ = os.Remove(maskTmp.Name())
|
||||
_ = os.Remove(origRef)
|
||||
return err
|
||||
}
|
||||
if cerr := maskTmp.Close(); cerr != nil {
|
||||
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close mask temp file")
|
||||
}
|
||||
maskRef = maskTmp.Name()
|
||||
// write JSON
|
||||
enc := json.NewEncoder(jf)
|
||||
if err := enc.Encode(jsonFile); err != nil {
|
||||
if cerr := jf.Close(); cerr != nil {
|
||||
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close temp json file after encode error")
|
||||
}
|
||||
return err
|
||||
}
|
||||
if cerr := jf.Close(); cerr != nil {
|
||||
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close temp json file")
|
||||
}
|
||||
// rename to desired name
|
||||
if err := os.Rename(jf.Name(), jsonPath); err != nil {
|
||||
return err
|
||||
}
|
||||
// prepare dst
|
||||
outTmp, err := os.CreateTemp(tmpDir, "out_")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cerr := outTmp.Close(); cerr != nil {
|
||||
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close out temp file")
|
||||
}
|
||||
dst = outTmp.Name() + ".png"
|
||||
if err := os.Rename(outTmp.Name(), dst); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Determine width/height default
|
||||
width := 512
|
||||
height := 512
|
||||
|
||||
// Call backend image generation via indirection so tests can stub it
|
||||
// Note: ImageGenerationFunc will call into the loaded model's GenerateImage which expects src JSON
|
||||
// Also pass ref images (orig + mask) so backends that support ref images can use them.
|
||||
refImages := []string{origRef, maskRef}
|
||||
fn, err := backend.ImageGenerationFunc(height, width, 0, steps, 0, prompt, "", jsonPath, dst, ml, *cfg, appConfig, refImages)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute generation function (blocking)
|
||||
if err := fn(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// On success, build response URL using BaseURL middleware helper and
|
||||
// the same `generated-images` prefix used by the server static mount.
|
||||
baseURL := middleware.BaseURL(c)
|
||||
|
||||
// Build response using url.JoinPath for correct URL escaping
|
||||
imgPath, err := url.JoinPath(baseURL, "generated-images", filepath.Base(dst))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
created := int(time.Now().Unix())
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Data: []schema.Item{{
|
||||
URL: imgPath,
|
||||
}},
|
||||
}
|
||||
|
||||
// mark success so defer cleanup will not remove output files
|
||||
success = true
|
||||
|
||||
return c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
}
|
||||
107
core/http/endpoints/openai/inpainting_test.go
Normal file
107
core/http/endpoints/openai/inpainting_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func makeMultipartRequest(t *testing.T, fields map[string]string, files map[string][]byte) (*http.Request, string) {
|
||||
b := &bytes.Buffer{}
|
||||
w := multipart.NewWriter(b)
|
||||
for k, v := range fields {
|
||||
_ = w.WriteField(k, v)
|
||||
}
|
||||
for fname, content := range files {
|
||||
fw, err := w.CreateFormFile(fname, fname+".png")
|
||||
require.NoError(t, err)
|
||||
_, err = fw.Write(content)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.NoError(t, w.Close())
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/inpainting", b)
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
return req, w.FormDataContentType()
|
||||
}
|
||||
|
||||
func TestInpainting_MissingFiles(t *testing.T) {
|
||||
e := echo.New()
|
||||
// handler requires cl, ml, appConfig but this test verifies missing files early
|
||||
h := InpaintingEndpoint(nil, nil, config.NewApplicationConfig())
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/inpainting", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := h(c)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestInpainting_HappyPath(t *testing.T) {
|
||||
// Setup temp generated content dir
|
||||
tmpDir, err := os.MkdirTemp("", "gencontent")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
appConf := config.NewApplicationConfig(config.WithGeneratedContentDir(tmpDir))
|
||||
|
||||
// stub the backend.ImageGenerationFunc
|
||||
orig := backend.ImageGenerationFunc
|
||||
backend.ImageGenerationFunc = func(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
|
||||
fn := func() error {
|
||||
// write a fake png file to dst
|
||||
return os.WriteFile(dst, []byte("PNGDATA"), 0644)
|
||||
}
|
||||
return fn, nil
|
||||
}
|
||||
defer func() { backend.ImageGenerationFunc = orig }()
|
||||
|
||||
// prepare multipart request with image and mask
|
||||
fields := map[string]string{"model": "dreamshaper-8-inpainting", "prompt": "A test"}
|
||||
files := map[string][]byte{"image": []byte("IMAGEDATA"), "mask": []byte("MASKDATA")}
|
||||
reqBuf, _ := makeMultipartRequest(t, fields, files)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
e := echo.New()
|
||||
c := e.NewContext(reqBuf, rec)
|
||||
|
||||
// set a minimal model config in context as handler expects
|
||||
c.Set(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG, &config.ModelConfig{Backend: "diffusers"})
|
||||
|
||||
h := InpaintingEndpoint(nil, nil, appConf)
|
||||
|
||||
// call handler
|
||||
err = h(c)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// verify response body contains generated-images path
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "generated-images")
|
||||
|
||||
// confirm the file was created in tmpDir
|
||||
// parse out filename from response (naive search)
|
||||
// find "generated-images/" and extract until closing quote or brace
|
||||
idx := bytes.Index(rec.Body.Bytes(), []byte("generated-images/"))
|
||||
require.True(t, idx >= 0)
|
||||
rest := rec.Body.Bytes()[idx:]
|
||||
end := bytes.IndexAny(rest, "\",}\n")
|
||||
if end == -1 {
|
||||
end = len(rest)
|
||||
}
|
||||
fname := string(rest[len("generated-images/"):end])
|
||||
// ensure file exists
|
||||
_, err = os.Stat(filepath.Join(tmpDir, fname))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -108,11 +108,11 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
|
||||
log.Debug().Msgf("[model agent] [model: %s] Reasoning: %s", config.Name, s)
|
||||
}),
|
||||
cogito.WithToolCallBack(func(t *cogito.ToolChoice) bool {
|
||||
log.Debug().Msgf("[model agent] [model: %s] Tool call: %s, reasoning: %s, arguments: %+v", t.Name, t.Reasoning, t.Arguments)
|
||||
log.Debug().Msgf("[model agent] [model: %s] Tool call: %s, reasoning: %s, arguments: %+v", config.Name, t.Name, t.Reasoning, t.Arguments)
|
||||
return true
|
||||
}),
|
||||
cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) {
|
||||
log.Debug().Msgf("[model agent] [model: %s] Tool call result: %s, tool arguments: %+v", t.Name, t.Result, t.ToolArguments)
|
||||
log.Debug().Msgf("[model agent] [model: %s] Tool call result: %s, result: %s, tool arguments: %+v", config.Name, t.Name, t.Result, t.ToolArguments)
|
||||
}),
|
||||
)
|
||||
|
||||
|
||||
@@ -55,6 +55,11 @@ func (re *RequestExtractor) setModelNameFromRequest(c echo.Context) {
|
||||
model = c.QueryParam("model")
|
||||
}
|
||||
|
||||
// Check FormValue for multipart/form-data requests (e.g., /v1/images/inpainting)
|
||||
if model == "" {
|
||||
model = c.FormValue("model")
|
||||
}
|
||||
|
||||
if model == "" {
|
||||
// Set model from bearer token, if available
|
||||
auth := c.Request().Header.Get("Authorization")
|
||||
|
||||
@@ -140,7 +140,8 @@ func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
// images
|
||||
imageHandler := openai.ImageEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
imageMiddleware := []echo.MiddlewareFunc{
|
||||
re.BuildConstantDefaultModelNameMiddleware("stablediffusion"),
|
||||
// Default: use the first available image generation model
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_IMAGE)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
@@ -155,6 +156,11 @@ func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
app.POST("/v1/images/generations", imageHandler, imageMiddleware...)
|
||||
app.POST("/images/generations", imageHandler, imageMiddleware...)
|
||||
|
||||
// inpainting endpoint (image + mask) - reuse same middleware config as images
|
||||
inpaintingHandler := openai.InpaintingEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
app.POST("/v1/images/inpainting", inpaintingHandler, imageMiddleware...)
|
||||
app.POST("/images/inpainting", inpaintingHandler, imageMiddleware...)
|
||||
|
||||
// videos (OpenAI-compatible endpoints mapped to LocalAI video handler)
|
||||
videoHandler := openai.VideoEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
videoMiddleware := []echo.MiddlewareFunc{
|
||||
|
||||
@@ -1197,6 +1197,47 @@ paths:
|
||||
schema:
|
||||
$ref: '#/definitions/schema.OpenAIResponse'
|
||||
summary: Creates an image given a prompt.
|
||||
/v1/images/inpainting:
|
||||
post:
|
||||
consumes:
|
||||
- multipart/form-data
|
||||
- application/json
|
||||
parameters:
|
||||
- in: formData
|
||||
name: model
|
||||
type: string
|
||||
description: Model name (eg. dreamshaper-8-inpainting)
|
||||
required: true
|
||||
- in: formData
|
||||
name: prompt
|
||||
type: string
|
||||
description: Positive prompt text
|
||||
required: true
|
||||
- in: formData
|
||||
name: image
|
||||
type: file
|
||||
description: Source image (PNG/JPEG)
|
||||
required: true
|
||||
- in: formData
|
||||
name: mask
|
||||
type: file
|
||||
description: Mask image (PNG). White=keep, Black=replace (or as backend expects)
|
||||
required: true
|
||||
- in: formData
|
||||
name: steps
|
||||
type: integer
|
||||
description: Number of inference steps
|
||||
- in: body
|
||||
name: request
|
||||
description: "Alternative JSON payload with base64 fields: { image: '<b64>', mask: '<b64>', model, prompt }"
|
||||
schema:
|
||||
$ref: '#/definitions/schema.OpenAIRequest'
|
||||
responses:
|
||||
"200":
|
||||
description: Successful inpainting
|
||||
schema:
|
||||
$ref: '#/definitions/schema.OpenAIResponse'
|
||||
summary: Creates an inpainted image given an image + mask + prompt.
|
||||
/v1/mcp/chat/completions:
|
||||
post:
|
||||
parameters:
|
||||
|
||||
Reference in New Issue
Block a user