feat(inpainting): add inpainting endpoint, wire ImageGenerationFunc and return generated image URL (#7328)

feat(inpainting): add inpainting endpoint with automatic model selection

Signed-off-by: Greg <marianigregory@pm.me>
This commit is contained in:
Gregory Mariani
2025-11-24 21:13:54 +01:00
committed by GitHub
parent 7e01aa8faa
commit 745c31e013
7 changed files with 434 additions and 3 deletions

View File

@@ -40,3 +40,7 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
return fn, nil
}
// ImageGenerationFunc is a test-friendly indirection to call image generation logic.
// Tests can override this variable to provide a stub implementation.
var ImageGenerationFunc = ImageGeneration

View File

@@ -0,0 +1,268 @@
package openai
import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"time"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog/log"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
model "github.com/mudler/LocalAI/pkg/model"
)
// InpaintingEndpoint handles POST /v1/images/inpainting
//
// Swagger / OpenAPI docstring (swaggo):
// @Summary Image inpainting
// @Description Perform image inpainting. Accepts multipart/form-data with `image` and `mask` files.
// @Tags images
// @Accept multipart/form-data
// @Produce application/json
// @Param model formData string true "Model identifier"
// @Param prompt formData string true "Text prompt guiding the generation"
// @Param steps formData int false "Number of inference steps (default 25)"
// @Param image formData file true "Original image file"
// @Param mask formData file true "Mask image file (white = area to inpaint)"
// @Success 200 {object} schema.OpenAIResponse
// @Failure 400 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /v1/images/inpainting [post]
func InpaintingEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
// Parse basic form values
modelName := c.FormValue("model")
prompt := c.FormValue("prompt")
stepsStr := c.FormValue("steps")
if modelName == "" || prompt == "" {
log.Error().Msg("Inpainting Endpoint - missing model or prompt")
return echo.ErrBadRequest
}
// steps default
steps := 25
if stepsStr != "" {
if v, err := strconv.Atoi(stepsStr); err == nil {
steps = v
}
}
// Get uploaded files
imageFile, err := c.FormFile("image")
if err != nil {
log.Error().Err(err).Msg("Inpainting Endpoint - missing image file")
return echo.NewHTTPError(http.StatusBadRequest, "missing image file")
}
maskFile, err := c.FormFile("mask")
if err != nil {
log.Error().Err(err).Msg("Inpainting Endpoint - missing mask file")
return echo.NewHTTPError(http.StatusBadRequest, "missing mask file")
}
// Read files into memory (small files expected)
imgSrc, err := imageFile.Open()
if err != nil {
return err
}
defer imgSrc.Close()
imgBytes, err := io.ReadAll(imgSrc)
if err != nil {
return err
}
maskSrc, err := maskFile.Open()
if err != nil {
return err
}
defer maskSrc.Close()
maskBytes, err := io.ReadAll(maskSrc)
if err != nil {
return err
}
// Create JSON with base64 fields expected by backend
b64Image := base64.StdEncoding.EncodeToString(imgBytes)
b64Mask := base64.StdEncoding.EncodeToString(maskBytes)
// get model config from context (middleware set it)
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
log.Error().Msg("Inpainting Endpoint - model config not found in context")
return echo.ErrBadRequest
}
// Use the GeneratedContentDir so the generated PNG is placed where the
// HTTP static handler serves `/generated-images`.
tmpDir := appConfig.GeneratedContentDir
// Ensure the directory exists
if err := os.MkdirAll(tmpDir, 0750); err != nil {
log.Error().Err(err).Msgf("Inpainting Endpoint - failed to create generated content dir: %s", tmpDir)
return echo.NewHTTPError(http.StatusInternalServerError, "failed to prepare storage")
}
id := uuid.New().String()
jsonPath := filepath.Join(tmpDir, fmt.Sprintf("inpaint_%s.json", id))
jsonFile := map[string]string{
"image": b64Image,
"mask_image": b64Mask,
}
jf, err := os.CreateTemp(tmpDir, "inpaint_")
if err != nil {
return err
}
// setup cleanup on error; if everything succeeds we set success = true
success := false
var dst string
var origRef string
var maskRef string
defer func() {
if !success {
// Best-effort cleanup; log any failures
if jf != nil {
if cerr := jf.Close(); cerr != nil {
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close temp json file in cleanup")
}
if name := jf.Name(); name != "" {
if rerr := os.Remove(name); rerr != nil && !os.IsNotExist(rerr) {
log.Warn().Err(rerr).Msgf("Inpainting Endpoint - failed to remove temp json file %s in cleanup", name)
}
}
}
if jsonPath != "" {
if rerr := os.Remove(jsonPath); rerr != nil && !os.IsNotExist(rerr) {
log.Warn().Err(rerr).Msgf("Inpainting Endpoint - failed to remove json file %s in cleanup", jsonPath)
}
}
if dst != "" {
if rerr := os.Remove(dst); rerr != nil && !os.IsNotExist(rerr) {
log.Warn().Err(rerr).Msgf("Inpainting Endpoint - failed to remove dst file %s in cleanup", dst)
}
}
if origRef != "" {
if rerr := os.Remove(origRef); rerr != nil && !os.IsNotExist(rerr) {
log.Warn().Err(rerr).Msgf("Inpainting Endpoint - failed to remove orig ref file %s in cleanup", origRef)
}
}
if maskRef != "" {
if rerr := os.Remove(maskRef); rerr != nil && !os.IsNotExist(rerr) {
log.Warn().Err(rerr).Msgf("Inpainting Endpoint - failed to remove mask ref file %s in cleanup", maskRef)
}
}
}
}()
// write original image and mask to disk as ref images so backends that
// accept reference image files can use them (maintainer request).
origTmp, err := os.CreateTemp(tmpDir, "refimg_")
if err != nil {
return err
}
if _, err := origTmp.Write(imgBytes); err != nil {
_ = origTmp.Close()
_ = os.Remove(origTmp.Name())
return err
}
if cerr := origTmp.Close(); cerr != nil {
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close orig temp file")
}
origRef = origTmp.Name()
maskTmp, err := os.CreateTemp(tmpDir, "refmask_")
if err != nil {
// cleanup origTmp on error
_ = os.Remove(origRef)
return err
}
if _, err := maskTmp.Write(maskBytes); err != nil {
_ = maskTmp.Close()
_ = os.Remove(maskTmp.Name())
_ = os.Remove(origRef)
return err
}
if cerr := maskTmp.Close(); cerr != nil {
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close mask temp file")
}
maskRef = maskTmp.Name()
// write JSON
enc := json.NewEncoder(jf)
if err := enc.Encode(jsonFile); err != nil {
if cerr := jf.Close(); cerr != nil {
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close temp json file after encode error")
}
return err
}
if cerr := jf.Close(); cerr != nil {
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close temp json file")
}
// rename to desired name
if err := os.Rename(jf.Name(), jsonPath); err != nil {
return err
}
// prepare dst
outTmp, err := os.CreateTemp(tmpDir, "out_")
if err != nil {
return err
}
if cerr := outTmp.Close(); cerr != nil {
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close out temp file")
}
dst = outTmp.Name() + ".png"
if err := os.Rename(outTmp.Name(), dst); err != nil {
return err
}
// Determine width/height default
width := 512
height := 512
// Call backend image generation via indirection so tests can stub it
// Note: ImageGenerationFunc will call into the loaded model's GenerateImage which expects src JSON
// Also pass ref images (orig + mask) so backends that support ref images can use them.
refImages := []string{origRef, maskRef}
fn, err := backend.ImageGenerationFunc(height, width, 0, steps, 0, prompt, "", jsonPath, dst, ml, *cfg, appConfig, refImages)
if err != nil {
return err
}
// Execute generation function (blocking)
if err := fn(); err != nil {
return err
}
// On success, build response URL using BaseURL middleware helper and
// the same `generated-images` prefix used by the server static mount.
baseURL := middleware.BaseURL(c)
// Build response using url.JoinPath for correct URL escaping
imgPath, err := url.JoinPath(baseURL, "generated-images", filepath.Base(dst))
if err != nil {
return err
}
created := int(time.Now().Unix())
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Data: []schema.Item{{
URL: imgPath,
}},
}
// mark success so defer cleanup will not remove output files
success = true
return c.JSON(http.StatusOK, resp)
}
}

View File

@@ -0,0 +1,107 @@
package openai
import (
"bytes"
"mime/multipart"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/stretchr/testify/require"
)
func makeMultipartRequest(t *testing.T, fields map[string]string, files map[string][]byte) (*http.Request, string) {
b := &bytes.Buffer{}
w := multipart.NewWriter(b)
for k, v := range fields {
_ = w.WriteField(k, v)
}
for fname, content := range files {
fw, err := w.CreateFormFile(fname, fname+".png")
require.NoError(t, err)
_, err = fw.Write(content)
require.NoError(t, err)
}
require.NoError(t, w.Close())
req := httptest.NewRequest(http.MethodPost, "/v1/images/inpainting", b)
req.Header.Set("Content-Type", w.FormDataContentType())
return req, w.FormDataContentType()
}
func TestInpainting_MissingFiles(t *testing.T) {
e := echo.New()
// handler requires cl, ml, appConfig but this test verifies missing files early
h := InpaintingEndpoint(nil, nil, config.NewApplicationConfig())
req := httptest.NewRequest(http.MethodPost, "/v1/images/inpainting", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := h(c)
require.Error(t, err)
}
func TestInpainting_HappyPath(t *testing.T) {
// Setup temp generated content dir
tmpDir, err := os.MkdirTemp("", "gencontent")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
appConf := config.NewApplicationConfig(config.WithGeneratedContentDir(tmpDir))
// stub the backend.ImageGenerationFunc
orig := backend.ImageGenerationFunc
backend.ImageGenerationFunc = func(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
fn := func() error {
// write a fake png file to dst
return os.WriteFile(dst, []byte("PNGDATA"), 0644)
}
return fn, nil
}
defer func() { backend.ImageGenerationFunc = orig }()
// prepare multipart request with image and mask
fields := map[string]string{"model": "dreamshaper-8-inpainting", "prompt": "A test"}
files := map[string][]byte{"image": []byte("IMAGEDATA"), "mask": []byte("MASKDATA")}
reqBuf, _ := makeMultipartRequest(t, fields, files)
rec := httptest.NewRecorder()
e := echo.New()
c := e.NewContext(reqBuf, rec)
// set a minimal model config in context as handler expects
c.Set(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG, &config.ModelConfig{Backend: "diffusers"})
h := InpaintingEndpoint(nil, nil, appConf)
// call handler
err = h(c)
require.NoError(t, err)
require.Equal(t, http.StatusOK, rec.Code)
// verify response body contains generated-images path
body := rec.Body.String()
require.Contains(t, body, "generated-images")
// confirm the file was created in tmpDir
// parse out filename from response (naive search)
// find "generated-images/" and extract until closing quote or brace
idx := bytes.Index(rec.Body.Bytes(), []byte("generated-images/"))
require.True(t, idx >= 0)
rest := rec.Body.Bytes()[idx:]
end := bytes.IndexAny(rest, "\",}\n")
if end == -1 {
end = len(rest)
}
fname := string(rest[len("generated-images/"):end])
// ensure file exists
_, err = os.Stat(filepath.Join(tmpDir, fname))
require.NoError(t, err)
}

View File

@@ -108,11 +108,11 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
log.Debug().Msgf("[model agent] [model: %s] Reasoning: %s", config.Name, s)
}),
cogito.WithToolCallBack(func(t *cogito.ToolChoice) bool {
log.Debug().Msgf("[model agent] [model: %s] Tool call: %s, reasoning: %s, arguments: %+v", t.Name, t.Reasoning, t.Arguments)
log.Debug().Msgf("[model agent] [model: %s] Tool call: %s, reasoning: %s, arguments: %+v", config.Name, t.Name, t.Reasoning, t.Arguments)
return true
}),
cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) {
log.Debug().Msgf("[model agent] [model: %s] Tool call result: %s, tool arguments: %+v", t.Name, t.Result, t.ToolArguments)
log.Debug().Msgf("[model agent] [model: %s] Tool call result: %s, result: %s, tool arguments: %+v", config.Name, t.Name, t.Result, t.ToolArguments)
}),
)

View File

@@ -55,6 +55,11 @@ func (re *RequestExtractor) setModelNameFromRequest(c echo.Context) {
model = c.QueryParam("model")
}
// Check FormValue for multipart/form-data requests (e.g., /v1/images/inpainting)
if model == "" {
model = c.FormValue("model")
}
if model == "" {
// Set model from bearer token, if available
auth := c.Request().Header.Get("Authorization")

View File

@@ -140,7 +140,8 @@ func RegisterOpenAIRoutes(app *echo.Echo,
// images
imageHandler := openai.ImageEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
imageMiddleware := []echo.MiddlewareFunc{
re.BuildConstantDefaultModelNameMiddleware("stablediffusion"),
// Default: use the first available image generation model
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_IMAGE)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
@@ -155,6 +156,11 @@ func RegisterOpenAIRoutes(app *echo.Echo,
app.POST("/v1/images/generations", imageHandler, imageMiddleware...)
app.POST("/images/generations", imageHandler, imageMiddleware...)
// inpainting endpoint (image + mask) - reuse same middleware config as images
inpaintingHandler := openai.InpaintingEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
app.POST("/v1/images/inpainting", inpaintingHandler, imageMiddleware...)
app.POST("/images/inpainting", inpaintingHandler, imageMiddleware...)
// videos (OpenAI-compatible endpoints mapped to LocalAI video handler)
videoHandler := openai.VideoEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
videoMiddleware := []echo.MiddlewareFunc{

View File

@@ -1197,6 +1197,47 @@ paths:
schema:
$ref: '#/definitions/schema.OpenAIResponse'
summary: Creates an image given a prompt.
/v1/images/inpainting:
post:
consumes:
- multipart/form-data
- application/json
parameters:
- in: formData
name: model
type: string
description: Model name (eg. dreamshaper-8-inpainting)
required: true
- in: formData
name: prompt
type: string
description: Positive prompt text
required: true
- in: formData
name: image
type: file
description: Source image (PNG/JPEG)
required: true
- in: formData
name: mask
type: file
description: Mask image (PNG). White=keep, Black=replace (or as backend expects)
required: true
- in: formData
name: steps
type: integer
description: Number of inference steps
- in: body
name: request
description: "Alternative JSON payload with base64 fields: { image: '<b64>', mask: '<b64>', model, prompt }"
schema:
$ref: '#/definitions/schema.OpenAIRequest'
responses:
"200":
description: Successful inpainting
schema:
$ref: '#/definitions/schema.OpenAIResponse'
summary: Creates an inpainted image given an image + mask + prompt.
/v1/mcp/chat/completions:
post:
parameters: