diff --git a/backend/go/stablediffusion-ggml/gosd.cpp b/backend/go/stablediffusion-ggml/gosd.cpp index f989ee921..ce00ad11e 100644 --- a/backend/go/stablediffusion-ggml/gosd.cpp +++ b/backend/go/stablediffusion-ggml/gosd.cpp @@ -8,7 +8,9 @@ #include #include #include +#include #include +#include #include "gosd.h" #define STB_IMAGE_IMPLEMENTATION @@ -23,6 +25,7 @@ #define STB_IMAGE_RESIZE_STATIC #include "stb_image_resize.h" #include +#include // Names of the sampler method, same order as enum sample_method in stable-diffusion.h const char* sample_method_str[] = { @@ -133,6 +136,13 @@ static std::vector embedding_vec; // Storage for embedding strings (needs to persist as long as embedding_vec references them) static std::vector embedding_strings; +// Storage for LoRAs (needs to persist for the lifetime of generation params) +static std::vector lora_vec; +// Storage for LoRA strings (needs to persist as long as lora_vec references them) +static std::vector lora_strings; +// Storage for lora_dir path +static std::string lora_dir_path; + // Build embeddings vector from directory, similar to upstream CLI static void build_embedding_vec(const char* embedding_dir) { embedding_vec.clear(); @@ -186,6 +196,229 @@ static void build_embedding_vec(const char* embedding_dir) { fprintf(stderr, "Loaded %zu embeddings from %s\n", embedding_vec.size(), embedding_dir); } +// Discover LoRA files in directory and build a map of name -> path +static std::map discover_lora_files(const char* lora_dir) { + std::map lora_map; + + if (!lora_dir || strlen(lora_dir) == 0) { + fprintf(stderr, "LoRA directory not specified\n"); + return lora_map; + } + + if (!std::filesystem::exists(lora_dir) || !std::filesystem::is_directory(lora_dir)) { + fprintf(stderr, "LoRA directory does not exist or is not a directory: %s\n", lora_dir); + return lora_map; + } + + static const std::vector valid_ext = {".safetensors", ".ckpt", ".pt", ".gguf"}; + + fprintf(stderr, "Discovering LoRA files in: %s\n", lora_dir); + + for (const auto& entry : std::filesystem::directory_iterator(lora_dir)) { + if (!entry.is_regular_file()) { + continue; + } + + auto path = entry.path(); + std::string ext = path.extension().string(); + + bool valid = false; + for (const auto& e : valid_ext) { + if (ext == e) { + valid = true; + break; + } + } + if (!valid) { + continue; + } + + std::string name = path.stem().string(); // stem() already removes extension + std::string full_path = path.string(); + + // Store the name (without extension) -> full path mapping + // This allows users to specify just the name in + lora_map[name] = full_path; + + fprintf(stderr, "Found LoRA file: %s -> %s\n", name.c_str(), full_path.c_str()); + } + + fprintf(stderr, "Discovered %zu LoRA files in %s\n", lora_map.size(), lora_dir); + return lora_map; +} + +// Helper function to check if a path is absolute (matches upstream) +static bool is_absolute_path(const std::string& p) { +#ifdef _WIN32 + // Windows: C:/path or C:\path + return p.size() > 1 && std::isalpha(static_cast(p[0])) && p[1] == ':'; +#else + // Unix: /path + return !p.empty() && p[0] == '/'; +#endif +} + +// Parse LoRAs from prompt string (e.g., "" or "") +// Returns a vector of LoRA info and the cleaned prompt with LoRA tags removed +// Matches upstream implementation more closely +static std::pair, std::string> parse_loras_from_prompt(const std::string& prompt, const char* lora_dir) { + std::vector loras; + std::string cleaned_prompt = prompt; + + if (!lora_dir || strlen(lora_dir) == 0) { + fprintf(stderr, "LoRA directory not set, cannot parse LoRAs from prompt\n"); + return {loras, cleaned_prompt}; + } + + // Discover LoRA files for name-based lookup + std::map discovered_lora_map = discover_lora_files(lora_dir); + + // Map to accumulate multipliers for the same LoRA (matches upstream) + std::map lora_map; + std::map high_noise_lora_map; + + static const std::regex re(R"(]+):([^>]+)>)"); + static const std::vector valid_ext = {".pt", ".safetensors", ".gguf"}; + std::smatch m; + + std::string tmp = prompt; + + fprintf(stderr, "Parsing LoRAs from prompt: %s\n", prompt.c_str()); + + while (std::regex_search(tmp, m, re)) { + std::string raw_path = m[1].str(); + const std::string raw_mul = m[2].str(); + + float mul = 0.f; + try { + mul = std::stof(raw_mul); + } catch (...) { + tmp = m.suffix().str(); + cleaned_prompt = std::regex_replace(cleaned_prompt, re, "", std::regex_constants::format_first_only); + fprintf(stderr, "Invalid LoRA multiplier '%s', skipping\n", raw_mul.c_str()); + continue; + } + + bool is_high_noise = false; + static const std::string prefix = "|high_noise|"; + if (raw_path.rfind(prefix, 0) == 0) { + raw_path.erase(0, prefix.size()); + is_high_noise = true; + } + + std::filesystem::path final_path; + if (is_absolute_path(raw_path)) { + final_path = raw_path; + } else { + // Try name-based lookup first + auto it = discovered_lora_map.find(raw_path); + if (it != discovered_lora_map.end()) { + final_path = it->second; + } else { + // Try case-insensitive lookup + bool found = false; + for (const auto& pair : discovered_lora_map) { + std::string lower_name = raw_path; + std::string lower_key = pair.first; + std::transform(lower_name.begin(), lower_name.end(), lower_name.begin(), ::tolower); + std::transform(lower_key.begin(), lower_key.end(), lower_key.begin(), ::tolower); + if (lower_name == lower_key) { + final_path = pair.second; + found = true; + break; + } + } + if (!found) { + // Try as relative path in lora_dir + final_path = std::filesystem::path(lora_dir) / raw_path; + } + } + } + + // Try adding extensions if file doesn't exist + if (!std::filesystem::exists(final_path)) { + bool found = false; + for (const auto& ext : valid_ext) { + std::filesystem::path try_path = final_path; + try_path += ext; + if (std::filesystem::exists(try_path)) { + final_path = try_path; + found = true; + break; + } + } + if (!found) { + fprintf(stderr, "WARNING: LoRA file not found: %s\n", final_path.lexically_normal().string().c_str()); + tmp = m.suffix().str(); + cleaned_prompt = std::regex_replace(cleaned_prompt, re, "", std::regex_constants::format_first_only); + continue; + } + } + + // Normalize path (matches upstream) + const std::string key = final_path.lexically_normal().string(); + + // Accumulate multiplier if same LoRA appears multiple times (matches upstream) + if (is_high_noise) { + high_noise_lora_map[key] += mul; + } else { + lora_map[key] += mul; + } + + fprintf(stderr, "Parsed LoRA: path='%s', multiplier=%.2f, is_high_noise=%s\n", + key.c_str(), mul, is_high_noise ? "true" : "false"); + + cleaned_prompt = std::regex_replace(cleaned_prompt, re, "", std::regex_constants::format_first_only); + tmp = m.suffix().str(); + } + + // Build final LoRA vector from accumulated maps (matches upstream) + // Store all path strings first to ensure they persist + for (const auto& kv : lora_map) { + lora_strings.push_back(kv.first); + } + for (const auto& kv : high_noise_lora_map) { + lora_strings.push_back(kv.first); + } + + // Now build the LoRA vector with pointers to the stored strings + size_t string_idx = 0; + for (const auto& kv : lora_map) { + sd_lora_t item; + item.is_high_noise = false; + item.path = lora_strings[string_idx].c_str(); + item.multiplier = kv.second; + loras.push_back(item); + string_idx++; + } + + for (const auto& kv : high_noise_lora_map) { + sd_lora_t item; + item.is_high_noise = true; + item.path = lora_strings[string_idx].c_str(); + item.multiplier = kv.second; + loras.push_back(item); + string_idx++; + } + + // Clean up extra spaces + std::regex space_regex(R"(\s+)"); + cleaned_prompt = std::regex_replace(cleaned_prompt, space_regex, " "); + // Trim leading/trailing spaces + size_t first = cleaned_prompt.find_first_not_of(" \t"); + if (first != std::string::npos) { + cleaned_prompt.erase(0, first); + } + size_t last = cleaned_prompt.find_last_not_of(" \t"); + if (last != std::string::npos) { + cleaned_prompt.erase(last + 1); + } + + fprintf(stderr, "Parsed %zu LoRA(s) from prompt. Cleaned prompt: %s\n", loras.size(), cleaned_prompt.c_str()); + + return {loras, cleaned_prompt}; +} + // Copied from the upstream CLI static void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) { //SDParams* params = (SDParams*)data; @@ -304,11 +537,17 @@ int load_model(const char *model, char *model_path, char* options[], int threads std::filesystem::path lora_path(optval); std::filesystem::path full_lora_path = model_path_str / lora_path; lora_dir = strdup(full_lora_path.string().c_str()); - fprintf(stderr, "Lora dir resolved to: %s\n", lora_dir); + lora_dir_path = full_lora_path.string(); + fprintf(stderr, "LoRA dir resolved to: %s\n", lora_dir); } else { lora_dir = strdup(optval); + lora_dir_path = std::string(optval); fprintf(stderr, "No model path provided, using lora dir as-is: %s\n", lora_dir); } + // Discover LoRAs immediately when directory is set + if (lora_dir && strlen(lora_dir) > 0) { + discover_lora_files(lora_dir); + } } // New parsing @@ -450,6 +689,14 @@ int load_model(const char *model, char *model_path, char* options[], int threads ctx_params.taesd_path = taesd_path; ctx_params.control_net_path = control_net_path; ctx_params.lora_model_dir = lora_dir; + if (lora_dir && strlen(lora_dir) > 0) { + lora_dir_path = std::string(lora_dir); + fprintf(stderr, "LoRA model directory set to: %s\n", lora_dir); + // Discover LoRAs at load time for logging + discover_lora_files(lora_dir); + } else { + fprintf(stderr, "WARNING: LoRA model directory not set. LoRAs in prompts will not be loaded.\n"); + } // Set embeddings array and count ctx_params.embeddings = embedding_vec.empty() ? NULL : embedding_vec.data(); ctx_params.embedding_count = static_cast(embedding_vec.size()); @@ -546,9 +793,63 @@ sd_img_gen_params_t* sd_img_gen_params_new(void) { return params; } +// Storage for cleaned prompt strings (needs to persist) +static std::string cleaned_prompt_storage; +static std::string cleaned_negative_prompt_storage; + void sd_img_gen_params_set_prompts(sd_img_gen_params_t *params, const char *prompt, const char *negative_prompt) { - params->prompt = prompt; - params->negative_prompt = negative_prompt; + // Clear previous LoRA data + lora_vec.clear(); + lora_strings.clear(); + + // Parse LoRAs from prompt + std::string prompt_str = prompt ? prompt : ""; + std::string negative_prompt_str = negative_prompt ? negative_prompt : ""; + + // Get lora_dir from ctx_params if available, otherwise use stored path + const char* lora_dir_to_use = ctx_params.lora_model_dir; + if (!lora_dir_to_use || strlen(lora_dir_to_use) == 0) { + lora_dir_to_use = lora_dir_path.empty() ? nullptr : lora_dir_path.c_str(); + } + + auto [loras, cleaned_prompt] = parse_loras_from_prompt(prompt_str, lora_dir_to_use); + lora_vec = loras; + cleaned_prompt_storage = cleaned_prompt; + + // Also check negative prompt for LoRAs (though this is less common) + auto [neg_loras, cleaned_negative] = parse_loras_from_prompt(negative_prompt_str, lora_dir_to_use); + // Merge negative prompt LoRAs (though typically not used) + if (!neg_loras.empty()) { + fprintf(stderr, "Note: Found %zu LoRAs in negative prompt (may not be supported)\n", neg_loras.size()); + } + cleaned_negative_prompt_storage = cleaned_negative; + + // Set the cleaned prompts + params->prompt = cleaned_prompt_storage.c_str(); + params->negative_prompt = cleaned_negative_prompt_storage.c_str(); + + // Set LoRAs in params + params->loras = lora_vec.empty() ? nullptr : lora_vec.data(); + params->lora_count = static_cast(lora_vec.size()); + + fprintf(stderr, "Set prompts with %zu LoRAs. Original prompt: %s\n", lora_vec.size(), prompt ? prompt : "(null)"); + fprintf(stderr, "Cleaned prompt: %s\n", cleaned_prompt_storage.c_str()); + + // Debug: Verify LoRAs are set correctly + if (params->loras && params->lora_count > 0) { + fprintf(stderr, "DEBUG: LoRAs set in params structure:\n"); + for (uint32_t i = 0; i < params->lora_count; i++) { + fprintf(stderr, " params->loras[%u]: path='%s' (ptr=%p), multiplier=%.2f, is_high_noise=%s\n", + i, + params->loras[i].path ? params->loras[i].path : "(null)", + (void*)params->loras[i].path, + params->loras[i].multiplier, + params->loras[i].is_high_noise ? "true" : "false"); + } + } else { + fprintf(stderr, "DEBUG: No LoRAs set in params structure (loras=%p, lora_count=%u)\n", + (void*)params->loras, params->lora_count); + } } void sd_img_gen_params_set_dimensions(sd_img_gen_params_t *params, int width, int height) { @@ -740,6 +1041,20 @@ int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, cha } } + // Log LoRA information + if (p->loras && p->lora_count > 0) { + fprintf(stderr, "Using %u LoRA(s) in generation:\n", p->lora_count); + for (uint32_t i = 0; i < p->lora_count; i++) { + fprintf(stderr, " LoRA[%u]: path='%s', multiplier=%.2f, is_high_noise=%s\n", + i, + p->loras[i].path ? p->loras[i].path : "(null)", + p->loras[i].multiplier, + p->loras[i].is_high_noise ? "true" : "false"); + } + } else { + fprintf(stderr, "No LoRAs specified for this generation\n"); + } + fprintf(stderr, "Generating image with params: \nctx\n---\n%s\ngen\n---\n%s\n", sd_ctx_params_to_str(&ctx_params), sd_img_gen_params_to_str(p)); @@ -802,3 +1117,4 @@ int unload() { free_sd_ctx(sd_c); return 0; } +