mirror of
https://github.com/mudler/LocalAI.git
synced 2025-12-30 22:20:20 -06:00
feat(flash_attention): set auto for flash_attention in llama.cpp (#6168)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
dbdf2908ad
commit
739573e41b
@@ -242,7 +242,7 @@ message ModelOptions {
|
||||
|
||||
string Type = 49;
|
||||
|
||||
bool FlashAttention = 56;
|
||||
string FlashAttention = 56;
|
||||
bool NoKVOffload = 57;
|
||||
|
||||
string ModelPath = 59;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=3d16b29c3bb1ec816ac0e782f20d169097063919
|
||||
LLAMA_VERSION?=4d74393bcc956ccd7df68a6a06d1a0575cfa712c
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -304,7 +304,15 @@ static void params_parse(const backend::ModelOptions* request,
|
||||
}
|
||||
params.use_mlock = request->mlock();
|
||||
params.use_mmap = request->mmap();
|
||||
params.flash_attn = request->flashattention();
|
||||
|
||||
if (request->flashattention() == "on" || request->flashattention() == "enabled") {
|
||||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
|
||||
} else if (request->flashattention() == "off" || request->flashattention() == "disabled") {
|
||||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
||||
} else if (request->flashattention() == "auto") {
|
||||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
|
||||
}
|
||||
|
||||
params.no_kv_offload = request->nokvoffload();
|
||||
params.ctx_shift = false; // We control context-shifting in any case (and we disable it as it could just lead to infinite loops)
|
||||
|
||||
|
||||
@@ -78,6 +78,12 @@ func grpcModelOpts(c config.ModelConfig) *pb.ModelOptions {
|
||||
b = c.Batch
|
||||
}
|
||||
|
||||
flashAttention := "auto"
|
||||
|
||||
if c.FlashAttention != nil {
|
||||
flashAttention = *c.FlashAttention
|
||||
}
|
||||
|
||||
f16 := false
|
||||
if c.F16 != nil {
|
||||
f16 = *c.F16
|
||||
@@ -166,7 +172,7 @@ func grpcModelOpts(c config.ModelConfig) *pb.ModelOptions {
|
||||
LimitVideoPerPrompt: int32(c.LimitMMPerPrompt.LimitVideoPerPrompt),
|
||||
LimitAudioPerPrompt: int32(c.LimitMMPerPrompt.LimitAudioPerPrompt),
|
||||
MMProj: c.MMProj,
|
||||
FlashAttention: c.FlashAttention,
|
||||
FlashAttention: flashAttention,
|
||||
CacheTypeKey: c.CacheTypeK,
|
||||
CacheTypeValue: c.CacheTypeV,
|
||||
NoKVOffload: c.NoKVOffloading,
|
||||
|
||||
@@ -164,10 +164,10 @@ type LLMConfig struct {
|
||||
LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt" json:"limit_mm_per_prompt"` // vLLM
|
||||
MMProj string `yaml:"mmproj" json:"mmproj"`
|
||||
|
||||
FlashAttention bool `yaml:"flash_attention" json:"flash_attention"`
|
||||
NoKVOffloading bool `yaml:"no_kv_offloading" json:"no_kv_offloading"`
|
||||
CacheTypeK string `yaml:"cache_type_k" json:"cache_type_k"`
|
||||
CacheTypeV string `yaml:"cache_type_v" json:"cache_type_v"`
|
||||
FlashAttention *string `yaml:"flash_attention" json:"flash_attention"`
|
||||
NoKVOffloading bool `yaml:"no_kv_offloading" json:"no_kv_offloading"`
|
||||
CacheTypeK string `yaml:"cache_type_k" json:"cache_type_k"`
|
||||
CacheTypeV string `yaml:"cache_type_v" json:"cache_type_v"`
|
||||
|
||||
RopeScaling string `yaml:"rope_scaling" json:"rope_scaling"`
|
||||
ModelType string `yaml:"type" json:"type"`
|
||||
|
||||
Reference in New Issue
Block a user