From 960e51e5277734a80fe2a0f3e23dfb9a3d1257dd Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 24 Aug 2025 22:03:08 +0200 Subject: [PATCH] chore(diffusers): support both src and reference_images in diffusers (#6135) Signed-off-by: Ettore Di Giacinto --- backend/python/diffusers/backend.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/backend/python/diffusers/backend.py b/backend/python/diffusers/backend.py index ef5f1b5c0..2c7ef2b24 100755 --- a/backend/python/diffusers/backend.py +++ b/backend/python/diffusers/backend.py @@ -475,11 +475,24 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): "num_inference_steps": steps, } - if request.src != "" and not self.controlnet and not self.img2vid: - image = Image.open(request.src) + # Handle image source: prioritize RefImages over request.src + image_src = None + if hasattr(request, 'ref_images') and request.ref_images and len(request.ref_images) > 0: + # Use the first reference image if available + image_src = request.ref_images[0] + print(f"Using reference image: {image_src}", file=sys.stderr) + elif request.src != "": + # Fall back to request.src if no ref_images + image_src = request.src + print(f"Using source image: {image_src}", file=sys.stderr) + else: + print("No image source provided", file=sys.stderr) + + if image_src and not self.controlnet and not self.img2vid: + image = Image.open(image_src) options["image"] = image - elif self.controlnet and request.src: - pose_image = load_image(request.src) + elif self.controlnet and image_src: + pose_image = load_image(image_src) options["image"] = pose_image if CLIPSKIP and self.clip_skip != 0: @@ -521,7 +534,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): if self.img2vid: # Load the conditioning image - image = load_image(request.src) + if image_src: + image = load_image(image_src) + else: + # Fallback to request.src for img2vid if no ref_images + image = load_image(request.src) image = image.resize((1024, 576)) generator = torch.manual_seed(request.seed)