diff --git a/backend/go/stablediffusion-ggml/gosd.cpp b/backend/go/stablediffusion-ggml/gosd.cpp index 768894470..e9d3cfa9d 100644 --- a/backend/go/stablediffusion-ggml/gosd.cpp +++ b/backend/go/stablediffusion-ggml/gosd.cpp @@ -1,4 +1,5 @@ #include "stable-diffusion.h" +#include #include #define GGML_MAX_NAME 128 @@ -21,6 +22,7 @@ #define STB_IMAGE_RESIZE_IMPLEMENTATION #define STB_IMAGE_RESIZE_STATIC #include "stb_image_resize.h" +#include // Names of the sampler method, same order as enum sample_method in stable-diffusion.h const char* sample_method_str[] = { @@ -55,6 +57,73 @@ const char* schedulers[] = { static_assert(std::size(schedulers) == SCHEDULER_COUNT, "schedulers mismatch"); +// New enum string arrays +const char* rng_type_str[] = { + "std_default", + "cuda", + "cpu", +}; +static_assert(std::size(rng_type_str) == RNG_TYPE_COUNT, "rng type mismatch"); + +const char* prediction_str[] = { + "default", + "epsilon", + "v", + "edm_v", + "sd3_flow", + "flux_flow", + "flux2_flow", +}; +static_assert(std::size(prediction_str) == PREDICTION_COUNT, "prediction mismatch"); + +const char* lora_apply_mode_str[] = { + "auto", + "immediately", + "at_runtime", +}; +static_assert(std::size(lora_apply_mode_str) == LORA_APPLY_MODE_COUNT, "lora apply mode mismatch"); + +constexpr const char* sd_type_str[] = { + "f32", // 0 + "f16", // 1 + "q4_0", // 2 + "q4_1", // 3 + nullptr, // 4 + nullptr, // 5 + "q5_0", // 6 + "q5_1", // 7 + "q8_0", // 8 + "q8_1", // 9 + "q2_k", // 10 + "q3_k", // 11 + "q4_k", // 12 + "q5_k", // 13 + "q6_k", // 14 + "q8_k", // 15 + "iq2_xxs", // 16 + "iq2_xs", // 17 + "iq3_xxs", // 18 + "iq1_s", // 19 + "iq4_nl", // 20 + "iq3_s", // 21 + "iq2_s", // 22 + "iq4_xs", // 23 + "i8", // 24 + "i16", // 25 + "i32", // 26 + "i64", // 27 + "f64", // 28 + "iq1_m", // 29 + "bf16", // 30 + nullptr, nullptr, nullptr, nullptr, // 31-34 + "tq1_0", // 35 + "tq2_0", // 36 + nullptr, nullptr, // 37-38 + "mxfp4" // 39 +}; +static_assert(std::size(sd_type_str) == SD_TYPE_COUNT, "sd type mismatch"); + +sd_ctx_params_t ctx_params; sd_ctx_t* sd_c; // Moved from the context (load time) to generation time params scheduler_t scheduler = SCHEDULER_COUNT; @@ -99,7 +168,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads const char *stableDiffusionModel = ""; if (diff == 1 ) { - stableDiffusionModel = model; + stableDiffusionModel = strdup(model); model = ""; } @@ -110,8 +179,38 @@ int load_model(const char *model, char *model_path, char* options[], int threads const char *vae_path = ""; const char *scheduler_str = ""; const char *sampler = ""; + const char *clip_vision_path = ""; + const char *llm_path = ""; + const char *llm_vision_path = ""; + const char *diffusion_model_path = stableDiffusionModel; + const char *high_noise_diffusion_model_path = ""; + const char *taesd_path = ""; + const char *control_net_path = ""; + const char *embedding_dir = ""; + const char *photo_maker_path = ""; + const char *tensor_type_rules = ""; char *lora_dir = model_path; - bool lora_dir_allocated = false; + + bool vae_decode_only = true; + int n_threads = threads; + enum sd_type_t wtype = SD_TYPE_COUNT; + enum rng_type_t rng_type = CUDA_RNG; + enum rng_type_t sampler_rng_type = RNG_TYPE_COUNT; + enum prediction_t prediction = DEFAULT_PRED; + enum lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO; + bool offload_params_to_cpu = false; + bool keep_clip_on_cpu = false; + bool keep_control_net_on_cpu = false; + bool keep_vae_on_cpu = false; + bool diffusion_flash_attn = false; + bool tae_preview_only = false; + bool diffusion_conv_direct = false; + bool vae_conv_direct = false; + bool force_sdxl_vae_conv_scale = false; + bool chroma_use_dit_mask = true; + bool chroma_use_t5_mask = false; + int chroma_t5_mask_pad = 1; + float flow_shift = INFINITY; fprintf(stderr, "parsing options: %p\n", options); @@ -124,16 +223,16 @@ int load_model(const char *model, char *model_path, char* options[], int threads } if (!strcmp(optname, "clip_l_path")) { - clip_l_path = optval; + clip_l_path = strdup(optval); } if (!strcmp(optname, "clip_g_path")) { - clip_g_path = optval; + clip_g_path = strdup(optval); } if (!strcmp(optname, "t5xxl_path")) { - t5xxl_path = optval; + t5xxl_path = strdup(optval); } if (!strcmp(optname, "vae_path")) { - vae_path = optval; + vae_path = strdup(optval); } if (!strcmp(optname, "scheduler")) { scheduler_str = optval; @@ -148,43 +247,167 @@ int load_model(const char *model, char *model_path, char* options[], int threads std::filesystem::path lora_path(optval); std::filesystem::path full_lora_path = model_path_str / lora_path; lora_dir = strdup(full_lora_path.string().c_str()); - lora_dir_allocated = true; fprintf(stderr, "Lora dir resolved to: %s\n", lora_dir); } else { lora_dir = strdup(optval); - lora_dir_allocated = true; fprintf(stderr, "No model path provided, using lora dir as-is: %s\n", lora_dir); } } + + // New parsing + if (!strcmp(optname, "clip_vision_path")) clip_vision_path = strdup(optval); + if (!strcmp(optname, "llm_path")) llm_path = strdup(optval); + if (!strcmp(optname, "llm_vision_path")) llm_vision_path = strdup(optval); + if (!strcmp(optname, "diffusion_model_path")) diffusion_model_path = strdup(optval); + if (!strcmp(optname, "high_noise_diffusion_model_path")) high_noise_diffusion_model_path = strdup(optval); + if (!strcmp(optname, "taesd_path")) taesd_path = strdup(optval); + if (!strcmp(optname, "control_net_path")) control_net_path = strdup(optval); + if (!strcmp(optname, "embedding_dir")) embedding_dir = strdup(optval); + if (!strcmp(optname, "photo_maker_path")) photo_maker_path = strdup(optval); + if (!strcmp(optname, "tensor_type_rules")) tensor_type_rules = strdup(optval); + + if (!strcmp(optname, "vae_decode_only")) vae_decode_only = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0); + if (!strcmp(optname, "offload_params_to_cpu")) offload_params_to_cpu = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0); + if (!strcmp(optname, "keep_clip_on_cpu")) keep_clip_on_cpu = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0); + if (!strcmp(optname, "keep_control_net_on_cpu")) keep_control_net_on_cpu = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0); + if (!strcmp(optname, "keep_vae_on_cpu")) keep_vae_on_cpu = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0); + if (!strcmp(optname, "diffusion_flash_attn")) diffusion_flash_attn = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0); + if (!strcmp(optname, "tae_preview_only")) tae_preview_only = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0); + if (!strcmp(optname, "diffusion_conv_direct")) diffusion_conv_direct = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0); + if (!strcmp(optname, "vae_conv_direct")) vae_conv_direct = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0); + if (!strcmp(optname, "force_sdxl_vae_conv_scale")) force_sdxl_vae_conv_scale = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0); + if (!strcmp(optname, "chroma_use_dit_mask")) chroma_use_dit_mask = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0); + if (!strcmp(optname, "chroma_use_t5_mask")) chroma_use_t5_mask = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0); + + if (!strcmp(optname, "n_threads")) n_threads = atoi(optval); + if (!strcmp(optname, "chroma_t5_mask_pad")) chroma_t5_mask_pad = atoi(optval); + + if (!strcmp(optname, "flow_shift")) flow_shift = atof(optval); + + if (!strcmp(optname, "rng_type")) { + int found = -1; + for (int m = 0; m < RNG_TYPE_COUNT; m++) { + if (!strcmp(optval, rng_type_str[m])) { + found = m; + break; + } + } + if (found != -1) { + rng_type = (rng_type_t)found; + fprintf(stderr, "Found rng_type: %s\n", optval); + } else { + fprintf(stderr, "Invalid rng_type: %s, using default\n", optval); + } + } + if (!strcmp(optname, "sampler_rng_type")) { + int found = -1; + for (int m = 0; m < RNG_TYPE_COUNT; m++) { + if (!strcmp(optval, rng_type_str[m])) { + found = m; + break; + } + } + if (found != -1) { + sampler_rng_type = (rng_type_t)found; + fprintf(stderr, "Found sampler_rng_type: %s\n", optval); + } else { + fprintf(stderr, "Invalid sampler_rng_type: %s, using default\n", optval); + } + } + if (!strcmp(optname, "prediction")) { + int found = -1; + for (int m = 0; m < PREDICTION_COUNT; m++) { + if (!strcmp(optval, prediction_str[m])) { + found = m; + break; + } + } + if (found != -1) { + prediction = (prediction_t)found; + fprintf(stderr, "Found prediction: %s\n", optval); + } else { + fprintf(stderr, "Invalid prediction: %s, using default\n", optval); + } + } + if (!strcmp(optname, "lora_apply_mode")) { + int found = -1; + for (int m = 0; m < LORA_APPLY_MODE_COUNT; m++) { + if (!strcmp(optval, lora_apply_mode_str[m])) { + found = m; + break; + } + } + if (found != -1) { + lora_apply_mode = (lora_apply_mode_t)found; + fprintf(stderr, "Found lora_apply_mode: %s\n", optval); + } else { + fprintf(stderr, "Invalid lora_apply_mode: %s, using default\n", optval); + } + } + if (!strcmp(optname, "wtype")) { + int found = -1; + for (int m = 0; m < SD_TYPE_COUNT; m++) { + if (sd_type_str[m] && !strcmp(optval, sd_type_str[m])) { + found = m; + break; + } + } + if (found != -1) { + wtype = (sd_type_t)found; + fprintf(stderr, "Found wtype: %s\n", optval); + } else { + fprintf(stderr, "Invalid wtype: %s, using default\n", optval); + } + } } fprintf(stderr, "parsed options\n"); fprintf (stderr, "Creating context\n"); - sd_ctx_params_t ctx_params; sd_ctx_params_init(&ctx_params); ctx_params.model_path = model; ctx_params.clip_l_path = clip_l_path; ctx_params.clip_g_path = clip_g_path; + ctx_params.clip_vision_path = clip_vision_path; ctx_params.t5xxl_path = t5xxl_path; - ctx_params.diffusion_model_path = stableDiffusionModel; + ctx_params.llm_path = llm_path; + ctx_params.llm_vision_path = llm_vision_path; + ctx_params.diffusion_model_path = diffusion_model_path; + ctx_params.high_noise_diffusion_model_path = high_noise_diffusion_model_path; ctx_params.vae_path = vae_path; - ctx_params.taesd_path = ""; - ctx_params.control_net_path = ""; + ctx_params.taesd_path = taesd_path; + ctx_params.control_net_path = control_net_path; ctx_params.lora_model_dir = lora_dir; - ctx_params.embedding_dir = ""; - ctx_params.vae_decode_only = false; + ctx_params.embedding_dir = embedding_dir; + ctx_params.photo_maker_path = photo_maker_path; + ctx_params.tensor_type_rules = tensor_type_rules; + ctx_params.vae_decode_only = vae_decode_only; + // XXX: Setting to true causes a segfault on the second run ctx_params.free_params_immediately = false; - ctx_params.n_threads = threads; - ctx_params.rng_type = STD_DEFAULT_RNG; + ctx_params.n_threads = n_threads; + ctx_params.rng_type = rng_type; + ctx_params.keep_clip_on_cpu = keep_clip_on_cpu; + if (wtype != SD_TYPE_COUNT) ctx_params.wtype = wtype; + if (sampler_rng_type != RNG_TYPE_COUNT) ctx_params.sampler_rng_type = sampler_rng_type; + if (prediction != PREDICTION_COUNT) ctx_params.prediction = prediction; + if (lora_apply_mode != LORA_APPLY_MODE_COUNT) ctx_params.lora_apply_mode = lora_apply_mode; + ctx_params.offload_params_to_cpu = offload_params_to_cpu; + ctx_params.keep_control_net_on_cpu = keep_control_net_on_cpu; + ctx_params.keep_vae_on_cpu = keep_vae_on_cpu; + ctx_params.diffusion_flash_attn = diffusion_flash_attn; + ctx_params.tae_preview_only = tae_preview_only; + ctx_params.diffusion_conv_direct = diffusion_conv_direct; + ctx_params.vae_conv_direct = vae_conv_direct; + ctx_params.force_sdxl_vae_conv_scale = force_sdxl_vae_conv_scale; + ctx_params.chroma_use_dit_mask = chroma_use_dit_mask; + ctx_params.chroma_use_t5_mask = chroma_use_t5_mask; + ctx_params.chroma_t5_mask_pad = chroma_t5_mask_pad; + ctx_params.flow_shift = flow_shift; sd_ctx_t* sd_ctx = new_sd_ctx(&ctx_params); if (sd_ctx == NULL) { fprintf (stderr, "failed loading model (generic error)\n"); - // Clean up allocated memory - if (lora_dir_allocated && lora_dir) { - free(lora_dir); - } + // TODO: Clean up allocated memory return 1; } fprintf (stderr, "Created context: OK\n"); @@ -215,11 +438,6 @@ int load_model(const char *model, char *model_path, char* options[], int threads sd_c = sd_ctx; - // Clean up allocated memory - if (lora_dir_allocated && lora_dir) { - free(lora_dir); - } - return 0; } @@ -248,6 +466,9 @@ sd_tiling_params_t* sd_img_gen_params_get_vae_tiling_params(sd_img_gen_params_t sd_img_gen_params_t* sd_img_gen_params_new(void) { sd_img_gen_params_t *params = (sd_img_gen_params_t *)std::malloc(sizeof(sd_img_gen_params_t)); sd_img_gen_params_init(params); + sd_sample_params_init(¶ms->sample_params); + sd_easycache_params_init(¶ms->easycache); + params->control_strength = 0.9f; return params; } @@ -265,7 +486,7 @@ void sd_img_gen_params_set_seed(sd_img_gen_params_t *params, int64_t seed) { params->seed = seed; } -int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char **ref_images, int ref_images_count) { +int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char* ref_images[], int ref_images_count) { sd_image_t* results; @@ -445,6 +666,10 @@ int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, cha } } + fprintf(stderr, "Generating image with params: \nctx\n---\n%s\ngen\n---\n%s\n", + sd_ctx_params_to_str(&ctx_params), + sd_img_gen_params_to_str(p)); + results = generate_image(sd_c, p); std::free(p); @@ -477,9 +702,12 @@ int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, cha fprintf (stderr, "Channel: %d\n", results[0].channel); fprintf (stderr, "Data: %p\n", results[0].data); - stbi_write_png(dst, results[0].width, results[0].height, results[0].channel, - results[0].data, 0, NULL); - fprintf (stderr, "Saved resulting image to '%s'\n", dst); + int ret = stbi_write_png(dst, results[0].width, results[0].height, results[0].channel, + results[0].data, 0, NULL); + if (ret) + fprintf (stderr, "Saved resulting image to '%s'\n", dst); + else + fprintf(stderr, "Failed to write image to '%s'\n", dst); // Clean up free(results[0].data); @@ -490,9 +718,10 @@ int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, cha for (auto buffer : ref_image_buffers) { if (buffer) free(buffer); } - fprintf (stderr, "gen_image is done: %s", dst); + fprintf (stderr, "gen_image is done: %s\n", dst); + fflush(stderr); - return 0; + return !ret; } int unload() { diff --git a/backend/go/stablediffusion-ggml/gosd.go b/backend/go/stablediffusion-ggml/gosd.go index d2bec357a..205f3f2d1 100644 --- a/backend/go/stablediffusion-ggml/gosd.go +++ b/backend/go/stablediffusion-ggml/gosd.go @@ -22,7 +22,7 @@ type SDGGML struct { var ( LoadModel func(model, model_apth string, options []uintptr, threads int32, diff int) int - GenImage func(params uintptr, steps int, dst string, cfgScale float32, srcImage string, strength float32, maskImage string, refImages []string, refImagesCount int) int + GenImage func(params uintptr, steps int, dst string, cfgScale float32, srcImage string, strength float32, maskImage string, refImages []uintptr, refImagesCount int) int TilingParamsSetEnabled func(params uintptr, enabled bool) TilingParamsSetTileSizes func(params uintptr, tileSizeX int, tileSizeY int) @@ -95,12 +95,12 @@ func (sd *SDGGML) Load(opts *pb.ModelOptions) error { sd.cfgScale = opts.CFGScale ret := LoadModel(modelFile, modelPathC, options, opts.Threads, diffusionModel) + runtime.KeepAlive(keepAlive) + fmt.Fprintf(os.Stderr, "LoadModel: %d\n", ret) if ret != 0 { return fmt.Errorf("could not load model") } - runtime.KeepAlive(keepAlive) - return nil } @@ -123,10 +123,15 @@ func (sd *SDGGML) GenerateImage(opts *pb.GenerateImageRequest) error { } } + // At the time of writing Purego doesn't recurse into slices and convert Go strings to pointers so we need to do that + var keepAlive []any refImagesCount := len(opts.RefImages) - refImages := make([]string, refImagesCount, refImagesCount+1) - copy(refImages, opts.RefImages) - *(*uintptr)(unsafe.Add(unsafe.Pointer(&refImages), refImagesCount)) = 0 + refImages := make([]uintptr, refImagesCount, refImagesCount+1) + for i, ri := range opts.RefImages { + bytep := CString(ri) + refImages[i] = uintptr(unsafe.Pointer(bytep)) + keepAlive = append(keepAlive, bytep) + } // Default strength for img2img (0.75 is a good default) strength := float32(0.75) @@ -140,6 +145,8 @@ func (sd *SDGGML) GenerateImage(opts *pb.GenerateImageRequest) error { TilingParamsSetEnabled(vaep, false) ret := GenImage(p, int(opts.Step), dst, sd.cfgScale, srcImage, strength, maskImage, refImages, refImagesCount) + runtime.KeepAlive(keepAlive) + fmt.Fprintf(os.Stderr, "GenImage: %d\n", ret) if ret != 0 { return fmt.Errorf("inference failed") } diff --git a/backend/go/stablediffusion-ggml/gosd.h b/backend/go/stablediffusion-ggml/gosd.h index 823133c4f..8324a3ead 100644 --- a/backend/go/stablediffusion-ggml/gosd.h +++ b/backend/go/stablediffusion-ggml/gosd.h @@ -17,7 +17,7 @@ void sd_img_gen_params_set_dimensions(sd_img_gen_params_t *params, int width, in void sd_img_gen_params_set_seed(sd_img_gen_params_t *params, int64_t seed); int load_model(const char *model, char *model_path, char* options[], int threads, int diffusionModel); -int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char **ref_images, int ref_images_count); +int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char* ref_images[], int ref_images_count); #ifdef __cplusplus } #endif diff --git a/gallery/index.yaml b/gallery/index.yaml index 610a8d70c..01cb206ee 100644 --- a/gallery/index.yaml +++ b/gallery/index.yaml @@ -20911,6 +20911,9 @@ overrides: parameters: model: flux1-dev-Q2_K.gguf + options: + - scheduler:simple + - keep_clip_on_cpu:true files: - filename: "flux1-dev-Q2_K.gguf" sha256: "b8c464bc0f10076ef8f00ba040d220d90c7993f7c4245ae80227d857f65df105" @@ -21078,6 +21081,32 @@ - filename: t5xxl_fp16.safetensors sha256: 6e480b09fae049a72d2a8c5fbccb8d3e92febeb233bbe9dfe7256958a9167635 uri: https://huggingface.co/comfyanonymous/flux_text_encoders/resolve/main/t5xxl_fp16.safetensors +- &zimage + name: Z-Image-Turbo + icon: https://z-image.ai/logo.png + license: apache-2.0 + description: | + Z-Image is a powerful and highly efficient image generation model with 6B parameters. Currently there are three variants of which this is the Turbo edition. + + 🚀 Z-Image-Turbo – A distilled version of Z-Image that matches or exceeds leading competitors with only 8 NFEs (Number of Function Evaluations). It offers ⚡️sub-second inference latency⚡️ on enterprise-grade H800 GPUs and fits comfortably within 16G VRAM consumer devices. It excels in photorealistic image generation, bilingual text rendering (English & Chinese), and robust instruction adherence. + urls: + - https://github.com/Tongyi-MAI/Z-Image + tags: + - text-to-image + - z-image + - gpu + url: "github:mudler/LocalAI/gallery/z-image-ggml.yaml@master" + files: + - filename: Qwen3-4B.Q4_K_M.gguf + sha256: a37931937683a723ae737a0c6fc67dab7782fd8a1b9dea2ca445b7a1dbd5ca3a + uri: huggingface://MaziyarPanahi/Qwen3-4B-GGUF/Qwen3-4B.Q4_K_M.gguf + - filename: z_image_turbo-Q4_0.gguf + sha256: 14b375ab4f226bc5378f68f37e899ef3c2242b8541e61e2bc1aff40976086fbd + uri: https://huggingface.co/leejet/Z-Image-Turbo-GGUF/resolve/main/z_image_turbo-Q4_0.gguf + - filename: ae.safetensors + sha256: afc8e28272cd15db3919bacdb6918ce9c1ed22e96cb12c4d5ed0fba823529e38 + uri: https://huggingface.co/ChuckMcSneed/FLUX.1-dev/resolve/main/ae.safetensors + - &whisper url: "github:mudler/LocalAI/gallery/whisper-base.yaml@master" ## Whisper name: "whisper-1" diff --git a/gallery/z-image-ggml.yaml b/gallery/z-image-ggml.yaml new file mode 100644 index 000000000..3d1e73f18 --- /dev/null +++ b/gallery/z-image-ggml.yaml @@ -0,0 +1,15 @@ +--- +name: "Z-Image-GGML" + +config_file: | + backend: stablediffusion-ggml + cfg_scale: 1 + name: z-image-test + options: + - diffusion_model + - llm_path:Qwen3-4B.Q4_K_M.gguf + - vae_path:ae.safetensors + - offload_params_to_cpu:true + parameters: + model: z_image_turbo-Q4_K.gguf + step: 25