mirror of
https://github.com/mudler/LocalAI.git
synced 2025-12-30 22:20:20 -06:00
feat(diffusers): add support for wan2.2 (#6153)
* feat(diffusers): add support for wan2.2
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
* chore(ci): use ttl.sh for PRs
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
* Add ftfy deps
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
* Revert "chore(ci): use ttl.sh for PRs"
This reverts commit c9fc3ecf28.
* Simplify
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
* chore: do not pin torch/torchvision on cuda12
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
---------
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
7ce92f0646
commit
9621edb4c5
@@ -18,7 +18,7 @@ import backend_pb2_grpc
|
||||
import grpc
|
||||
|
||||
from diffusers import SanaPipeline, StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, \
|
||||
EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, QwenImageEditPipeline
|
||||
EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, QwenImageEditPipeline, AutoencoderKLWan, WanPipeline, WanImageToVideoPipeline
|
||||
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline, Lumina2Text2ImgPipeline
|
||||
from diffusers.pipelines.stable_diffusion import safety_checker
|
||||
from diffusers.utils import load_image, export_to_video
|
||||
@@ -72,13 +72,6 @@ def is_float(s):
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def is_int(s):
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
# The scheduler list mapping was taken from here: https://github.com/neggles/animatediff-cli/blob/6f336f5f4b5e38e85d7f06f1744ef42d0a45f2a7/src/animatediff/schedulers.py#L39
|
||||
# Credits to https://github.com/neggles
|
||||
# See https://github.com/huggingface/diffusers/issues/4167 for more details on sched mapping from A1111
|
||||
@@ -184,9 +177,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
key, value = opt.split(":")
|
||||
# if value is a number, convert it to the appropriate type
|
||||
if is_float(value):
|
||||
value = float(value)
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
if value.is_integer():
|
||||
value = int(value)
|
||||
else:
|
||||
value = float(value)
|
||||
self.options[key] = value
|
||||
|
||||
# From options, extract if present "torch_dtype" and set it to the appropriate type
|
||||
@@ -334,6 +328,32 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
torch_dtype=torch.bfloat16)
|
||||
self.pipe.vae.to(torch.bfloat16)
|
||||
self.pipe.text_encoder.to(torch.bfloat16)
|
||||
elif request.PipelineType == "WanPipeline":
|
||||
# WAN2.2 pipeline requires special VAE handling
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
request.Model,
|
||||
subfolder="vae",
|
||||
torch_dtype=torch.float32
|
||||
)
|
||||
self.pipe = WanPipeline.from_pretrained(
|
||||
request.Model,
|
||||
vae=vae,
|
||||
torch_dtype=torchType
|
||||
)
|
||||
self.txt2vid = True # WAN2.2 is a text-to-video pipeline
|
||||
elif request.PipelineType == "WanImageToVideoPipeline":
|
||||
# WAN2.2 image-to-video pipeline
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
request.Model,
|
||||
subfolder="vae",
|
||||
torch_dtype=torch.float32
|
||||
)
|
||||
self.pipe = WanImageToVideoPipeline.from_pretrained(
|
||||
request.Model,
|
||||
vae=vae,
|
||||
torch_dtype=torchType
|
||||
)
|
||||
self.img2vid = True # WAN2.2 image-to-video pipeline
|
||||
|
||||
if CLIPSKIP and request.CLIPSkip != 0:
|
||||
self.clip_skip = request.CLIPSkip
|
||||
@@ -575,6 +595,96 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
return backend_pb2.Result(message="Media generated", success=True)
|
||||
|
||||
def GenerateVideo(self, request, context):
|
||||
try:
|
||||
prompt = request.prompt
|
||||
if not prompt:
|
||||
return backend_pb2.Result(success=False, message="No prompt provided for video generation")
|
||||
|
||||
# Set default values from request or use defaults
|
||||
num_frames = request.num_frames if request.num_frames > 0 else 81
|
||||
fps = request.fps if request.fps > 0 else 16
|
||||
cfg_scale = request.cfg_scale if request.cfg_scale > 0 else 4.0
|
||||
num_inference_steps = request.step if request.step > 0 else 40
|
||||
|
||||
# Prepare generation parameters
|
||||
kwargs = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": request.negative_prompt if request.negative_prompt else "",
|
||||
"height": request.height if request.height > 0 else 720,
|
||||
"width": request.width if request.width > 0 else 1280,
|
||||
"num_frames": num_frames,
|
||||
"guidance_scale": cfg_scale,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
}
|
||||
|
||||
# Add custom options from self.options (including guidance_scale_2 if specified)
|
||||
kwargs.update(self.options)
|
||||
|
||||
# Set seed if provided
|
||||
if request.seed > 0:
|
||||
kwargs["generator"] = torch.Generator(device=self.device).manual_seed(request.seed)
|
||||
|
||||
# Handle start and end images for video generation
|
||||
if request.start_image:
|
||||
kwargs["start_image"] = load_image(request.start_image)
|
||||
if request.end_image:
|
||||
kwargs["end_image"] = load_image(request.end_image)
|
||||
|
||||
print(f"Generating video with {kwargs=}", file=sys.stderr)
|
||||
|
||||
# Generate video frames based on pipeline type
|
||||
if self.PipelineType == "WanPipeline":
|
||||
# WAN2.2 text-to-video generation
|
||||
output = self.pipe(**kwargs)
|
||||
frames = output.frames[0] # WAN2.2 returns frames in this format
|
||||
elif self.PipelineType == "WanImageToVideoPipeline":
|
||||
# WAN2.2 image-to-video generation
|
||||
if request.start_image:
|
||||
# Load and resize the input image according to WAN2.2 requirements
|
||||
image = load_image(request.start_image)
|
||||
# Use request dimensions or defaults, but respect WAN2.2 constraints
|
||||
request_height = request.height if request.height > 0 else 480
|
||||
request_width = request.width if request.width > 0 else 832
|
||||
max_area = request_height * request_width
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = self.pipe.vae_scale_factor_spatial * self.pipe.transformer.config.patch_size[1]
|
||||
height = round((max_area * aspect_ratio) ** 0.5 / mod_value) * mod_value
|
||||
width = round((max_area / aspect_ratio) ** 0.5 / mod_value) * mod_value
|
||||
image = image.resize((width, height))
|
||||
kwargs["image"] = image
|
||||
kwargs["height"] = height
|
||||
kwargs["width"] = width
|
||||
|
||||
output = self.pipe(**kwargs)
|
||||
frames = output.frames[0]
|
||||
elif self.img2vid:
|
||||
# Generic image-to-video generation
|
||||
if request.start_image:
|
||||
image = load_image(request.start_image)
|
||||
image = image.resize((request.width if request.width > 0 else 1024,
|
||||
request.height if request.height > 0 else 576))
|
||||
kwargs["image"] = image
|
||||
|
||||
output = self.pipe(**kwargs)
|
||||
frames = output.frames[0]
|
||||
elif self.txt2vid:
|
||||
# Generic text-to-video generation
|
||||
output = self.pipe(**kwargs)
|
||||
frames = output.frames[0]
|
||||
else:
|
||||
return backend_pb2.Result(success=False, message=f"Pipeline {self.PipelineType} does not support video generation")
|
||||
|
||||
# Export video
|
||||
export_to_video(frames, request.dst, fps=fps)
|
||||
|
||||
return backend_pb2.Result(message="Video generated successfully", success=True)
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error generating video: {err}", file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
return backend_pb2.Result(success=False, message=f"Error generating video: {err}")
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
|
||||
@@ -8,4 +8,5 @@ compel
|
||||
peft
|
||||
sentencepiece
|
||||
torch==2.7.1
|
||||
optimum-quanto
|
||||
optimum-quanto
|
||||
ftfy
|
||||
@@ -1,11 +1,12 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.7.1+cu118
|
||||
torchvision==0.22.1+cu118
|
||||
git+https://github.com/huggingface/diffusers
|
||||
opencv-python
|
||||
transformers
|
||||
torchvision==0.22.1
|
||||
accelerate
|
||||
compel
|
||||
peft
|
||||
sentencepiece
|
||||
optimum-quanto
|
||||
torch==2.7.1
|
||||
optimum-quanto
|
||||
ftfy
|
||||
@@ -1,10 +1,11 @@
|
||||
torch==2.7.1
|
||||
torchvision==0.22.1
|
||||
--extra-index-url https://download.pytorch.org/whl/cu121
|
||||
git+https://github.com/huggingface/diffusers
|
||||
opencv-python
|
||||
transformers
|
||||
torchvision
|
||||
accelerate
|
||||
compel
|
||||
peft
|
||||
sentencepiece
|
||||
optimum-quanto
|
||||
torch
|
||||
ftfy
|
||||
@@ -8,4 +8,5 @@ accelerate
|
||||
compel
|
||||
peft
|
||||
sentencepiece
|
||||
optimum-quanto
|
||||
optimum-quanto
|
||||
ftfy
|
||||
@@ -12,4 +12,5 @@ accelerate
|
||||
compel
|
||||
peft
|
||||
sentencepiece
|
||||
optimum-quanto
|
||||
optimum-quanto
|
||||
ftfy
|
||||
@@ -8,4 +8,5 @@ peft
|
||||
optimum-quanto
|
||||
numpy<2
|
||||
sentencepiece
|
||||
torchvision
|
||||
torchvision
|
||||
ftfy
|
||||
@@ -7,4 +7,5 @@ accelerate
|
||||
compel
|
||||
peft
|
||||
sentencepiece
|
||||
optimum-quanto
|
||||
optimum-quanto
|
||||
ftfy
|
||||
Reference in New Issue
Block a user