mirror of
https://github.com/mudler/LocalAI.git
synced 2025-12-30 22:20:20 -06:00
feat(diffusers): implement dynamic pipeline loader to remove per-pipeline conditionals (#7365)
* Initial plan Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add dynamic loader for diffusers pipelines and refactor backend.py Co-authored-by: mudler <2420543+mudler@users.noreply.github.com> Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Fix pipeline discovery error handling and test mock issue Co-authored-by: mudler <2420543+mudler@users.noreply.github.com> Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Address code review feedback: direct imports, better error handling, improved tests Co-authored-by: mudler <2420543+mudler@users.noreply.github.com> Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Address remaining code review feedback: specific exceptions, registry access, test imports Co-authored-by: mudler <2420543+mudler@users.noreply.github.com> Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add defensive fallback for DiffusionPipeline registry access Co-authored-by: mudler <2420543+mudler@users.noreply.github.com> Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Actually use dynamic pipeline loading for all pipelines in backend Co-authored-by: mudler <2420543+mudler@users.noreply.github.com> Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Use dynamic loader consistently for all pipelines including AutoPipelineForText2Image Co-authored-by: mudler <2420543+mudler@users.noreply.github.com> Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Move dynamic loader tests into test.py for CI compatibility Co-authored-by: mudler <2420543+mudler@users.noreply.github.com> Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Extend dynamic loader to discover any diffusers class type, not just DiffusionPipeline Co-authored-by: mudler <2420543+mudler@users.noreply.github.com> Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add AutoPipeline classes to pipeline registry for default model loading Co-authored-by: mudler <2420543+mudler@users.noreply.github.com> Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(python): set pyvenv python home Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * do pyenv update during start Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Minor changes Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: mudler <2420543+mudler@users.noreply.github.com> Co-authored-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
@@ -237,7 +237,14 @@ function getBuildProfile() {
|
||||
# Make the venv relocatable:
|
||||
# - rewrite venv/bin/python{,3} to relative symlinks into $(_portable_dir)
|
||||
# - normalize entrypoint shebangs to /usr/bin/env python3
|
||||
# - optionally update pyvenv.cfg to point to the portable Python directory (only at runtime)
|
||||
# Usage: _makeVenvPortable [--update-pyvenv-cfg]
|
||||
_makeVenvPortable() {
|
||||
local update_pyvenv_cfg=false
|
||||
if [ "${1:-}" = "--update-pyvenv-cfg" ]; then
|
||||
update_pyvenv_cfg=true
|
||||
fi
|
||||
|
||||
local venv_dir="${EDIR}/venv"
|
||||
local vbin="${venv_dir}/bin"
|
||||
|
||||
@@ -255,7 +262,39 @@ _makeVenvPortable() {
|
||||
ln -s "${rel_py}" "${vbin}/python3"
|
||||
ln -s "python3" "${vbin}/python"
|
||||
|
||||
# 2) Rewrite shebangs of entry points to use env, so the venv is relocatable
|
||||
# 2) Update pyvenv.cfg to point to the portable Python directory (only at runtime)
|
||||
# Use absolute path resolved at runtime so it works when the venv is copied
|
||||
if [ "$update_pyvenv_cfg" = "true" ]; then
|
||||
local pyvenv_cfg="${venv_dir}/pyvenv.cfg"
|
||||
if [ -f "${pyvenv_cfg}" ]; then
|
||||
local portable_dir="$(_portable_dir)"
|
||||
# Resolve to absolute path - this ensures it works when the backend is copied
|
||||
# Only resolve if the directory exists (it should if ensurePortablePython was called)
|
||||
if [ -d "${portable_dir}" ]; then
|
||||
portable_dir="$(cd "${portable_dir}" && pwd)"
|
||||
else
|
||||
# Fallback to relative path if directory doesn't exist yet
|
||||
portable_dir="../python"
|
||||
fi
|
||||
local sed_i=(sed -i)
|
||||
# macOS/BSD sed needs a backup suffix; GNU sed doesn't. Make it portable:
|
||||
if sed --version >/dev/null 2>&1; then
|
||||
sed_i=(sed -i)
|
||||
else
|
||||
sed_i=(sed -i '')
|
||||
fi
|
||||
# Update the home field in pyvenv.cfg
|
||||
# Handle both absolute paths (starting with /) and relative paths
|
||||
if grep -q "^home = " "${pyvenv_cfg}"; then
|
||||
"${sed_i[@]}" "s|^home = .*|home = ${portable_dir}|" "${pyvenv_cfg}"
|
||||
else
|
||||
# If home field doesn't exist, add it
|
||||
echo "home = ${portable_dir}" >> "${pyvenv_cfg}"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
# 3) Rewrite shebangs of entry points to use env, so the venv is relocatable
|
||||
# Only touch text files that start with #! and reference the current venv.
|
||||
local ve_abs="${vbin}/python"
|
||||
local sed_i=(sed -i)
|
||||
@@ -316,6 +355,7 @@ function ensureVenv() {
|
||||
fi
|
||||
fi
|
||||
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
|
||||
# During install, only update symlinks and shebangs, not pyvenv.cfg
|
||||
_makeVenvPortable
|
||||
fi
|
||||
fi
|
||||
@@ -420,6 +460,11 @@ function installRequirements() {
|
||||
# - ${BACKEND_NAME}.py
|
||||
function startBackend() {
|
||||
ensureVenv
|
||||
# Update pyvenv.cfg before running to ensure paths are correct for current location
|
||||
# This is critical when the backend position is dynamic (e.g., copied from container)
|
||||
if [ "x${PORTABLE_PYTHON}" == "xtrue" ] || [ -x "$(_portable_python)" ]; then
|
||||
_makeVenvPortable --update-pyvenv-cfg
|
||||
fi
|
||||
if [ ! -z "${BACKEND_FILE:-}" ]; then
|
||||
exec "${EDIR}/venv/bin/python" "${BACKEND_FILE}" "$@"
|
||||
elif [ -e "${MY_DIR}/server.py" ]; then
|
||||
|
||||
@@ -1,5 +1,136 @@
|
||||
# Creating a separate environment for the diffusers project
|
||||
# LocalAI Diffusers Backend
|
||||
|
||||
This backend provides gRPC access to Hugging Face diffusers pipelines with dynamic pipeline loading.
|
||||
|
||||
## Creating a separate environment for the diffusers project
|
||||
|
||||
```
|
||||
make diffusers
|
||||
```
|
||||
```
|
||||
|
||||
## Dynamic Pipeline Loader
|
||||
|
||||
The diffusers backend includes a dynamic pipeline loader (`diffusers_dynamic_loader.py`) that automatically discovers and loads diffusers pipelines at runtime. This eliminates the need for per-pipeline conditional statements - new pipelines added to diffusers become available automatically without code changes.
|
||||
|
||||
### How It Works
|
||||
|
||||
1. **Pipeline Discovery**: On first use, the loader scans the `diffusers` package to find all classes that inherit from `DiffusionPipeline`.
|
||||
|
||||
2. **Registry Caching**: Discovery results are cached for the lifetime of the process to avoid repeated scanning.
|
||||
|
||||
3. **Task Aliases**: The loader automatically derives task aliases from class names (e.g., "text-to-image", "image-to-image", "inpainting") without hardcoding.
|
||||
|
||||
4. **Multiple Resolution Methods**: Pipelines can be resolved by:
|
||||
- Exact class name (e.g., `StableDiffusionPipeline`)
|
||||
- Task alias (e.g., `text-to-image`, `img2img`)
|
||||
- Model ID (uses HuggingFace Hub to infer pipeline type)
|
||||
|
||||
### Usage Examples
|
||||
|
||||
```python
|
||||
from diffusers_dynamic_loader import (
|
||||
load_diffusers_pipeline,
|
||||
get_available_pipelines,
|
||||
get_available_tasks,
|
||||
resolve_pipeline_class,
|
||||
discover_diffusers_classes,
|
||||
get_available_classes,
|
||||
)
|
||||
|
||||
# List all available pipelines
|
||||
pipelines = get_available_pipelines()
|
||||
print(f"Available pipelines: {pipelines[:10]}...")
|
||||
|
||||
# List all task aliases
|
||||
tasks = get_available_tasks()
|
||||
print(f"Available tasks: {tasks}")
|
||||
|
||||
# Resolve a pipeline class by name
|
||||
cls = resolve_pipeline_class(class_name="StableDiffusionPipeline")
|
||||
|
||||
# Resolve by task alias
|
||||
cls = resolve_pipeline_class(task="stable-diffusion")
|
||||
|
||||
# Load and instantiate a pipeline
|
||||
pipe = load_diffusers_pipeline(
|
||||
class_name="StableDiffusionPipeline",
|
||||
model_id="runwayml/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
# Load from single file
|
||||
pipe = load_diffusers_pipeline(
|
||||
class_name="StableDiffusionPipeline",
|
||||
model_id="/path/to/model.safetensors",
|
||||
from_single_file=True,
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
# Discover other diffusers classes (schedulers, models, etc.)
|
||||
schedulers = discover_diffusers_classes("SchedulerMixin")
|
||||
print(f"Available schedulers: {list(schedulers.keys())[:5]}...")
|
||||
|
||||
# Get list of available scheduler classes
|
||||
scheduler_list = get_available_classes("SchedulerMixin")
|
||||
```
|
||||
|
||||
### Generic Class Discovery
|
||||
|
||||
The dynamic loader can discover not just pipelines but any class type from diffusers:
|
||||
|
||||
```python
|
||||
# Discover all scheduler classes
|
||||
schedulers = discover_diffusers_classes("SchedulerMixin")
|
||||
|
||||
# Discover all model classes
|
||||
models = discover_diffusers_classes("ModelMixin")
|
||||
|
||||
# Get a sorted list of available classes
|
||||
scheduler_names = get_available_classes("SchedulerMixin")
|
||||
```
|
||||
|
||||
### Special Pipeline Handling
|
||||
|
||||
Most pipelines are loaded dynamically through `load_diffusers_pipeline()`. Only pipelines requiring truly custom initialization logic are handled explicitly:
|
||||
|
||||
- `FluxTransformer2DModel`: Requires quantization and custom transformer loading (cannot use dynamic loader)
|
||||
- `WanPipeline` / `WanImageToVideoPipeline`: Uses dynamic loader with special VAE (float32 dtype)
|
||||
- `SanaPipeline`: Uses dynamic loader with post-load dtype conversion for VAE/text encoder
|
||||
- `StableVideoDiffusionPipeline`: Uses dynamic loader with CPU offload handling
|
||||
- `VideoDiffusionPipeline`: Alias for DiffusionPipeline with video flags
|
||||
|
||||
All other pipelines (StableDiffusionPipeline, StableDiffusionXLPipeline, FluxPipeline, etc.) are loaded purely through the dynamic loader.
|
||||
|
||||
### Error Handling
|
||||
|
||||
When a pipeline cannot be resolved, the loader provides helpful error messages listing available pipelines and tasks:
|
||||
|
||||
```
|
||||
ValueError: Unknown pipeline class 'NonExistentPipeline'.
|
||||
Available pipelines: AnimateDiffPipeline, AnimateDiffVideoToVideoPipeline, ...
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `COMPEL` | `0` | Enable Compel for prompt weighting |
|
||||
| `XPU` | `0` | Enable Intel XPU support |
|
||||
| `CLIPSKIP` | `1` | Enable CLIP skip support |
|
||||
| `SAFETENSORS` | `1` | Use safetensors format |
|
||||
| `CHUNK_SIZE` | `8` | Decode chunk size for video |
|
||||
| `FPS` | `7` | Video frames per second |
|
||||
| `DISABLE_CPU_OFFLOAD` | `0` | Disable CPU offload |
|
||||
| `FRAMES` | `64` | Number of video frames |
|
||||
| `BFL_REPO` | `ChuckMcSneed/FLUX.1-dev` | Flux base repo |
|
||||
| `PYTHON_GRPC_MAX_WORKERS` | `1` | Max gRPC workers |
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
./test.sh
|
||||
```
|
||||
|
||||
The test suite includes:
|
||||
- Unit tests for the dynamic loader (`test_dynamic_loader.py`)
|
||||
- Integration tests for the gRPC backend (`test.py`)
|
||||
@@ -1,4 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
LocalAI Diffusers Backend
|
||||
|
||||
This backend provides gRPC access to diffusers pipelines with dynamic pipeline loading.
|
||||
New pipelines added to diffusers become available automatically without code changes.
|
||||
"""
|
||||
from concurrent import futures
|
||||
import traceback
|
||||
import argparse
|
||||
@@ -17,14 +23,22 @@ import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
|
||||
from diffusers import SanaPipeline, StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, \
|
||||
EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, QwenImageEditPipeline, AutoencoderKLWan, WanPipeline, WanImageToVideoPipeline
|
||||
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline, Lumina2Text2ImgPipeline
|
||||
# Import dynamic loader for pipeline discovery
|
||||
from diffusers_dynamic_loader import (
|
||||
get_pipeline_registry,
|
||||
resolve_pipeline_class,
|
||||
get_available_pipelines,
|
||||
load_diffusers_pipeline,
|
||||
)
|
||||
|
||||
# Import specific items still needed for special cases and safety checker
|
||||
from diffusers import DiffusionPipeline, ControlNetModel
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKLWan
|
||||
from diffusers.pipelines.stable_diffusion import safety_checker
|
||||
from diffusers.utils import load_image, export_to_video
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from optimum.quanto import freeze, qfloat8, quantize
|
||||
from transformers import CLIPTextModel, T5EncoderModel
|
||||
from transformers import T5EncoderModel
|
||||
from safetensors.torch import load_file
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@@ -158,6 +172,165 @@ def get_scheduler(name: str, config: dict = {}):
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
def _load_pipeline(self, request, modelFile, fromSingleFile, torchType, variant):
|
||||
"""
|
||||
Load a diffusers pipeline dynamically using the dynamic loader.
|
||||
|
||||
This method uses load_diffusers_pipeline() for most pipelines, falling back
|
||||
to explicit handling only for pipelines requiring custom initialization
|
||||
(e.g., quantization, special VAE handling).
|
||||
|
||||
Args:
|
||||
request: The gRPC request containing pipeline configuration
|
||||
modelFile: Path to the model file (for single file loading)
|
||||
fromSingleFile: Whether to use from_single_file() vs from_pretrained()
|
||||
torchType: The torch dtype to use
|
||||
variant: Model variant (e.g., "fp16")
|
||||
|
||||
Returns:
|
||||
The loaded pipeline instance
|
||||
"""
|
||||
pipeline_type = request.PipelineType
|
||||
|
||||
# Handle IMG2IMG request flag with default pipeline
|
||||
if request.IMG2IMG and pipeline_type == "":
|
||||
pipeline_type = "StableDiffusionImg2ImgPipeline"
|
||||
|
||||
# ================================================================
|
||||
# Special cases requiring custom initialization logic
|
||||
# Only handle pipelines that truly need custom code (quantization,
|
||||
# special VAE handling, etc.). All other pipelines use dynamic loading.
|
||||
# ================================================================
|
||||
|
||||
# FluxTransformer2DModel - requires quantization and custom transformer loading
|
||||
if pipeline_type == "FluxTransformer2DModel":
|
||||
dtype = torch.bfloat16
|
||||
bfl_repo = os.environ.get("BFL_REPO", "ChuckMcSneed/FLUX.1-dev")
|
||||
|
||||
transformer = FluxTransformer2DModel.from_single_file(modelFile, torch_dtype=dtype)
|
||||
quantize(transformer, weights=qfloat8)
|
||||
freeze(transformer)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
|
||||
quantize(text_encoder_2, weights=qfloat8)
|
||||
freeze(text_encoder_2)
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
|
||||
pipe.transformer = transformer
|
||||
pipe.text_encoder_2 = text_encoder_2
|
||||
|
||||
if request.LowVRAM:
|
||||
pipe.enable_model_cpu_offload()
|
||||
return pipe
|
||||
|
||||
# WanPipeline - requires special VAE with float32 dtype
|
||||
if pipeline_type == "WanPipeline":
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
request.Model,
|
||||
subfolder="vae",
|
||||
torch_dtype=torch.float32
|
||||
)
|
||||
pipe = load_diffusers_pipeline(
|
||||
class_name="WanPipeline",
|
||||
model_id=request.Model,
|
||||
vae=vae,
|
||||
torch_dtype=torchType
|
||||
)
|
||||
self.txt2vid = True
|
||||
return pipe
|
||||
|
||||
# WanImageToVideoPipeline - requires special VAE with float32 dtype
|
||||
if pipeline_type == "WanImageToVideoPipeline":
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
request.Model,
|
||||
subfolder="vae",
|
||||
torch_dtype=torch.float32
|
||||
)
|
||||
pipe = load_diffusers_pipeline(
|
||||
class_name="WanImageToVideoPipeline",
|
||||
model_id=request.Model,
|
||||
vae=vae,
|
||||
torch_dtype=torchType
|
||||
)
|
||||
self.img2vid = True
|
||||
return pipe
|
||||
|
||||
# SanaPipeline - requires special VAE and text encoder dtype conversion
|
||||
if pipeline_type == "SanaPipeline":
|
||||
pipe = load_diffusers_pipeline(
|
||||
class_name="SanaPipeline",
|
||||
model_id=request.Model,
|
||||
variant="bf16",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.vae.to(torch.bfloat16)
|
||||
pipe.text_encoder.to(torch.bfloat16)
|
||||
return pipe
|
||||
|
||||
# VideoDiffusionPipeline - alias for DiffusionPipeline with txt2vid flag
|
||||
if pipeline_type == "VideoDiffusionPipeline":
|
||||
self.txt2vid = True
|
||||
pipe = load_diffusers_pipeline(
|
||||
class_name="DiffusionPipeline",
|
||||
model_id=request.Model,
|
||||
torch_dtype=torchType
|
||||
)
|
||||
return pipe
|
||||
|
||||
# StableVideoDiffusionPipeline - needs img2vid flag and CPU offload
|
||||
if pipeline_type == "StableVideoDiffusionPipeline":
|
||||
self.img2vid = True
|
||||
pipe = load_diffusers_pipeline(
|
||||
class_name="StableVideoDiffusionPipeline",
|
||||
model_id=request.Model,
|
||||
torch_dtype=torchType,
|
||||
variant=variant
|
||||
)
|
||||
if not DISABLE_CPU_OFFLOAD:
|
||||
pipe.enable_model_cpu_offload()
|
||||
return pipe
|
||||
|
||||
# ================================================================
|
||||
# Dynamic pipeline loading - the default path for most pipelines
|
||||
# Uses the dynamic loader to instantiate any pipeline by class name
|
||||
# ================================================================
|
||||
|
||||
# Build kwargs for dynamic loading
|
||||
load_kwargs = {"torch_dtype": torchType}
|
||||
|
||||
# Add variant if not loading from single file
|
||||
if not fromSingleFile and variant:
|
||||
load_kwargs["variant"] = variant
|
||||
|
||||
# Add use_safetensors for from_pretrained
|
||||
if not fromSingleFile:
|
||||
load_kwargs["use_safetensors"] = SAFETENSORS
|
||||
|
||||
# Determine pipeline class name - default to AutoPipelineForText2Image
|
||||
effective_pipeline_type = pipeline_type if pipeline_type else "AutoPipelineForText2Image"
|
||||
|
||||
# Use dynamic loader for all pipelines
|
||||
try:
|
||||
pipe = load_diffusers_pipeline(
|
||||
class_name=effective_pipeline_type,
|
||||
model_id=modelFile if fromSingleFile else request.Model,
|
||||
from_single_file=fromSingleFile,
|
||||
**load_kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
# Provide helpful error with available pipelines
|
||||
available = get_available_pipelines()
|
||||
raise ValueError(
|
||||
f"Failed to load pipeline '{effective_pipeline_type}': {e}\n"
|
||||
f"Available pipelines: {', '.join(available[:30])}..."
|
||||
) from e
|
||||
|
||||
# Apply LowVRAM optimization if supported and requested
|
||||
if request.LowVRAM and hasattr(pipe, 'enable_model_cpu_offload'):
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
return pipe
|
||||
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
@@ -231,139 +404,16 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
fromSingleFile = request.Model.startswith("http") or request.Model.startswith("/") or local
|
||||
self.img2vid = False
|
||||
self.txt2vid = False
|
||||
## img2img
|
||||
if (request.PipelineType == "StableDiffusionImg2ImgPipeline") or (request.IMG2IMG and request.PipelineType == ""):
|
||||
if fromSingleFile:
|
||||
self.pipe = StableDiffusionImg2ImgPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType)
|
||||
else:
|
||||
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType)
|
||||
|
||||
elif request.PipelineType == "StableDiffusionDepth2ImgPipeline":
|
||||
self.pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType)
|
||||
## img2vid
|
||||
elif request.PipelineType == "StableVideoDiffusionPipeline":
|
||||
self.img2vid = True
|
||||
self.pipe = StableVideoDiffusionPipeline.from_pretrained(
|
||||
request.Model, torch_dtype=torchType, variant=variant
|
||||
)
|
||||
if not DISABLE_CPU_OFFLOAD:
|
||||
self.pipe.enable_model_cpu_offload()
|
||||
## text2img
|
||||
elif request.PipelineType == "AutoPipelineForText2Image" or request.PipelineType == "":
|
||||
self.pipe = AutoPipelineForText2Image.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=SAFETENSORS,
|
||||
variant=variant)
|
||||
elif request.PipelineType == "StableDiffusionPipeline":
|
||||
if fromSingleFile:
|
||||
self.pipe = StableDiffusionPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType)
|
||||
else:
|
||||
self.pipe = StableDiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType)
|
||||
elif request.PipelineType == "DiffusionPipeline":
|
||||
self.pipe = DiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType)
|
||||
elif request.PipelineType == "QwenImageEditPipeline":
|
||||
self.pipe = QwenImageEditPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType)
|
||||
elif request.PipelineType == "VideoDiffusionPipeline":
|
||||
self.txt2vid = True
|
||||
self.pipe = DiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType)
|
||||
elif request.PipelineType == "StableDiffusionXLPipeline":
|
||||
if fromSingleFile:
|
||||
self.pipe = StableDiffusionXLPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=True)
|
||||
else:
|
||||
self.pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
request.Model,
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=True,
|
||||
variant=variant)
|
||||
elif request.PipelineType == "StableDiffusion3Pipeline":
|
||||
if fromSingleFile:
|
||||
self.pipe = StableDiffusion3Pipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=True)
|
||||
else:
|
||||
self.pipe = StableDiffusion3Pipeline.from_pretrained(
|
||||
request.Model,
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=True,
|
||||
variant=variant)
|
||||
elif request.PipelineType == "FluxPipeline":
|
||||
if fromSingleFile:
|
||||
self.pipe = FluxPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=True)
|
||||
else:
|
||||
self.pipe = FluxPipeline.from_pretrained(
|
||||
request.Model,
|
||||
torch_dtype=torch.bfloat16)
|
||||
if request.LowVRAM:
|
||||
self.pipe.enable_model_cpu_offload()
|
||||
elif request.PipelineType == "FluxTransformer2DModel":
|
||||
dtype = torch.bfloat16
|
||||
# specify from environment or default to "ChuckMcSneed/FLUX.1-dev"
|
||||
bfl_repo = os.environ.get("BFL_REPO", "ChuckMcSneed/FLUX.1-dev")
|
||||
|
||||
transformer = FluxTransformer2DModel.from_single_file(modelFile, torch_dtype=dtype)
|
||||
quantize(transformer, weights=qfloat8)
|
||||
freeze(transformer)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
|
||||
quantize(text_encoder_2, weights=qfloat8)
|
||||
freeze(text_encoder_2)
|
||||
|
||||
self.pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
|
||||
self.pipe.transformer = transformer
|
||||
self.pipe.text_encoder_2 = text_encoder_2
|
||||
|
||||
if request.LowVRAM:
|
||||
self.pipe.enable_model_cpu_offload()
|
||||
elif request.PipelineType == "Lumina2Text2ImgPipeline":
|
||||
self.pipe = Lumina2Text2ImgPipeline.from_pretrained(
|
||||
request.Model,
|
||||
torch_dtype=torch.bfloat16)
|
||||
if request.LowVRAM:
|
||||
self.pipe.enable_model_cpu_offload()
|
||||
elif request.PipelineType == "SanaPipeline":
|
||||
self.pipe = SanaPipeline.from_pretrained(
|
||||
request.Model,
|
||||
variant="bf16",
|
||||
torch_dtype=torch.bfloat16)
|
||||
self.pipe.vae.to(torch.bfloat16)
|
||||
self.pipe.text_encoder.to(torch.bfloat16)
|
||||
elif request.PipelineType == "WanPipeline":
|
||||
# WAN2.2 pipeline requires special VAE handling
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
request.Model,
|
||||
subfolder="vae",
|
||||
torch_dtype=torch.float32
|
||||
)
|
||||
self.pipe = WanPipeline.from_pretrained(
|
||||
request.Model,
|
||||
vae=vae,
|
||||
torch_dtype=torchType
|
||||
)
|
||||
self.txt2vid = True # WAN2.2 is a text-to-video pipeline
|
||||
elif request.PipelineType == "WanImageToVideoPipeline":
|
||||
# WAN2.2 image-to-video pipeline
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
request.Model,
|
||||
subfolder="vae",
|
||||
torch_dtype=torch.float32
|
||||
)
|
||||
self.pipe = WanImageToVideoPipeline.from_pretrained(
|
||||
request.Model,
|
||||
vae=vae,
|
||||
torch_dtype=torchType
|
||||
)
|
||||
self.img2vid = True # WAN2.2 image-to-video pipeline
|
||||
# Load pipeline using dynamic loader
|
||||
# Special cases that require custom initialization are handled first
|
||||
self.pipe = self._load_pipeline(
|
||||
request=request,
|
||||
modelFile=modelFile,
|
||||
fromSingleFile=fromSingleFile,
|
||||
torchType=torchType,
|
||||
variant=variant
|
||||
)
|
||||
|
||||
if CLIPSKIP and request.CLIPSkip != 0:
|
||||
self.clip_skip = request.CLIPSkip
|
||||
@@ -501,10 +551,12 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
# create a dictionary of values for the parameters
|
||||
options = {
|
||||
"negative_prompt": request.negative_prompt,
|
||||
"num_inference_steps": steps,
|
||||
}
|
||||
|
||||
if hasattr(request, 'negative_prompt') and request.negative_prompt != "":
|
||||
options["negative_prompt"] = request.negative_prompt
|
||||
|
||||
# Handle image source: prioritize RefImages over request.src
|
||||
image_src = None
|
||||
if hasattr(request, 'ref_images') and request.ref_images and len(request.ref_images) > 0:
|
||||
@@ -528,17 +580,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if CLIPSKIP and self.clip_skip != 0:
|
||||
options["clip_skip"] = self.clip_skip
|
||||
|
||||
# Get the keys that we will build the args for our pipe for
|
||||
keys = options.keys()
|
||||
|
||||
if request.EnableParameters != "":
|
||||
keys = [key.strip() for key in request.EnableParameters.split(",")]
|
||||
|
||||
if request.EnableParameters == "none":
|
||||
keys = []
|
||||
|
||||
# create a dictionary of parameters by using the keys from EnableParameters and the values from defaults
|
||||
kwargs = {key: options.get(key) for key in keys if key in options}
|
||||
kwargs = {}
|
||||
|
||||
# populate kwargs from self.options.
|
||||
kwargs.update(self.options)
|
||||
|
||||
538
backend/python/diffusers/diffusers_dynamic_loader.py
Normal file
538
backend/python/diffusers/diffusers_dynamic_loader.py
Normal file
@@ -0,0 +1,538 @@
|
||||
"""
|
||||
Dynamic Diffusers Pipeline Loader
|
||||
|
||||
This module provides dynamic discovery and loading of diffusers pipelines at runtime,
|
||||
eliminating the need for per-pipeline conditional statements. New pipelines added to
|
||||
diffusers become available automatically without code changes.
|
||||
|
||||
The module also supports discovering other diffusers classes like schedulers, models,
|
||||
and other components, making it a generic solution for dynamic class loading.
|
||||
|
||||
Usage:
|
||||
from diffusers_dynamic_loader import load_diffusers_pipeline, get_available_pipelines
|
||||
|
||||
# Load by class name
|
||||
pipe = load_diffusers_pipeline(class_name="StableDiffusionPipeline", model_id="...", torch_dtype=torch.float16)
|
||||
|
||||
# Load by task alias
|
||||
pipe = load_diffusers_pipeline(task="text-to-image", model_id="...", torch_dtype=torch.float16)
|
||||
|
||||
# Load using model_id (infers from HuggingFace Hub if possible)
|
||||
pipe = load_diffusers_pipeline(model_id="runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
|
||||
# Get list of available pipelines
|
||||
available = get_available_pipelines()
|
||||
|
||||
# Discover other diffusers classes (schedulers, models, etc.)
|
||||
schedulers = discover_diffusers_classes("SchedulerMixin")
|
||||
models = discover_diffusers_classes("ModelMixin")
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
|
||||
# Global cache for discovered pipelines - computed once per process
|
||||
_pipeline_registry: Optional[Dict[str, Type]] = None
|
||||
_task_aliases: Optional[Dict[str, List[str]]] = None
|
||||
|
||||
# Global cache for other discovered class types
|
||||
_class_registries: Dict[str, Dict[str, Type]] = {}
|
||||
|
||||
|
||||
def _camel_to_kebab(name: str) -> str:
|
||||
"""
|
||||
Convert CamelCase to kebab-case.
|
||||
|
||||
Examples:
|
||||
StableDiffusionPipeline -> stable-diffusion-pipeline
|
||||
StableDiffusionXLImg2ImgPipeline -> stable-diffusion-xl-img-2-img-pipeline
|
||||
"""
|
||||
# Insert hyphen before uppercase letters (but not at the start)
|
||||
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1-\2', name)
|
||||
# Insert hyphen before uppercase letters following lowercase letters or numbers
|
||||
s2 = re.sub('([a-z0-9])([A-Z])', r'\1-\2', s1)
|
||||
return s2.lower()
|
||||
|
||||
|
||||
def _extract_task_keywords(class_name: str) -> List[str]:
|
||||
"""
|
||||
Extract task-related keywords from a pipeline class name.
|
||||
|
||||
This function derives useful task aliases from the class name without
|
||||
hardcoding per-pipeline branches.
|
||||
|
||||
Returns a list of potential task aliases for this pipeline.
|
||||
"""
|
||||
aliases = []
|
||||
name_lower = class_name.lower()
|
||||
|
||||
# Direct task mappings based on common patterns in class names
|
||||
task_patterns = {
|
||||
'text2image': ['text-to-image', 'txt2img', 'text2image'],
|
||||
'texttoimage': ['text-to-image', 'txt2img', 'text2image'],
|
||||
'txt2img': ['text-to-image', 'txt2img', 'text2image'],
|
||||
'img2img': ['image-to-image', 'img2img', 'image2image'],
|
||||
'image2image': ['image-to-image', 'img2img', 'image2image'],
|
||||
'imagetoimage': ['image-to-image', 'img2img', 'image2image'],
|
||||
'img2video': ['image-to-video', 'img2vid', 'img2video'],
|
||||
'imagetovideo': ['image-to-video', 'img2vid', 'img2video'],
|
||||
'text2video': ['text-to-video', 'txt2vid', 'text2video'],
|
||||
'texttovideo': ['text-to-video', 'txt2vid', 'text2video'],
|
||||
'inpaint': ['inpainting', 'inpaint'],
|
||||
'depth2img': ['depth-to-image', 'depth2img'],
|
||||
'depthtoimage': ['depth-to-image', 'depth2img'],
|
||||
'controlnet': ['controlnet', 'control-net'],
|
||||
'upscale': ['upscaling', 'upscale', 'super-resolution'],
|
||||
'superresolution': ['upscaling', 'upscale', 'super-resolution'],
|
||||
}
|
||||
|
||||
# Check for each pattern in the class name
|
||||
for pattern, task_aliases in task_patterns.items():
|
||||
if pattern in name_lower:
|
||||
aliases.extend(task_aliases)
|
||||
|
||||
# Also detect general pipeline types from the class name structure
|
||||
# E.g., StableDiffusionPipeline -> stable-diffusion, flux -> flux
|
||||
# Remove "Pipeline" suffix and convert to kebab case
|
||||
if class_name.endswith('Pipeline'):
|
||||
base_name = class_name[:-8] # Remove "Pipeline"
|
||||
kebab_name = _camel_to_kebab(base_name)
|
||||
aliases.append(kebab_name)
|
||||
|
||||
# Extract model family name (e.g., "stable-diffusion" from "stable-diffusion-xl-img-2-img")
|
||||
parts = kebab_name.split('-')
|
||||
if len(parts) >= 2:
|
||||
# Try the first two words as a family name
|
||||
family = '-'.join(parts[:2])
|
||||
if family not in aliases:
|
||||
aliases.append(family)
|
||||
|
||||
# If no specific task pattern matched but class contains "Pipeline", add "text-to-image" as default
|
||||
# since most diffusion pipelines support text-to-image generation
|
||||
if 'text-to-image' not in aliases and 'image-to-image' not in aliases:
|
||||
# Only add for pipelines that seem to be generation pipelines (not schedulers, etc.)
|
||||
if 'pipeline' in name_lower and not any(x in name_lower for x in ['scheduler', 'processor', 'encoder']):
|
||||
# Don't automatically add - let it be explicit
|
||||
pass
|
||||
|
||||
return list(set(aliases)) # Remove duplicates
|
||||
|
||||
|
||||
def discover_diffusers_classes(
|
||||
base_class_name: str,
|
||||
include_base: bool = True
|
||||
) -> Dict[str, Type]:
|
||||
"""
|
||||
Discover all subclasses of a given base class from diffusers.
|
||||
|
||||
This function provides a generic way to discover any type of diffusers class,
|
||||
not just pipelines. It can be used to discover schedulers, models, processors,
|
||||
and other components.
|
||||
|
||||
Args:
|
||||
base_class_name: Name of the base class to search for subclasses
|
||||
(e.g., "DiffusionPipeline", "SchedulerMixin", "ModelMixin")
|
||||
include_base: Whether to include the base class itself in results
|
||||
|
||||
Returns:
|
||||
Dict mapping class names to class objects
|
||||
|
||||
Examples:
|
||||
# Discover all pipeline classes
|
||||
pipelines = discover_diffusers_classes("DiffusionPipeline")
|
||||
|
||||
# Discover all scheduler classes
|
||||
schedulers = discover_diffusers_classes("SchedulerMixin")
|
||||
|
||||
# Discover all model classes
|
||||
models = discover_diffusers_classes("ModelMixin")
|
||||
|
||||
# Discover AutoPipeline classes
|
||||
auto_pipelines = discover_diffusers_classes("AutoPipelineForText2Image")
|
||||
"""
|
||||
global _class_registries
|
||||
|
||||
# Check cache first
|
||||
if base_class_name in _class_registries:
|
||||
return _class_registries[base_class_name]
|
||||
|
||||
import diffusers
|
||||
|
||||
# Try to get the base class from diffusers
|
||||
base_class = None
|
||||
try:
|
||||
base_class = getattr(diffusers, base_class_name)
|
||||
except AttributeError:
|
||||
# Try to find in submodules
|
||||
for submodule in ['schedulers', 'models', 'pipelines']:
|
||||
try:
|
||||
module = importlib.import_module(f'diffusers.{submodule}')
|
||||
if hasattr(module, base_class_name):
|
||||
base_class = getattr(module, base_class_name)
|
||||
break
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
continue
|
||||
|
||||
if base_class is None:
|
||||
raise ValueError(f"Could not find base class '{base_class_name}' in diffusers")
|
||||
|
||||
registry: Dict[str, Type] = {}
|
||||
|
||||
# Include base class if requested
|
||||
if include_base:
|
||||
registry[base_class_name] = base_class
|
||||
|
||||
# Scan diffusers module for subclasses
|
||||
for attr_name in dir(diffusers):
|
||||
try:
|
||||
attr = getattr(diffusers, attr_name)
|
||||
if (isinstance(attr, type) and
|
||||
issubclass(attr, base_class) and
|
||||
(include_base or attr is not base_class)):
|
||||
registry[attr_name] = attr
|
||||
except (ImportError, AttributeError, TypeError, RuntimeError, ModuleNotFoundError):
|
||||
continue
|
||||
|
||||
# Cache the results
|
||||
_class_registries[base_class_name] = registry
|
||||
return registry
|
||||
|
||||
|
||||
def get_available_classes(base_class_name: str) -> List[str]:
|
||||
"""
|
||||
Get a sorted list of all discovered class names for a given base class.
|
||||
|
||||
Args:
|
||||
base_class_name: Name of the base class (e.g., "SchedulerMixin")
|
||||
|
||||
Returns:
|
||||
Sorted list of discovered class names
|
||||
"""
|
||||
return sorted(discover_diffusers_classes(base_class_name).keys())
|
||||
|
||||
|
||||
def _discover_pipelines() -> Tuple[Dict[str, Type], Dict[str, List[str]]]:
|
||||
"""
|
||||
Discover all subclasses of DiffusionPipeline from diffusers.
|
||||
|
||||
This function uses the generic discover_diffusers_classes() internally
|
||||
and adds pipeline-specific task alias generation. It also includes
|
||||
AutoPipeline classes which are special utility classes for automatic
|
||||
pipeline selection.
|
||||
|
||||
Returns:
|
||||
A tuple of (pipeline_registry, task_aliases) where:
|
||||
- pipeline_registry: Dict mapping class names to class objects
|
||||
- task_aliases: Dict mapping task aliases to lists of class names
|
||||
"""
|
||||
# Use the generic discovery function
|
||||
pipeline_registry = discover_diffusers_classes("DiffusionPipeline", include_base=True)
|
||||
|
||||
# Also add AutoPipeline classes - these are special utility classes that are
|
||||
# NOT subclasses of DiffusionPipeline but are commonly used
|
||||
import diffusers
|
||||
auto_pipeline_classes = [
|
||||
"AutoPipelineForText2Image",
|
||||
"AutoPipelineForImage2Image",
|
||||
"AutoPipelineForInpainting",
|
||||
]
|
||||
for cls_name in auto_pipeline_classes:
|
||||
try:
|
||||
cls = getattr(diffusers, cls_name)
|
||||
if cls is not None:
|
||||
pipeline_registry[cls_name] = cls
|
||||
except AttributeError:
|
||||
# Class not available in this version of diffusers
|
||||
pass
|
||||
|
||||
# Generate task aliases for pipelines
|
||||
task_aliases: Dict[str, List[str]] = {}
|
||||
for attr_name in pipeline_registry:
|
||||
if attr_name == "DiffusionPipeline":
|
||||
continue # Skip base class for alias generation
|
||||
|
||||
aliases = _extract_task_keywords(attr_name)
|
||||
for alias in aliases:
|
||||
if alias not in task_aliases:
|
||||
task_aliases[alias] = []
|
||||
if attr_name not in task_aliases[alias]:
|
||||
task_aliases[alias].append(attr_name)
|
||||
|
||||
return pipeline_registry, task_aliases
|
||||
|
||||
|
||||
def get_pipeline_registry() -> Dict[str, Type]:
|
||||
"""
|
||||
Get the cached pipeline registry.
|
||||
|
||||
Returns a dictionary mapping pipeline class names to their class objects.
|
||||
The registry is built on first access and cached for subsequent calls.
|
||||
"""
|
||||
global _pipeline_registry, _task_aliases
|
||||
if _pipeline_registry is None:
|
||||
_pipeline_registry, _task_aliases = _discover_pipelines()
|
||||
return _pipeline_registry
|
||||
|
||||
|
||||
def get_task_aliases() -> Dict[str, List[str]]:
|
||||
"""
|
||||
Get the cached task aliases dictionary.
|
||||
|
||||
Returns a dictionary mapping task aliases (e.g., "text-to-image") to
|
||||
lists of pipeline class names that support that task.
|
||||
"""
|
||||
global _pipeline_registry, _task_aliases
|
||||
if _task_aliases is None:
|
||||
_pipeline_registry, _task_aliases = _discover_pipelines()
|
||||
return _task_aliases
|
||||
|
||||
|
||||
def get_available_pipelines() -> List[str]:
|
||||
"""
|
||||
Get a sorted list of all discovered pipeline class names.
|
||||
|
||||
Returns:
|
||||
List of pipeline class names available for loading.
|
||||
"""
|
||||
return sorted(get_pipeline_registry().keys())
|
||||
|
||||
|
||||
def get_available_tasks() -> List[str]:
|
||||
"""
|
||||
Get a sorted list of all available task aliases.
|
||||
|
||||
Returns:
|
||||
List of task aliases (e.g., ["text-to-image", "image-to-image", ...])
|
||||
"""
|
||||
return sorted(get_task_aliases().keys())
|
||||
|
||||
|
||||
def resolve_pipeline_class(
|
||||
class_name: Optional[str] = None,
|
||||
task: Optional[str] = None,
|
||||
model_id: Optional[str] = None
|
||||
) -> Type:
|
||||
"""
|
||||
Resolve a pipeline class from class_name, task, or model_id.
|
||||
|
||||
Priority:
|
||||
1. If class_name is provided, look it up directly
|
||||
2. If task is provided, resolve through task aliases
|
||||
3. If model_id is provided, try to infer from HuggingFace Hub
|
||||
|
||||
Args:
|
||||
class_name: Exact pipeline class name (e.g., "StableDiffusionPipeline")
|
||||
task: Task alias (e.g., "text-to-image", "img2img")
|
||||
model_id: HuggingFace model ID (e.g., "runwayml/stable-diffusion-v1-5")
|
||||
|
||||
Returns:
|
||||
The resolved pipeline class.
|
||||
|
||||
Raises:
|
||||
ValueError: If no pipeline could be resolved.
|
||||
"""
|
||||
registry = get_pipeline_registry()
|
||||
aliases = get_task_aliases()
|
||||
|
||||
# 1. Direct class name lookup
|
||||
if class_name:
|
||||
if class_name in registry:
|
||||
return registry[class_name]
|
||||
# Try case-insensitive match
|
||||
for name, cls in registry.items():
|
||||
if name.lower() == class_name.lower():
|
||||
return cls
|
||||
raise ValueError(
|
||||
f"Unknown pipeline class '{class_name}'. "
|
||||
f"Available pipelines: {', '.join(sorted(registry.keys())[:20])}..."
|
||||
)
|
||||
|
||||
# 2. Task alias lookup
|
||||
if task:
|
||||
task_lower = task.lower().replace('_', '-')
|
||||
if task_lower in aliases:
|
||||
# Return the first matching pipeline for this task
|
||||
matching_classes = aliases[task_lower]
|
||||
if matching_classes:
|
||||
return registry[matching_classes[0]]
|
||||
|
||||
# Try partial matching
|
||||
for alias, classes in aliases.items():
|
||||
if task_lower in alias or alias in task_lower:
|
||||
if classes:
|
||||
return registry[classes[0]]
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown task '{task}'. "
|
||||
f"Available tasks: {', '.join(sorted(aliases.keys())[:20])}..."
|
||||
)
|
||||
|
||||
# 3. Try to infer from HuggingFace Hub
|
||||
if model_id:
|
||||
try:
|
||||
from huggingface_hub import model_info
|
||||
info = model_info(model_id)
|
||||
|
||||
# Check pipeline_tag
|
||||
if hasattr(info, 'pipeline_tag') and info.pipeline_tag:
|
||||
tag = info.pipeline_tag.lower().replace('_', '-')
|
||||
if tag in aliases:
|
||||
matching_classes = aliases[tag]
|
||||
if matching_classes:
|
||||
return registry[matching_classes[0]]
|
||||
|
||||
# Check model card for hints
|
||||
if hasattr(info, 'cardData') and info.cardData:
|
||||
card = info.cardData
|
||||
if 'pipeline_tag' in card:
|
||||
tag = card['pipeline_tag'].lower().replace('_', '-')
|
||||
if tag in aliases:
|
||||
matching_classes = aliases[tag]
|
||||
if matching_classes:
|
||||
return registry[matching_classes[0]]
|
||||
|
||||
except ImportError:
|
||||
# huggingface_hub not available
|
||||
pass
|
||||
except (KeyError, AttributeError, ValueError, OSError):
|
||||
# Model info lookup failed - common cases:
|
||||
# - KeyError: Missing keys in model card
|
||||
# - AttributeError: Missing attributes on model info
|
||||
# - ValueError: Invalid model data
|
||||
# - OSError: Network or file access issues
|
||||
pass
|
||||
|
||||
# Fallback: use DiffusionPipeline.from_pretrained which auto-detects
|
||||
# DiffusionPipeline is always added to registry in _discover_pipelines (line 132)
|
||||
# but use .get() with import fallback for extra safety
|
||||
from diffusers import DiffusionPipeline
|
||||
return registry.get('DiffusionPipeline', DiffusionPipeline)
|
||||
|
||||
raise ValueError(
|
||||
"Must provide at least one of: class_name, task, or model_id. "
|
||||
f"Available pipelines: {', '.join(sorted(registry.keys())[:20])}... "
|
||||
f"Available tasks: {', '.join(sorted(aliases.keys())[:20])}..."
|
||||
)
|
||||
|
||||
|
||||
def load_diffusers_pipeline(
|
||||
class_name: Optional[str] = None,
|
||||
task: Optional[str] = None,
|
||||
model_id: Optional[str] = None,
|
||||
from_single_file: bool = False,
|
||||
**kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Load a diffusers pipeline dynamically.
|
||||
|
||||
This function resolves the appropriate pipeline class based on the provided
|
||||
parameters and instantiates it with the given kwargs.
|
||||
|
||||
Args:
|
||||
class_name: Exact pipeline class name (e.g., "StableDiffusionPipeline")
|
||||
task: Task alias (e.g., "text-to-image", "img2img")
|
||||
model_id: HuggingFace model ID or local path
|
||||
from_single_file: If True, use from_single_file() instead of from_pretrained()
|
||||
**kwargs: Additional arguments passed to from_pretrained() or from_single_file()
|
||||
|
||||
Returns:
|
||||
An instantiated pipeline object.
|
||||
|
||||
Raises:
|
||||
ValueError: If no pipeline could be resolved.
|
||||
Exception: If pipeline loading fails.
|
||||
|
||||
Examples:
|
||||
# Load by class name
|
||||
pipe = load_diffusers_pipeline(
|
||||
class_name="StableDiffusionPipeline",
|
||||
model_id="runwayml/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
# Load by task
|
||||
pipe = load_diffusers_pipeline(
|
||||
task="text-to-image",
|
||||
model_id="runwayml/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
# Load from single file
|
||||
pipe = load_diffusers_pipeline(
|
||||
class_name="StableDiffusionPipeline",
|
||||
model_id="/path/to/model.safetensors",
|
||||
from_single_file=True,
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
"""
|
||||
# Resolve the pipeline class
|
||||
pipeline_class = resolve_pipeline_class(
|
||||
class_name=class_name,
|
||||
task=task,
|
||||
model_id=model_id
|
||||
)
|
||||
|
||||
# If no model_id provided but we have a class, we can't load
|
||||
if model_id is None:
|
||||
raise ValueError("model_id is required to load a pipeline")
|
||||
|
||||
# Load the pipeline
|
||||
try:
|
||||
if from_single_file:
|
||||
# Check if the class has from_single_file method
|
||||
if hasattr(pipeline_class, 'from_single_file'):
|
||||
return pipeline_class.from_single_file(model_id, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Pipeline class {pipeline_class.__name__} does not support from_single_file(). "
|
||||
f"Use from_pretrained() instead."
|
||||
)
|
||||
else:
|
||||
return pipeline_class.from_pretrained(model_id, **kwargs)
|
||||
|
||||
except Exception as e:
|
||||
# Provide helpful error message
|
||||
available = get_available_pipelines()
|
||||
raise RuntimeError(
|
||||
f"Failed to load pipeline '{pipeline_class.__name__}' from '{model_id}': {e}\n"
|
||||
f"Available pipelines: {', '.join(available[:20])}..."
|
||||
) from e
|
||||
|
||||
|
||||
def get_pipeline_info(class_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about a specific pipeline class.
|
||||
|
||||
Args:
|
||||
class_name: The pipeline class name
|
||||
|
||||
Returns:
|
||||
Dictionary with pipeline information including:
|
||||
- name: Class name
|
||||
- aliases: List of task aliases
|
||||
- supports_single_file: Whether from_single_file() is available
|
||||
- docstring: Class docstring (if available)
|
||||
"""
|
||||
registry = get_pipeline_registry()
|
||||
aliases = get_task_aliases()
|
||||
|
||||
if class_name not in registry:
|
||||
raise ValueError(f"Unknown pipeline: {class_name}")
|
||||
|
||||
cls = registry[class_name]
|
||||
|
||||
# Find all aliases for this pipeline
|
||||
pipeline_aliases = []
|
||||
for alias, classes in aliases.items():
|
||||
if class_name in classes:
|
||||
pipeline_aliases.append(alias)
|
||||
|
||||
return {
|
||||
'name': class_name,
|
||||
'aliases': pipeline_aliases,
|
||||
'supports_single_file': hasattr(cls, 'from_single_file'),
|
||||
'docstring': cls.__doc__[:200] if cls.__doc__ else None
|
||||
}
|
||||
@@ -1,15 +1,26 @@
|
||||
"""
|
||||
A test script to test the gRPC service
|
||||
A test script to test the gRPC service and dynamic loader
|
||||
"""
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import grpc
|
||||
# Import dynamic loader for testing (these don't need gRPC)
|
||||
import diffusers_dynamic_loader as loader
|
||||
from diffusers import DiffusionPipeline, StableDiffusionPipeline
|
||||
|
||||
# Try to import gRPC modules - may not be available during unit testing
|
||||
try:
|
||||
import grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
GRPC_AVAILABLE = True
|
||||
except ImportError:
|
||||
GRPC_AVAILABLE = False
|
||||
|
||||
|
||||
@unittest.skipUnless(GRPC_AVAILABLE, "gRPC modules not available")
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
"""
|
||||
TestBackendServicer is the class that tests the gRPC service
|
||||
@@ -82,3 +93,222 @@ class TestBackendServicer(unittest.TestCase):
|
||||
self.fail("Image gen service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
|
||||
class TestDiffusersDynamicLoader(unittest.TestCase):
|
||||
"""Test cases for the diffusers dynamic loader functionality."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""Set up test fixtures - clear caches to ensure fresh discovery."""
|
||||
# Reset the caches to ensure fresh discovery
|
||||
loader._pipeline_registry = None
|
||||
loader._task_aliases = None
|
||||
|
||||
def test_camel_to_kebab_conversion(self):
|
||||
"""Test CamelCase to kebab-case conversion."""
|
||||
test_cases = [
|
||||
("StableDiffusionPipeline", "stable-diffusion-pipeline"),
|
||||
("StableDiffusionXLPipeline", "stable-diffusion-xl-pipeline"),
|
||||
("FluxPipeline", "flux-pipeline"),
|
||||
("DiffusionPipeline", "diffusion-pipeline"),
|
||||
]
|
||||
for input_val, expected in test_cases:
|
||||
with self.subTest(input=input_val):
|
||||
result = loader._camel_to_kebab(input_val)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_extract_task_keywords(self):
|
||||
"""Test task keyword extraction from class names."""
|
||||
# Test text-to-image detection
|
||||
aliases = loader._extract_task_keywords("StableDiffusionPipeline")
|
||||
self.assertIn("stable-diffusion", aliases)
|
||||
|
||||
# Test img2img detection
|
||||
aliases = loader._extract_task_keywords("StableDiffusionImg2ImgPipeline")
|
||||
self.assertIn("image-to-image", aliases)
|
||||
self.assertIn("img2img", aliases)
|
||||
|
||||
# Test inpainting detection
|
||||
aliases = loader._extract_task_keywords("StableDiffusionInpaintPipeline")
|
||||
self.assertIn("inpainting", aliases)
|
||||
self.assertIn("inpaint", aliases)
|
||||
|
||||
# Test depth2img detection
|
||||
aliases = loader._extract_task_keywords("StableDiffusionDepth2ImgPipeline")
|
||||
self.assertIn("depth-to-image", aliases)
|
||||
|
||||
def test_discover_pipelines_finds_known_classes(self):
|
||||
"""Test that pipeline discovery finds at least one known pipeline class."""
|
||||
registry = loader.get_pipeline_registry()
|
||||
|
||||
# Check that the registry is not empty
|
||||
self.assertGreater(len(registry), 0, "Pipeline registry should not be empty")
|
||||
|
||||
# Check for known pipeline classes
|
||||
known_pipelines = [
|
||||
"StableDiffusionPipeline",
|
||||
"DiffusionPipeline",
|
||||
]
|
||||
|
||||
for pipeline_name in known_pipelines:
|
||||
with self.subTest(pipeline=pipeline_name):
|
||||
self.assertIn(
|
||||
pipeline_name,
|
||||
registry,
|
||||
f"Expected to find {pipeline_name} in registry"
|
||||
)
|
||||
|
||||
def test_discover_pipelines_caches_results(self):
|
||||
"""Test that pipeline discovery results are cached."""
|
||||
# Get registry twice
|
||||
registry1 = loader.get_pipeline_registry()
|
||||
registry2 = loader.get_pipeline_registry()
|
||||
|
||||
# Should be the same object (cached)
|
||||
self.assertIs(registry1, registry2, "Registry should be cached")
|
||||
|
||||
def test_get_available_pipelines(self):
|
||||
"""Test getting list of available pipelines."""
|
||||
available = loader.get_available_pipelines()
|
||||
|
||||
# Should return a list
|
||||
self.assertIsInstance(available, list)
|
||||
|
||||
# Should contain known pipelines
|
||||
self.assertIn("StableDiffusionPipeline", available)
|
||||
self.assertIn("DiffusionPipeline", available)
|
||||
|
||||
# Should be sorted
|
||||
self.assertEqual(available, sorted(available))
|
||||
|
||||
def test_get_available_tasks(self):
|
||||
"""Test getting list of available task aliases."""
|
||||
tasks = loader.get_available_tasks()
|
||||
|
||||
# Should return a list
|
||||
self.assertIsInstance(tasks, list)
|
||||
|
||||
# Should be sorted
|
||||
self.assertEqual(tasks, sorted(tasks))
|
||||
|
||||
def test_resolve_pipeline_class_by_name(self):
|
||||
"""Test resolving pipeline class by exact name."""
|
||||
cls = loader.resolve_pipeline_class(class_name="StableDiffusionPipeline")
|
||||
self.assertEqual(cls, StableDiffusionPipeline)
|
||||
|
||||
def test_resolve_pipeline_class_by_name_case_insensitive(self):
|
||||
"""Test that class name resolution is case-insensitive."""
|
||||
cls1 = loader.resolve_pipeline_class(class_name="StableDiffusionPipeline")
|
||||
cls2 = loader.resolve_pipeline_class(class_name="stablediffusionpipeline")
|
||||
self.assertEqual(cls1, cls2)
|
||||
|
||||
def test_resolve_pipeline_class_by_task(self):
|
||||
"""Test resolving pipeline class by task alias."""
|
||||
# Get the registry to find available tasks
|
||||
aliases = loader.get_task_aliases()
|
||||
|
||||
# Test with a common task that should be available
|
||||
if "stable-diffusion" in aliases:
|
||||
cls = loader.resolve_pipeline_class(task="stable-diffusion")
|
||||
self.assertIsNotNone(cls)
|
||||
|
||||
def test_resolve_pipeline_class_unknown_name_raises(self):
|
||||
"""Test that resolving unknown class name raises ValueError with helpful message."""
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
loader.resolve_pipeline_class(class_name="NonExistentPipeline")
|
||||
|
||||
# Check that error message includes available pipelines
|
||||
error_msg = str(ctx.exception)
|
||||
self.assertIn("Unknown pipeline class", error_msg)
|
||||
self.assertIn("Available pipelines", error_msg)
|
||||
|
||||
def test_resolve_pipeline_class_unknown_task_raises(self):
|
||||
"""Test that resolving unknown task raises ValueError with helpful message."""
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
loader.resolve_pipeline_class(task="nonexistent-task-xyz")
|
||||
|
||||
# Check that error message includes available tasks
|
||||
error_msg = str(ctx.exception)
|
||||
self.assertIn("Unknown task", error_msg)
|
||||
self.assertIn("Available tasks", error_msg)
|
||||
|
||||
def test_resolve_pipeline_class_no_params_raises(self):
|
||||
"""Test that calling with no parameters raises helpful ValueError."""
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
loader.resolve_pipeline_class()
|
||||
|
||||
error_msg = str(ctx.exception)
|
||||
self.assertIn("Must provide at least one of", error_msg)
|
||||
|
||||
def test_get_pipeline_info(self):
|
||||
"""Test getting pipeline information."""
|
||||
info = loader.get_pipeline_info("StableDiffusionPipeline")
|
||||
|
||||
self.assertEqual(info['name'], "StableDiffusionPipeline")
|
||||
self.assertIsInstance(info['aliases'], list)
|
||||
self.assertIsInstance(info['supports_single_file'], bool)
|
||||
|
||||
def test_get_pipeline_info_unknown_raises(self):
|
||||
"""Test that getting info for unknown pipeline raises ValueError."""
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
loader.get_pipeline_info("NonExistentPipeline")
|
||||
|
||||
self.assertIn("Unknown pipeline", str(ctx.exception))
|
||||
|
||||
def test_discover_diffusers_classes_pipelines(self):
|
||||
"""Test generic class discovery for DiffusionPipeline."""
|
||||
classes = loader.discover_diffusers_classes("DiffusionPipeline")
|
||||
|
||||
# Should return a dict
|
||||
self.assertIsInstance(classes, dict)
|
||||
|
||||
# Should contain known pipeline classes
|
||||
self.assertIn("DiffusionPipeline", classes)
|
||||
self.assertIn("StableDiffusionPipeline", classes)
|
||||
|
||||
def test_discover_diffusers_classes_caches_results(self):
|
||||
"""Test that class discovery results are cached."""
|
||||
classes1 = loader.discover_diffusers_classes("DiffusionPipeline")
|
||||
classes2 = loader.discover_diffusers_classes("DiffusionPipeline")
|
||||
|
||||
# Should be the same object (cached)
|
||||
self.assertIs(classes1, classes2)
|
||||
|
||||
def test_discover_diffusers_classes_exclude_base(self):
|
||||
"""Test discovering classes without base class."""
|
||||
classes = loader.discover_diffusers_classes("DiffusionPipeline", include_base=False)
|
||||
|
||||
# Should still contain subclasses
|
||||
self.assertIn("StableDiffusionPipeline", classes)
|
||||
|
||||
def test_get_available_classes(self):
|
||||
"""Test getting list of available classes for a base class."""
|
||||
classes = loader.get_available_classes("DiffusionPipeline")
|
||||
|
||||
# Should return a sorted list
|
||||
self.assertIsInstance(classes, list)
|
||||
self.assertEqual(classes, sorted(classes))
|
||||
|
||||
# Should contain known classes
|
||||
self.assertIn("StableDiffusionPipeline", classes)
|
||||
|
||||
|
||||
class TestDiffusersDynamicLoaderWithMocks(unittest.TestCase):
|
||||
"""Test cases using mocks to test edge cases."""
|
||||
|
||||
def test_load_pipeline_requires_model_id(self):
|
||||
"""Test that load_diffusers_pipeline requires model_id."""
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
loader.load_diffusers_pipeline(class_name="StableDiffusionPipeline")
|
||||
|
||||
self.assertIn("model_id is required", str(ctx.exception))
|
||||
|
||||
def test_resolve_with_model_id_uses_diffusion_pipeline_fallback(self):
|
||||
"""Test that resolving with only model_id falls back to DiffusionPipeline."""
|
||||
# When model_id is provided, if hub lookup is not successful,
|
||||
# should fall back to DiffusionPipeline.
|
||||
# This tests the fallback behavior - the actual hub lookup may succeed
|
||||
# or fail depending on network, but the fallback path should work.
|
||||
cls = loader.resolve_pipeline_class(model_id="some/nonexistent/model")
|
||||
self.assertEqual(cls, DiffusionPipeline)
|
||||
|
||||
Reference in New Issue
Block a user