mirror of
https://github.com/mudler/LocalAI.git
synced 2026-01-01 23:21:13 -06:00
feat(stablediffusion-ggml): add lora support (#7542)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
2bd6faaff5
commit
abfb0ff8fe
@@ -8,7 +8,9 @@
|
||||
#include <time.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <filesystem>
|
||||
#include <algorithm>
|
||||
#include "gosd.h"
|
||||
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
@@ -23,6 +25,7 @@
|
||||
#define STB_IMAGE_RESIZE_STATIC
|
||||
#include "stb_image_resize.h"
|
||||
#include <stdlib.h>
|
||||
#include <regex>
|
||||
|
||||
// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
|
||||
const char* sample_method_str[] = {
|
||||
@@ -133,6 +136,13 @@ static std::vector<sd_embedding_t> embedding_vec;
|
||||
// Storage for embedding strings (needs to persist as long as embedding_vec references them)
|
||||
static std::vector<std::string> embedding_strings;
|
||||
|
||||
// Storage for LoRAs (needs to persist for the lifetime of generation params)
|
||||
static std::vector<sd_lora_t> lora_vec;
|
||||
// Storage for LoRA strings (needs to persist as long as lora_vec references them)
|
||||
static std::vector<std::string> lora_strings;
|
||||
// Storage for lora_dir path
|
||||
static std::string lora_dir_path;
|
||||
|
||||
// Build embeddings vector from directory, similar to upstream CLI
|
||||
static void build_embedding_vec(const char* embedding_dir) {
|
||||
embedding_vec.clear();
|
||||
@@ -186,6 +196,229 @@ static void build_embedding_vec(const char* embedding_dir) {
|
||||
fprintf(stderr, "Loaded %zu embeddings from %s\n", embedding_vec.size(), embedding_dir);
|
||||
}
|
||||
|
||||
// Discover LoRA files in directory and build a map of name -> path
|
||||
static std::map<std::string, std::string> discover_lora_files(const char* lora_dir) {
|
||||
std::map<std::string, std::string> lora_map;
|
||||
|
||||
if (!lora_dir || strlen(lora_dir) == 0) {
|
||||
fprintf(stderr, "LoRA directory not specified\n");
|
||||
return lora_map;
|
||||
}
|
||||
|
||||
if (!std::filesystem::exists(lora_dir) || !std::filesystem::is_directory(lora_dir)) {
|
||||
fprintf(stderr, "LoRA directory does not exist or is not a directory: %s\n", lora_dir);
|
||||
return lora_map;
|
||||
}
|
||||
|
||||
static const std::vector<std::string> valid_ext = {".safetensors", ".ckpt", ".pt", ".gguf"};
|
||||
|
||||
fprintf(stderr, "Discovering LoRA files in: %s\n", lora_dir);
|
||||
|
||||
for (const auto& entry : std::filesystem::directory_iterator(lora_dir)) {
|
||||
if (!entry.is_regular_file()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto path = entry.path();
|
||||
std::string ext = path.extension().string();
|
||||
|
||||
bool valid = false;
|
||||
for (const auto& e : valid_ext) {
|
||||
if (ext == e) {
|
||||
valid = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!valid) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string name = path.stem().string(); // stem() already removes extension
|
||||
std::string full_path = path.string();
|
||||
|
||||
// Store the name (without extension) -> full path mapping
|
||||
// This allows users to specify just the name in <lora:name:strength>
|
||||
lora_map[name] = full_path;
|
||||
|
||||
fprintf(stderr, "Found LoRA file: %s -> %s\n", name.c_str(), full_path.c_str());
|
||||
}
|
||||
|
||||
fprintf(stderr, "Discovered %zu LoRA files in %s\n", lora_map.size(), lora_dir);
|
||||
return lora_map;
|
||||
}
|
||||
|
||||
// Helper function to check if a path is absolute (matches upstream)
|
||||
static bool is_absolute_path(const std::string& p) {
|
||||
#ifdef _WIN32
|
||||
// Windows: C:/path or C:\path
|
||||
return p.size() > 1 && std::isalpha(static_cast<unsigned char>(p[0])) && p[1] == ':';
|
||||
#else
|
||||
// Unix: /path
|
||||
return !p.empty() && p[0] == '/';
|
||||
#endif
|
||||
}
|
||||
|
||||
// Parse LoRAs from prompt string (e.g., "<lora:name:1.0>" or "<lora:name>")
|
||||
// Returns a vector of LoRA info and the cleaned prompt with LoRA tags removed
|
||||
// Matches upstream implementation more closely
|
||||
static std::pair<std::vector<sd_lora_t>, std::string> parse_loras_from_prompt(const std::string& prompt, const char* lora_dir) {
|
||||
std::vector<sd_lora_t> loras;
|
||||
std::string cleaned_prompt = prompt;
|
||||
|
||||
if (!lora_dir || strlen(lora_dir) == 0) {
|
||||
fprintf(stderr, "LoRA directory not set, cannot parse LoRAs from prompt\n");
|
||||
return {loras, cleaned_prompt};
|
||||
}
|
||||
|
||||
// Discover LoRA files for name-based lookup
|
||||
std::map<std::string, std::string> discovered_lora_map = discover_lora_files(lora_dir);
|
||||
|
||||
// Map to accumulate multipliers for the same LoRA (matches upstream)
|
||||
std::map<std::string, float> lora_map;
|
||||
std::map<std::string, float> high_noise_lora_map;
|
||||
|
||||
static const std::regex re(R"(<lora:([^:>]+):([^>]+)>)");
|
||||
static const std::vector<std::string> valid_ext = {".pt", ".safetensors", ".gguf"};
|
||||
std::smatch m;
|
||||
|
||||
std::string tmp = prompt;
|
||||
|
||||
fprintf(stderr, "Parsing LoRAs from prompt: %s\n", prompt.c_str());
|
||||
|
||||
while (std::regex_search(tmp, m, re)) {
|
||||
std::string raw_path = m[1].str();
|
||||
const std::string raw_mul = m[2].str();
|
||||
|
||||
float mul = 0.f;
|
||||
try {
|
||||
mul = std::stof(raw_mul);
|
||||
} catch (...) {
|
||||
tmp = m.suffix().str();
|
||||
cleaned_prompt = std::regex_replace(cleaned_prompt, re, "", std::regex_constants::format_first_only);
|
||||
fprintf(stderr, "Invalid LoRA multiplier '%s', skipping\n", raw_mul.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
bool is_high_noise = false;
|
||||
static const std::string prefix = "|high_noise|";
|
||||
if (raw_path.rfind(prefix, 0) == 0) {
|
||||
raw_path.erase(0, prefix.size());
|
||||
is_high_noise = true;
|
||||
}
|
||||
|
||||
std::filesystem::path final_path;
|
||||
if (is_absolute_path(raw_path)) {
|
||||
final_path = raw_path;
|
||||
} else {
|
||||
// Try name-based lookup first
|
||||
auto it = discovered_lora_map.find(raw_path);
|
||||
if (it != discovered_lora_map.end()) {
|
||||
final_path = it->second;
|
||||
} else {
|
||||
// Try case-insensitive lookup
|
||||
bool found = false;
|
||||
for (const auto& pair : discovered_lora_map) {
|
||||
std::string lower_name = raw_path;
|
||||
std::string lower_key = pair.first;
|
||||
std::transform(lower_name.begin(), lower_name.end(), lower_name.begin(), ::tolower);
|
||||
std::transform(lower_key.begin(), lower_key.end(), lower_key.begin(), ::tolower);
|
||||
if (lower_name == lower_key) {
|
||||
final_path = pair.second;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
// Try as relative path in lora_dir
|
||||
final_path = std::filesystem::path(lora_dir) / raw_path;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try adding extensions if file doesn't exist
|
||||
if (!std::filesystem::exists(final_path)) {
|
||||
bool found = false;
|
||||
for (const auto& ext : valid_ext) {
|
||||
std::filesystem::path try_path = final_path;
|
||||
try_path += ext;
|
||||
if (std::filesystem::exists(try_path)) {
|
||||
final_path = try_path;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
fprintf(stderr, "WARNING: LoRA file not found: %s\n", final_path.lexically_normal().string().c_str());
|
||||
tmp = m.suffix().str();
|
||||
cleaned_prompt = std::regex_replace(cleaned_prompt, re, "", std::regex_constants::format_first_only);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize path (matches upstream)
|
||||
const std::string key = final_path.lexically_normal().string();
|
||||
|
||||
// Accumulate multiplier if same LoRA appears multiple times (matches upstream)
|
||||
if (is_high_noise) {
|
||||
high_noise_lora_map[key] += mul;
|
||||
} else {
|
||||
lora_map[key] += mul;
|
||||
}
|
||||
|
||||
fprintf(stderr, "Parsed LoRA: path='%s', multiplier=%.2f, is_high_noise=%s\n",
|
||||
key.c_str(), mul, is_high_noise ? "true" : "false");
|
||||
|
||||
cleaned_prompt = std::regex_replace(cleaned_prompt, re, "", std::regex_constants::format_first_only);
|
||||
tmp = m.suffix().str();
|
||||
}
|
||||
|
||||
// Build final LoRA vector from accumulated maps (matches upstream)
|
||||
// Store all path strings first to ensure they persist
|
||||
for (const auto& kv : lora_map) {
|
||||
lora_strings.push_back(kv.first);
|
||||
}
|
||||
for (const auto& kv : high_noise_lora_map) {
|
||||
lora_strings.push_back(kv.first);
|
||||
}
|
||||
|
||||
// Now build the LoRA vector with pointers to the stored strings
|
||||
size_t string_idx = 0;
|
||||
for (const auto& kv : lora_map) {
|
||||
sd_lora_t item;
|
||||
item.is_high_noise = false;
|
||||
item.path = lora_strings[string_idx].c_str();
|
||||
item.multiplier = kv.second;
|
||||
loras.push_back(item);
|
||||
string_idx++;
|
||||
}
|
||||
|
||||
for (const auto& kv : high_noise_lora_map) {
|
||||
sd_lora_t item;
|
||||
item.is_high_noise = true;
|
||||
item.path = lora_strings[string_idx].c_str();
|
||||
item.multiplier = kv.second;
|
||||
loras.push_back(item);
|
||||
string_idx++;
|
||||
}
|
||||
|
||||
// Clean up extra spaces
|
||||
std::regex space_regex(R"(\s+)");
|
||||
cleaned_prompt = std::regex_replace(cleaned_prompt, space_regex, " ");
|
||||
// Trim leading/trailing spaces
|
||||
size_t first = cleaned_prompt.find_first_not_of(" \t");
|
||||
if (first != std::string::npos) {
|
||||
cleaned_prompt.erase(0, first);
|
||||
}
|
||||
size_t last = cleaned_prompt.find_last_not_of(" \t");
|
||||
if (last != std::string::npos) {
|
||||
cleaned_prompt.erase(last + 1);
|
||||
}
|
||||
|
||||
fprintf(stderr, "Parsed %zu LoRA(s) from prompt. Cleaned prompt: %s\n", loras.size(), cleaned_prompt.c_str());
|
||||
|
||||
return {loras, cleaned_prompt};
|
||||
}
|
||||
|
||||
// Copied from the upstream CLI
|
||||
static void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
|
||||
//SDParams* params = (SDParams*)data;
|
||||
@@ -304,11 +537,17 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
std::filesystem::path lora_path(optval);
|
||||
std::filesystem::path full_lora_path = model_path_str / lora_path;
|
||||
lora_dir = strdup(full_lora_path.string().c_str());
|
||||
fprintf(stderr, "Lora dir resolved to: %s\n", lora_dir);
|
||||
lora_dir_path = full_lora_path.string();
|
||||
fprintf(stderr, "LoRA dir resolved to: %s\n", lora_dir);
|
||||
} else {
|
||||
lora_dir = strdup(optval);
|
||||
lora_dir_path = std::string(optval);
|
||||
fprintf(stderr, "No model path provided, using lora dir as-is: %s\n", lora_dir);
|
||||
}
|
||||
// Discover LoRAs immediately when directory is set
|
||||
if (lora_dir && strlen(lora_dir) > 0) {
|
||||
discover_lora_files(lora_dir);
|
||||
}
|
||||
}
|
||||
|
||||
// New parsing
|
||||
@@ -450,6 +689,14 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
ctx_params.taesd_path = taesd_path;
|
||||
ctx_params.control_net_path = control_net_path;
|
||||
ctx_params.lora_model_dir = lora_dir;
|
||||
if (lora_dir && strlen(lora_dir) > 0) {
|
||||
lora_dir_path = std::string(lora_dir);
|
||||
fprintf(stderr, "LoRA model directory set to: %s\n", lora_dir);
|
||||
// Discover LoRAs at load time for logging
|
||||
discover_lora_files(lora_dir);
|
||||
} else {
|
||||
fprintf(stderr, "WARNING: LoRA model directory not set. LoRAs in prompts will not be loaded.\n");
|
||||
}
|
||||
// Set embeddings array and count
|
||||
ctx_params.embeddings = embedding_vec.empty() ? NULL : embedding_vec.data();
|
||||
ctx_params.embedding_count = static_cast<uint32_t>(embedding_vec.size());
|
||||
@@ -546,9 +793,63 @@ sd_img_gen_params_t* sd_img_gen_params_new(void) {
|
||||
return params;
|
||||
}
|
||||
|
||||
// Storage for cleaned prompt strings (needs to persist)
|
||||
static std::string cleaned_prompt_storage;
|
||||
static std::string cleaned_negative_prompt_storage;
|
||||
|
||||
void sd_img_gen_params_set_prompts(sd_img_gen_params_t *params, const char *prompt, const char *negative_prompt) {
|
||||
params->prompt = prompt;
|
||||
params->negative_prompt = negative_prompt;
|
||||
// Clear previous LoRA data
|
||||
lora_vec.clear();
|
||||
lora_strings.clear();
|
||||
|
||||
// Parse LoRAs from prompt
|
||||
std::string prompt_str = prompt ? prompt : "";
|
||||
std::string negative_prompt_str = negative_prompt ? negative_prompt : "";
|
||||
|
||||
// Get lora_dir from ctx_params if available, otherwise use stored path
|
||||
const char* lora_dir_to_use = ctx_params.lora_model_dir;
|
||||
if (!lora_dir_to_use || strlen(lora_dir_to_use) == 0) {
|
||||
lora_dir_to_use = lora_dir_path.empty() ? nullptr : lora_dir_path.c_str();
|
||||
}
|
||||
|
||||
auto [loras, cleaned_prompt] = parse_loras_from_prompt(prompt_str, lora_dir_to_use);
|
||||
lora_vec = loras;
|
||||
cleaned_prompt_storage = cleaned_prompt;
|
||||
|
||||
// Also check negative prompt for LoRAs (though this is less common)
|
||||
auto [neg_loras, cleaned_negative] = parse_loras_from_prompt(negative_prompt_str, lora_dir_to_use);
|
||||
// Merge negative prompt LoRAs (though typically not used)
|
||||
if (!neg_loras.empty()) {
|
||||
fprintf(stderr, "Note: Found %zu LoRAs in negative prompt (may not be supported)\n", neg_loras.size());
|
||||
}
|
||||
cleaned_negative_prompt_storage = cleaned_negative;
|
||||
|
||||
// Set the cleaned prompts
|
||||
params->prompt = cleaned_prompt_storage.c_str();
|
||||
params->negative_prompt = cleaned_negative_prompt_storage.c_str();
|
||||
|
||||
// Set LoRAs in params
|
||||
params->loras = lora_vec.empty() ? nullptr : lora_vec.data();
|
||||
params->lora_count = static_cast<uint32_t>(lora_vec.size());
|
||||
|
||||
fprintf(stderr, "Set prompts with %zu LoRAs. Original prompt: %s\n", lora_vec.size(), prompt ? prompt : "(null)");
|
||||
fprintf(stderr, "Cleaned prompt: %s\n", cleaned_prompt_storage.c_str());
|
||||
|
||||
// Debug: Verify LoRAs are set correctly
|
||||
if (params->loras && params->lora_count > 0) {
|
||||
fprintf(stderr, "DEBUG: LoRAs set in params structure:\n");
|
||||
for (uint32_t i = 0; i < params->lora_count; i++) {
|
||||
fprintf(stderr, " params->loras[%u]: path='%s' (ptr=%p), multiplier=%.2f, is_high_noise=%s\n",
|
||||
i,
|
||||
params->loras[i].path ? params->loras[i].path : "(null)",
|
||||
(void*)params->loras[i].path,
|
||||
params->loras[i].multiplier,
|
||||
params->loras[i].is_high_noise ? "true" : "false");
|
||||
}
|
||||
} else {
|
||||
fprintf(stderr, "DEBUG: No LoRAs set in params structure (loras=%p, lora_count=%u)\n",
|
||||
(void*)params->loras, params->lora_count);
|
||||
}
|
||||
}
|
||||
|
||||
void sd_img_gen_params_set_dimensions(sd_img_gen_params_t *params, int width, int height) {
|
||||
@@ -740,6 +1041,20 @@ int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, cha
|
||||
}
|
||||
}
|
||||
|
||||
// Log LoRA information
|
||||
if (p->loras && p->lora_count > 0) {
|
||||
fprintf(stderr, "Using %u LoRA(s) in generation:\n", p->lora_count);
|
||||
for (uint32_t i = 0; i < p->lora_count; i++) {
|
||||
fprintf(stderr, " LoRA[%u]: path='%s', multiplier=%.2f, is_high_noise=%s\n",
|
||||
i,
|
||||
p->loras[i].path ? p->loras[i].path : "(null)",
|
||||
p->loras[i].multiplier,
|
||||
p->loras[i].is_high_noise ? "true" : "false");
|
||||
}
|
||||
} else {
|
||||
fprintf(stderr, "No LoRAs specified for this generation\n");
|
||||
}
|
||||
|
||||
fprintf(stderr, "Generating image with params: \nctx\n---\n%s\ngen\n---\n%s\n",
|
||||
sd_ctx_params_to_str(&ctx_params),
|
||||
sd_img_gen_params_to_str(p));
|
||||
@@ -802,3 +1117,4 @@ int unload() {
|
||||
free_sd_ctx(sd_c);
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user