From 9621edb4c53e699c8eab51041cf4e42cd901c36e Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 28 Aug 2025 10:26:42 +0200 Subject: [PATCH] feat(diffusers): add support for wan2.2 (#6153) * feat(diffusers): add support for wan2.2 Signed-off-by: Ettore Di Giacinto * chore(ci): use ttl.sh for PRs Signed-off-by: Ettore Di Giacinto * Add ftfy deps Signed-off-by: Ettore Di Giacinto * Revert "chore(ci): use ttl.sh for PRs" This reverts commit c9fc3ecf288dd9d454a38d78840d214b72a140ca. * Simplify Signed-off-by: Ettore Di Giacinto * chore: do not pin torch/torchvision on cuda12 Signed-off-by: Ettore Di Giacinto --------- Signed-off-by: Ettore Di Giacinto --- backend/backend.proto | 20 +-- backend/python/diffusers/backend.py | 132 ++++++++++++++++-- backend/python/diffusers/requirements-cpu.txt | 3 +- .../diffusers/requirements-cublas11.txt | 7 +- .../diffusers/requirements-cublas12.txt | 7 +- .../python/diffusers/requirements-hipblas.txt | 3 +- .../python/diffusers/requirements-intel.txt | 3 +- backend/python/diffusers/requirements-l4t.txt | 3 +- backend/python/diffusers/requirements-mps.txt | 3 +- backend/python/mlx-audio/backend.py | 15 +- backend/python/mlx-vlm/backend.py | 15 +- backend/python/mlx/backend.py | 15 +- core/backend/video.go | 20 ++- core/http/endpoints/localai/video.go | 20 ++- core/schema/localai.go | 2 + 15 files changed, 195 insertions(+), 73 deletions(-) diff --git a/backend/backend.proto b/backend/backend.proto index 77bf3fefa..77791b7ee 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -312,15 +312,17 @@ message GenerateImageRequest { message GenerateVideoRequest { string prompt = 1; - string start_image = 2; // Path or base64 encoded image for the start frame - string end_image = 3; // Path or base64 encoded image for the end frame - int32 width = 4; - int32 height = 5; - int32 num_frames = 6; // Number of frames to generate - int32 fps = 7; // Frames per second - int32 seed = 8; - float cfg_scale = 9; // Classifier-free guidance scale - string dst = 10; // Output path for the generated video + string negative_prompt = 2; // Negative prompt for video generation + string start_image = 3; // Path or base64 encoded image for the start frame + string end_image = 4; // Path or base64 encoded image for the end frame + int32 width = 5; + int32 height = 6; + int32 num_frames = 7; // Number of frames to generate + int32 fps = 8; // Frames per second + int32 seed = 9; + float cfg_scale = 10; // Classifier-free guidance scale + int32 step = 11; // Number of inference steps + string dst = 12; // Output path for the generated video } message TTSRequest { diff --git a/backend/python/diffusers/backend.py b/backend/python/diffusers/backend.py index 2c7ef2b24..c2f59dd11 100755 --- a/backend/python/diffusers/backend.py +++ b/backend/python/diffusers/backend.py @@ -18,7 +18,7 @@ import backend_pb2_grpc import grpc from diffusers import SanaPipeline, StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, \ - EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, QwenImageEditPipeline + EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, QwenImageEditPipeline, AutoencoderKLWan, WanPipeline, WanImageToVideoPipeline from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline, Lumina2Text2ImgPipeline from diffusers.pipelines.stable_diffusion import safety_checker from diffusers.utils import load_image, export_to_video @@ -72,13 +72,6 @@ def is_float(s): except ValueError: return False -def is_int(s): - try: - int(s) - return True - except ValueError: - return False - # The scheduler list mapping was taken from here: https://github.com/neggles/animatediff-cli/blob/6f336f5f4b5e38e85d7f06f1744ef42d0a45f2a7/src/animatediff/schedulers.py#L39 # Credits to https://github.com/neggles # See https://github.com/huggingface/diffusers/issues/4167 for more details on sched mapping from A1111 @@ -184,9 +177,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): key, value = opt.split(":") # if value is a number, convert it to the appropriate type if is_float(value): - value = float(value) - elif is_int(value): - value = int(value) + if value.is_integer(): + value = int(value) + else: + value = float(value) self.options[key] = value # From options, extract if present "torch_dtype" and set it to the appropriate type @@ -334,6 +328,32 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): torch_dtype=torch.bfloat16) self.pipe.vae.to(torch.bfloat16) self.pipe.text_encoder.to(torch.bfloat16) + elif request.PipelineType == "WanPipeline": + # WAN2.2 pipeline requires special VAE handling + vae = AutoencoderKLWan.from_pretrained( + request.Model, + subfolder="vae", + torch_dtype=torch.float32 + ) + self.pipe = WanPipeline.from_pretrained( + request.Model, + vae=vae, + torch_dtype=torchType + ) + self.txt2vid = True # WAN2.2 is a text-to-video pipeline + elif request.PipelineType == "WanImageToVideoPipeline": + # WAN2.2 image-to-video pipeline + vae = AutoencoderKLWan.from_pretrained( + request.Model, + subfolder="vae", + torch_dtype=torch.float32 + ) + self.pipe = WanImageToVideoPipeline.from_pretrained( + request.Model, + vae=vae, + torch_dtype=torchType + ) + self.img2vid = True # WAN2.2 image-to-video pipeline if CLIPSKIP and request.CLIPSkip != 0: self.clip_skip = request.CLIPSkip @@ -575,6 +595,96 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): return backend_pb2.Result(message="Media generated", success=True) + def GenerateVideo(self, request, context): + try: + prompt = request.prompt + if not prompt: + return backend_pb2.Result(success=False, message="No prompt provided for video generation") + + # Set default values from request or use defaults + num_frames = request.num_frames if request.num_frames > 0 else 81 + fps = request.fps if request.fps > 0 else 16 + cfg_scale = request.cfg_scale if request.cfg_scale > 0 else 4.0 + num_inference_steps = request.step if request.step > 0 else 40 + + # Prepare generation parameters + kwargs = { + "prompt": prompt, + "negative_prompt": request.negative_prompt if request.negative_prompt else "", + "height": request.height if request.height > 0 else 720, + "width": request.width if request.width > 0 else 1280, + "num_frames": num_frames, + "guidance_scale": cfg_scale, + "num_inference_steps": num_inference_steps, + } + + # Add custom options from self.options (including guidance_scale_2 if specified) + kwargs.update(self.options) + + # Set seed if provided + if request.seed > 0: + kwargs["generator"] = torch.Generator(device=self.device).manual_seed(request.seed) + + # Handle start and end images for video generation + if request.start_image: + kwargs["start_image"] = load_image(request.start_image) + if request.end_image: + kwargs["end_image"] = load_image(request.end_image) + + print(f"Generating video with {kwargs=}", file=sys.stderr) + + # Generate video frames based on pipeline type + if self.PipelineType == "WanPipeline": + # WAN2.2 text-to-video generation + output = self.pipe(**kwargs) + frames = output.frames[0] # WAN2.2 returns frames in this format + elif self.PipelineType == "WanImageToVideoPipeline": + # WAN2.2 image-to-video generation + if request.start_image: + # Load and resize the input image according to WAN2.2 requirements + image = load_image(request.start_image) + # Use request dimensions or defaults, but respect WAN2.2 constraints + request_height = request.height if request.height > 0 else 480 + request_width = request.width if request.width > 0 else 832 + max_area = request_height * request_width + aspect_ratio = image.height / image.width + mod_value = self.pipe.vae_scale_factor_spatial * self.pipe.transformer.config.patch_size[1] + height = round((max_area * aspect_ratio) ** 0.5 / mod_value) * mod_value + width = round((max_area / aspect_ratio) ** 0.5 / mod_value) * mod_value + image = image.resize((width, height)) + kwargs["image"] = image + kwargs["height"] = height + kwargs["width"] = width + + output = self.pipe(**kwargs) + frames = output.frames[0] + elif self.img2vid: + # Generic image-to-video generation + if request.start_image: + image = load_image(request.start_image) + image = image.resize((request.width if request.width > 0 else 1024, + request.height if request.height > 0 else 576)) + kwargs["image"] = image + + output = self.pipe(**kwargs) + frames = output.frames[0] + elif self.txt2vid: + # Generic text-to-video generation + output = self.pipe(**kwargs) + frames = output.frames[0] + else: + return backend_pb2.Result(success=False, message=f"Pipeline {self.PipelineType} does not support video generation") + + # Export video + export_to_video(frames, request.dst, fps=fps) + + return backend_pb2.Result(message="Video generated successfully", success=True) + + except Exception as err: + print(f"Error generating video: {err}", file=sys.stderr) + traceback.print_exc() + return backend_pb2.Result(success=False, message=f"Error generating video: {err}") + def serve(address): server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), diff --git a/backend/python/diffusers/requirements-cpu.txt b/backend/python/diffusers/requirements-cpu.txt index 659d2b8e0..fceda06d2 100644 --- a/backend/python/diffusers/requirements-cpu.txt +++ b/backend/python/diffusers/requirements-cpu.txt @@ -8,4 +8,5 @@ compel peft sentencepiece torch==2.7.1 -optimum-quanto \ No newline at end of file +optimum-quanto +ftfy \ No newline at end of file diff --git a/backend/python/diffusers/requirements-cublas11.txt b/backend/python/diffusers/requirements-cublas11.txt index 142234a7c..7b77f7f68 100644 --- a/backend/python/diffusers/requirements-cublas11.txt +++ b/backend/python/diffusers/requirements-cublas11.txt @@ -1,11 +1,12 @@ --extra-index-url https://download.pytorch.org/whl/cu118 -torch==2.7.1+cu118 -torchvision==0.22.1+cu118 git+https://github.com/huggingface/diffusers opencv-python transformers +torchvision==0.22.1 accelerate compel peft sentencepiece -optimum-quanto \ No newline at end of file +torch==2.7.1 +optimum-quanto +ftfy \ No newline at end of file diff --git a/backend/python/diffusers/requirements-cublas12.txt b/backend/python/diffusers/requirements-cublas12.txt index 77dd54b73..2c90df9bb 100644 --- a/backend/python/diffusers/requirements-cublas12.txt +++ b/backend/python/diffusers/requirements-cublas12.txt @@ -1,10 +1,11 @@ -torch==2.7.1 -torchvision==0.22.1 +--extra-index-url https://download.pytorch.org/whl/cu121 git+https://github.com/huggingface/diffusers opencv-python transformers +torchvision accelerate compel peft sentencepiece -optimum-quanto \ No newline at end of file +torch +ftfy \ No newline at end of file diff --git a/backend/python/diffusers/requirements-hipblas.txt b/backend/python/diffusers/requirements-hipblas.txt index 2bab13494..aeea37563 100644 --- a/backend/python/diffusers/requirements-hipblas.txt +++ b/backend/python/diffusers/requirements-hipblas.txt @@ -8,4 +8,5 @@ accelerate compel peft sentencepiece -optimum-quanto \ No newline at end of file +optimum-quanto +ftfy \ No newline at end of file diff --git a/backend/python/diffusers/requirements-intel.txt b/backend/python/diffusers/requirements-intel.txt index 98ccac7b6..fec4d9df7 100644 --- a/backend/python/diffusers/requirements-intel.txt +++ b/backend/python/diffusers/requirements-intel.txt @@ -12,4 +12,5 @@ accelerate compel peft sentencepiece -optimum-quanto \ No newline at end of file +optimum-quanto +ftfy \ No newline at end of file diff --git a/backend/python/diffusers/requirements-l4t.txt b/backend/python/diffusers/requirements-l4t.txt index 0540cd27d..11c095342 100644 --- a/backend/python/diffusers/requirements-l4t.txt +++ b/backend/python/diffusers/requirements-l4t.txt @@ -8,4 +8,5 @@ peft optimum-quanto numpy<2 sentencepiece -torchvision \ No newline at end of file +torchvision +ftfy \ No newline at end of file diff --git a/backend/python/diffusers/requirements-mps.txt b/backend/python/diffusers/requirements-mps.txt index 77dd54b73..8b7c2413b 100644 --- a/backend/python/diffusers/requirements-mps.txt +++ b/backend/python/diffusers/requirements-mps.txt @@ -7,4 +7,5 @@ accelerate compel peft sentencepiece -optimum-quanto \ No newline at end of file +optimum-quanto +ftfy \ No newline at end of file diff --git a/backend/python/mlx-audio/backend.py b/backend/python/mlx-audio/backend.py index a098b8872..d8c1a807a 100644 --- a/backend/python/mlx-audio/backend.py +++ b/backend/python/mlx-audio/backend.py @@ -40,14 +40,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): except ValueError: return False - def _is_int(self, s): - """Check if a string can be converted to int.""" - try: - int(s) - return True - except ValueError: - return False - def Health(self, request, context): """ Returns a health check message. @@ -89,9 +81,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): # Convert numeric values to appropriate types if self._is_float(value): - value = float(value) - elif self._is_int(value): - value = int(value) + if float(value).is_integer(): + value = int(value) + else: + value = float(value) elif value.lower() in ["true", "false"]: value = value.lower() == "true" diff --git a/backend/python/mlx-vlm/backend.py b/backend/python/mlx-vlm/backend.py index 02730c814..b010a1d27 100644 --- a/backend/python/mlx-vlm/backend.py +++ b/backend/python/mlx-vlm/backend.py @@ -40,14 +40,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): except ValueError: return False - def _is_int(self, s): - """Check if a string can be converted to int.""" - try: - int(s) - return True - except ValueError: - return False - def Health(self, request, context): """ Returns a health check message. @@ -89,9 +81,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): # Convert numeric values to appropriate types if self._is_float(value): - value = float(value) - elif self._is_int(value): - value = int(value) + if float(value).is_integer(): + value = int(value) + else: + value = float(value) elif value.lower() in ["true", "false"]: value = value.lower() == "true" diff --git a/backend/python/mlx/backend.py b/backend/python/mlx/backend.py index 84024b387..a7326685e 100644 --- a/backend/python/mlx/backend.py +++ b/backend/python/mlx/backend.py @@ -38,14 +38,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): except ValueError: return False - def _is_int(self, s): - """Check if a string can be converted to int.""" - try: - int(s) - return True - except ValueError: - return False - def Health(self, request, context): """ Returns a health check message. @@ -87,9 +79,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): # Convert numeric values to appropriate types if self._is_float(value): - value = float(value) - elif self._is_int(value): - value = int(value) + if float(value).is_integer(): + value = int(value) + else: + value = float(value) elif value.lower() in ["true", "false"]: value = value.lower() == "true" diff --git a/core/backend/video.go b/core/backend/video.go index b5a4dbc04..a7a39bf24 100644 --- a/core/backend/video.go +++ b/core/backend/video.go @@ -7,7 +7,7 @@ import ( model "github.com/mudler/LocalAI/pkg/model" ) -func VideoGeneration(height, width int32, prompt, startImage, endImage, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() error, error) { +func VideoGeneration(height, width int32, prompt, negativePrompt, startImage, endImage, dst string, numFrames, fps, seed int32, cfgScale float32, step int32, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() error, error) { opts := ModelOptions(modelConfig, appConfig) inferenceModel, err := loader.Load( @@ -22,12 +22,18 @@ func VideoGeneration(height, width int32, prompt, startImage, endImage, dst stri _, err := inferenceModel.GenerateVideo( appConfig.Context, &proto.GenerateVideoRequest{ - Height: height, - Width: width, - Prompt: prompt, - StartImage: startImage, - EndImage: endImage, - Dst: dst, + Height: height, + Width: width, + Prompt: prompt, + NegativePrompt: negativePrompt, + StartImage: startImage, + EndImage: endImage, + NumFrames: numFrames, + Fps: fps, + Seed: seed, + CfgScale: cfgScale, + Step: step, + Dst: dst, }) return err } diff --git a/core/http/endpoints/localai/video.go b/core/http/endpoints/localai/video.go index df01ce316..68b8ec011 100644 --- a/core/http/endpoints/localai/video.go +++ b/core/http/endpoints/localai/video.go @@ -61,7 +61,7 @@ func downloadFile(url string) (string, error) { */ // VideoEndpoint // @Summary Creates a video given a prompt. -// @Param request body schema.OpenAIRequest true "query params" +// @Param request body schema.VideoRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /video [post] func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { @@ -166,7 +166,23 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi baseURL := c.BaseURL() - fn, err := backend.VideoGeneration(height, width, input.Prompt, src, input.EndImage, output, ml, *config, appConfig) + fn, err := backend.VideoGeneration( + height, + width, + input.Prompt, + input.NegativePrompt, + src, + input.EndImage, + output, + input.NumFrames, + input.FPS, + input.Seed, + input.CFGScale, + input.Step, + ml, + *config, + appConfig, + ) if err != nil { return err } diff --git a/core/schema/localai.go b/core/schema/localai.go index 5949e743d..1ea0c1965 100644 --- a/core/schema/localai.go +++ b/core/schema/localai.go @@ -28,6 +28,7 @@ type GalleryResponse struct { type VideoRequest struct { BasicModelRequest Prompt string `json:"prompt" yaml:"prompt"` + NegativePrompt string `json:"negative_prompt" yaml:"negative_prompt"` StartImage string `json:"start_image" yaml:"start_image"` EndImage string `json:"end_image" yaml:"end_image"` Width int32 `json:"width" yaml:"width"` @@ -36,6 +37,7 @@ type VideoRequest struct { FPS int32 `json:"fps" yaml:"fps"` Seed int32 `json:"seed" yaml:"seed"` CFGScale float32 `json:"cfg_scale" yaml:"cfg_scale"` + Step int32 `json:"step" yaml:"step"` ResponseFormat string `json:"response_format" yaml:"response_format"` }