Kernels

Test code with transformers

#1
by manueldeprada - opened

For anyone wondering how to test this out on transformers, I'll save you some effort:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, KernelConfig
from transformers.integrations.hub_kernels import load_and_register_attn_kernel
import time


load_and_register_attn_kernel("kernels-community/metal-flash-sdpa")

model = AutoModelForCausalLM.from_pretrained(
    "HuggingFaceTB/SmolLM2-135M", 
    device_map="auto", 
)
model.config._attn_implementation_internal = "kernels-community/metal-flash-sdpa"
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M", padding_side="left")

prompt = ["Hello, how are you?", "Wow, that's a big number!"]
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to("mps" if torch.backends.mps.is_available() else "cpu")
PARALLEL_REQUESTS = 5

# Warmup generation
_ = model.generate(**inputs, max_new_tokens=50, num_return_sequences=PARALLEL_REQUESTS, do_sample=True)
...

It will fail, then go to ~/.cache/huggingface/hub/models--kernels-community--metal-flash-sdpa/snapshots/.../build/torch27-metal-aarch64-darwin/metal_flash_sdpa/init.py and add a dummy

def flash_attn_func(*args, **kwargs):
    return NotImplementedError("flash_attn_func is not implemented in this module.")

A few notes:

  1. attn_implementation flag or set_attn_implementation does not work since this code: https://github.com/huggingface/transformers/blob/a127710b3a91b3d323b27be8e06ddf507b009b87/src/transformers/modeling_utils.py#L2375-L2386 changes the impl from kernels-community/metal-flash-sdpa to kernels-community/vllm-flash-attn3. Is this intentional? @AntonV
  2. You need to add at least 2 sequences, so that padding exists on the batch. This is because kernels-community/metal-flash-sdpa only provides flash_attn_varlen_func, not flash_attn_func. That is also the reason that metal_flash_sdpa/init needs to be patched (so that https://github.com/huggingface/transformers/blob/a127710b3a91b3d323b27be8e06ddf507b009b87/src/transformers/modeling_flash_attention_utils.py#L97) does not complain).

Maybe theres an easier route to test it, lmk!

I think https://github.com/huggingface/transformers/pull/41427 will be interesting for you @manueldeprada

Re your notes:

  1. This was unintentionally introduced by some other PRs. It should only fall back when we request base fa versions, not kernels.
  2. Note that with your patch it will only work for padded inputs (or anything that circumvents the base fa functions), e.g. CB, padded anything, padding-free training.

The linked PR should resolve all your issues (avoiding manual patches). However, imo we should also support the base fa function (not only varlen) cc @danieldk @EricB

thanks for the quick fix!! youre the FA master now♥️😃

Sign up or log in to comment