diff --git a/src/legogpt/models/llm.py b/src/legogpt/models/llm.py index 0f6e3c8..149bb2c 100644 --- a/src/legogpt/models/llm.py +++ b/src/legogpt/models/llm.py @@ -18,7 +18,6 @@ class LLM: model_name, torch_dtype=torch.bfloat16, device_map=device, - attn_implementation='flash_attention_2', ).to(device) self.kv_cache = None