Test code with transformers
#1
by
manueldeprada
- opened
For anyone wondering how to test this out on transformers, I'll save you some effort:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, KernelConfig
from transformers.integrations.hub_kernels import load_and_register_attn_kernel
import time
load_and_register_attn_kernel("kernels-community/metal-flash-sdpa")
model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceTB/SmolLM2-135M",
device_map="auto",
)
model.config._attn_implementation_internal = "kernels-community/metal-flash-sdpa"
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M", padding_side="left")
prompt = ["Hello, how are you?", "Wow, that's a big number!"]
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to("mps" if torch.backends.mps.is_available() else "cpu")
PARALLEL_REQUESTS = 5
# Warmup generation
_ = model.generate(**inputs, max_new_tokens=50, num_return_sequences=PARALLEL_REQUESTS, do_sample=True)
...
It will fail, then go to ~/.cache/huggingface/hub/models--kernels-community--metal-flash-sdpa/snapshots/.../build/torch27-metal-aarch64-darwin/metal_flash_sdpa/init.py and add a dummy
def flash_attn_func(*args, **kwargs):
return NotImplementedError("flash_attn_func is not implemented in this module.")
A few notes:
- attn_implementation flag or set_attn_implementation does not work since this code: https://github.com/huggingface/transformers/blob/a127710b3a91b3d323b27be8e06ddf507b009b87/src/transformers/modeling_utils.py#L2375-L2386 changes the impl from kernels-community/metal-flash-sdpa to kernels-community/vllm-flash-attn3. Is this intentional? @AntonV
- You need to add at least 2 sequences, so that padding exists on the batch. This is because kernels-community/metal-flash-sdpa only provides flash_attn_varlen_func, not flash_attn_func. That is also the reason that metal_flash_sdpa/init needs to be patched (so that https://github.com/huggingface/transformers/blob/a127710b3a91b3d323b27be8e06ddf507b009b87/src/transformers/modeling_flash_attention_utils.py#L97) does not complain).
Maybe theres an easier route to test it, lmk!
I think https://github.com/huggingface/transformers/pull/41427 will be interesting for you @manueldeprada
Re your notes:
- This was unintentionally introduced by some other PRs. It should only fall back when we request base fa versions, not kernels.
- Note that with your patch it will only work for padded inputs (or anything that circumvents the base fa functions), e.g. CB, padded anything, padding-free training.
The linked PR should resolve all your issues (avoiding manual patches). However, imo we should also support the base fa function (not only varlen) cc @danieldk @EricB
thanks for the quick fix!! youre the FA master now♥️😃