diff --git a/backend/python/diffusers/backend.py b/backend/python/diffusers/backend.py index ef5f1b5c0..2c7ef2b24 100755 --- a/backend/python/diffusers/backend.py +++ b/backend/python/diffusers/backend.py @@ -475,11 +475,24 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): "num_inference_steps": steps, } - if request.src != "" and not self.controlnet and not self.img2vid: - image = Image.open(request.src) + # Handle image source: prioritize RefImages over request.src + image_src = None + if hasattr(request, 'ref_images') and request.ref_images and len(request.ref_images) > 0: + # Use the first reference image if available + image_src = request.ref_images[0] + print(f"Using reference image: {image_src}", file=sys.stderr) + elif request.src != "": + # Fall back to request.src if no ref_images + image_src = request.src + print(f"Using source image: {image_src}", file=sys.stderr) + else: + print("No image source provided", file=sys.stderr) + + if image_src and not self.controlnet and not self.img2vid: + image = Image.open(image_src) options["image"] = image - elif self.controlnet and request.src: - pose_image = load_image(request.src) + elif self.controlnet and image_src: + pose_image = load_image(image_src) options["image"] = pose_image if CLIPSKIP and self.clip_skip != 0: @@ -521,7 +534,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): if self.img2vid: # Load the conditioning image - image = load_image(request.src) + if image_src: + image = load_image(image_src) + else: + # Fallback to request.src for img2vid if no ref_images + image = load_image(request.src) image = image.resize((1024, 576)) generator = torch.manual_seed(request.seed)