Upgrade to lightning att2
Browse files- lightning_attention2.py +540 -0
- modeling_transnormer.py +4 -3
lightning_attention2.py
ADDED
|
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 OpenNLPLab
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# coding=utf-8
|
| 16 |
+
import torch
|
| 17 |
+
import triton
|
| 18 |
+
import triton.language as tl
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@triton.jit
|
| 22 |
+
def _fwd_kernel(
|
| 23 |
+
Q,
|
| 24 |
+
K,
|
| 25 |
+
V,
|
| 26 |
+
Out,
|
| 27 |
+
S,
|
| 28 |
+
stride_qz,
|
| 29 |
+
stride_qh,
|
| 30 |
+
stride_qm,
|
| 31 |
+
stride_qk,
|
| 32 |
+
stride_kz,
|
| 33 |
+
stride_kh,
|
| 34 |
+
stride_kn,
|
| 35 |
+
stride_kk,
|
| 36 |
+
stride_vz,
|
| 37 |
+
stride_vh,
|
| 38 |
+
stride_vn,
|
| 39 |
+
stride_ve,
|
| 40 |
+
stride_oz,
|
| 41 |
+
stride_oh,
|
| 42 |
+
stride_om,
|
| 43 |
+
stride_oe,
|
| 44 |
+
stride_sh,
|
| 45 |
+
Z,
|
| 46 |
+
H,
|
| 47 |
+
N_CTX,
|
| 48 |
+
BLOCK_M: tl.constexpr,
|
| 49 |
+
BLOCK_DMODEL_QK: tl.constexpr,
|
| 50 |
+
BLOCK_N: tl.constexpr,
|
| 51 |
+
BLOCK_DMODEL_V: tl.constexpr,
|
| 52 |
+
IS_CAUSAL: tl.constexpr,
|
| 53 |
+
USE_DECAY: tl.constexpr,
|
| 54 |
+
):
|
| 55 |
+
start_m = tl.program_id(0)
|
| 56 |
+
off_hz = tl.program_id(1)
|
| 57 |
+
off_h = off_hz % H
|
| 58 |
+
# initialize offsets
|
| 59 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 60 |
+
offs_n = tl.arange(0, BLOCK_N)
|
| 61 |
+
offs_k = tl.arange(0, BLOCK_DMODEL_QK)
|
| 62 |
+
offs_e = tl.arange(0, BLOCK_DMODEL_V)
|
| 63 |
+
# get current offset of q k v
|
| 64 |
+
off_q = (off_hz * stride_qh + offs_m[:, None] * stride_qm
|
| 65 |
+
+ offs_k[None, :] * stride_qk)
|
| 66 |
+
off_k = (off_hz * stride_kh + offs_n[:, None] * stride_kn
|
| 67 |
+
+ offs_k[None, :] * stride_kk)
|
| 68 |
+
off_v = (off_hz * stride_vh + offs_n[:, None] * stride_vn
|
| 69 |
+
+ offs_e[None, :] * stride_ve)
|
| 70 |
+
off_o = (off_hz * stride_oh + offs_m[:, None] * stride_om
|
| 71 |
+
+ offs_e[None, :] * stride_oe)
|
| 72 |
+
|
| 73 |
+
# Initialize pointers to Q, K, V
|
| 74 |
+
q_ptrs = Q + off_q
|
| 75 |
+
k_ptrs = K + off_k
|
| 76 |
+
v_ptrs = V + off_v
|
| 77 |
+
|
| 78 |
+
# initialize pointer to m and l
|
| 79 |
+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_V], dtype=tl.float32)
|
| 80 |
+
# load q: it will stay in SRAM throughout
|
| 81 |
+
q = tl.load(q_ptrs, mask=offs_m[:, None] < N_CTX, other=0.0)
|
| 82 |
+
# loop over k, v and update accumulator
|
| 83 |
+
lo = 0
|
| 84 |
+
# print(start_m)
|
| 85 |
+
hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
|
| 86 |
+
for start_n in range(lo, hi, BLOCK_N):
|
| 87 |
+
# -- load k, v --
|
| 88 |
+
k = tl.load(
|
| 89 |
+
k_ptrs + start_n * stride_kn,
|
| 90 |
+
mask=(start_n + offs_n)[:, None] < N_CTX,
|
| 91 |
+
other=0.0,
|
| 92 |
+
)
|
| 93 |
+
v = tl.load(
|
| 94 |
+
v_ptrs + start_n * stride_vn,
|
| 95 |
+
mask=(start_n + offs_n)[:, None] < N_CTX,
|
| 96 |
+
other=0.0,
|
| 97 |
+
)
|
| 98 |
+
# -- compute qk ---
|
| 99 |
+
# qk = tl.dot(q, k)
|
| 100 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 101 |
+
# qk += tl.dot(q, k, trans_b=True)
|
| 102 |
+
qk += tl.dot(q, tl.trans(k))
|
| 103 |
+
if IS_CAUSAL:
|
| 104 |
+
index = offs_m[:, None] - (start_n + offs_n[None, :])
|
| 105 |
+
if USE_DECAY:
|
| 106 |
+
S_block_ptr = S + off_h * stride_sh
|
| 107 |
+
s = tl.load(S_block_ptr)
|
| 108 |
+
s_index = s * index
|
| 109 |
+
s_index = tl.where(s_index >= 0, -s_index, float("-inf"))
|
| 110 |
+
qk = tl.exp(s_index) * qk
|
| 111 |
+
else:
|
| 112 |
+
qk = tl.where(index >= 0, qk, 0)
|
| 113 |
+
acc += tl.dot(qk, v.to(qk.dtype))
|
| 114 |
+
|
| 115 |
+
out_ptrs = Out + off_o
|
| 116 |
+
tl.store(out_ptrs, acc.to(q.dtype), mask=offs_m[:, None] < N_CTX)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@triton.jit
|
| 120 |
+
def _bwd_kernel_kv(
|
| 121 |
+
Q,
|
| 122 |
+
K,
|
| 123 |
+
V,
|
| 124 |
+
S,
|
| 125 |
+
DO,
|
| 126 |
+
DQ,
|
| 127 |
+
DK,
|
| 128 |
+
DV,
|
| 129 |
+
stride_qz,
|
| 130 |
+
stride_qh,
|
| 131 |
+
stride_qm,
|
| 132 |
+
stride_qk,
|
| 133 |
+
stride_kz,
|
| 134 |
+
stride_kh,
|
| 135 |
+
stride_kn,
|
| 136 |
+
stride_kk,
|
| 137 |
+
stride_vz,
|
| 138 |
+
stride_vh,
|
| 139 |
+
stride_vn,
|
| 140 |
+
stride_ve,
|
| 141 |
+
stride_oz,
|
| 142 |
+
stride_oh,
|
| 143 |
+
stride_om,
|
| 144 |
+
stride_oe,
|
| 145 |
+
stride_sh,
|
| 146 |
+
Z,
|
| 147 |
+
H,
|
| 148 |
+
N_CTX,
|
| 149 |
+
num_block,
|
| 150 |
+
BLOCK_M: tl.constexpr,
|
| 151 |
+
BLOCK_DMODEL_QK: tl.constexpr,
|
| 152 |
+
BLOCK_N: tl.constexpr,
|
| 153 |
+
BLOCK_DMODEL_V: tl.constexpr,
|
| 154 |
+
CAUSAL: tl.constexpr,
|
| 155 |
+
USE_DECAY: tl.constexpr,
|
| 156 |
+
):
|
| 157 |
+
start_n = tl.program_id(0)
|
| 158 |
+
off_hz = tl.program_id(1)
|
| 159 |
+
|
| 160 |
+
off_z = off_hz // H
|
| 161 |
+
off_h = off_hz % H
|
| 162 |
+
# offset pointers for batch/head
|
| 163 |
+
Q += off_z * stride_qz + off_h * stride_qh
|
| 164 |
+
K += off_z * stride_kz + off_h * stride_kh
|
| 165 |
+
V += off_z * stride_vz + off_h * stride_vh
|
| 166 |
+
DO += off_z * stride_oz + off_h * stride_oh
|
| 167 |
+
DQ += off_z * stride_qz + off_h * stride_qh
|
| 168 |
+
DK += off_z * stride_kz + off_h * stride_kh
|
| 169 |
+
DV += off_z * stride_vz + off_h * stride_vh
|
| 170 |
+
|
| 171 |
+
# start of q
|
| 172 |
+
if CAUSAL:
|
| 173 |
+
lo = start_n * BLOCK_M
|
| 174 |
+
else:
|
| 175 |
+
lo = 0
|
| 176 |
+
# initialize row/col offsets
|
| 177 |
+
# seqlence offset
|
| 178 |
+
offs_qm = lo + tl.arange(0, BLOCK_M)
|
| 179 |
+
offs_kvn = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 180 |
+
# feature offset
|
| 181 |
+
offs_qkk = tl.arange(0, BLOCK_DMODEL_QK)
|
| 182 |
+
offs_ve = tl.arange(0, BLOCK_DMODEL_V)
|
| 183 |
+
# row block index
|
| 184 |
+
offs_m = tl.arange(0, BLOCK_M)
|
| 185 |
+
# initialize pointers to value-like data
|
| 186 |
+
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_qkk[None, :] * stride_qk)
|
| 187 |
+
k_ptrs = K + (offs_kvn[:, None] * stride_kn
|
| 188 |
+
+ offs_qkk[None, :] * stride_kk)
|
| 189 |
+
v_ptrs = V + (offs_kvn[:, None] * stride_vn + offs_ve[None, :] * stride_ve)
|
| 190 |
+
do_ptrs = DO + (offs_qm[:, None] * stride_om
|
| 191 |
+
+ offs_ve[None, :] * stride_oe)
|
| 192 |
+
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm
|
| 193 |
+
+ offs_qkk[None, :] * stride_qk)
|
| 194 |
+
# initialize dv amd dk
|
| 195 |
+
dv = tl.zeros([BLOCK_N, BLOCK_DMODEL_V], dtype=tl.float32)
|
| 196 |
+
dk = tl.zeros([BLOCK_N, BLOCK_DMODEL_QK], dtype=tl.float32)
|
| 197 |
+
# k and v stay in SRAM throughout
|
| 198 |
+
k = tl.load(k_ptrs, mask=offs_kvn[:, None] < N_CTX, other=0.0)
|
| 199 |
+
v = tl.load(v_ptrs, mask=offs_kvn[:, None] < N_CTX, other=0.0)
|
| 200 |
+
# loop over rows
|
| 201 |
+
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
|
| 202 |
+
offs_m_curr = start_m + offs_m
|
| 203 |
+
# load q, k, v, do on-chip
|
| 204 |
+
q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < N_CTX, other=0.0)
|
| 205 |
+
qk = tl.dot(q, tl.trans(k))
|
| 206 |
+
# qk = tl.dot(q, k, trans_b=True)
|
| 207 |
+
if CAUSAL:
|
| 208 |
+
index = offs_m_curr[:, None] - offs_kvn[None, :]
|
| 209 |
+
if USE_DECAY:
|
| 210 |
+
S_block_ptr = S + off_h * stride_sh
|
| 211 |
+
s = tl.load(S_block_ptr)
|
| 212 |
+
s_index = s * index
|
| 213 |
+
s_index = tl.where(s_index >= 0, -s_index, float("-inf"))
|
| 214 |
+
s = tl.exp(s_index)
|
| 215 |
+
qk = qk * s
|
| 216 |
+
else:
|
| 217 |
+
qk = tl.where(index >= 0, qk, 0)
|
| 218 |
+
|
| 219 |
+
p = qk
|
| 220 |
+
# compute dv
|
| 221 |
+
do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < N_CTX, other=0.0)
|
| 222 |
+
dv += tl.dot(tl.trans(p.to(do.dtype)), do)
|
| 223 |
+
dp = tl.dot(do, tl.trans(v).to(do.dtype))
|
| 224 |
+
if CAUSAL:
|
| 225 |
+
if USE_DECAY:
|
| 226 |
+
dp = dp * s
|
| 227 |
+
else:
|
| 228 |
+
dp = tl.where(index >= 0, dp, 0)
|
| 229 |
+
|
| 230 |
+
dk += tl.dot(tl.trans(dp.to(q.dtype)), q).to(tl.float32)
|
| 231 |
+
|
| 232 |
+
# increment pointers
|
| 233 |
+
q_ptrs += BLOCK_M * stride_qm
|
| 234 |
+
do_ptrs += BLOCK_M * stride_om
|
| 235 |
+
# write-back
|
| 236 |
+
dv_ptrs = DV + (offs_kvn[:, None] * stride_vn
|
| 237 |
+
+ offs_ve[None, :] * stride_ve)
|
| 238 |
+
dk_ptrs = DK + (offs_kvn[:, None] * stride_kn
|
| 239 |
+
+ offs_qkk[None, :] * stride_kk)
|
| 240 |
+
tl.store(dv_ptrs, dv, mask=offs_kvn[:, None] < N_CTX)
|
| 241 |
+
tl.store(dk_ptrs, dk, mask=offs_kvn[:, None] < N_CTX)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
@triton.jit
|
| 245 |
+
def _bwd_kernel_q(
|
| 246 |
+
Q,
|
| 247 |
+
K,
|
| 248 |
+
V,
|
| 249 |
+
S,
|
| 250 |
+
DO,
|
| 251 |
+
DQ,
|
| 252 |
+
DK,
|
| 253 |
+
DV,
|
| 254 |
+
stride_qz,
|
| 255 |
+
stride_qh,
|
| 256 |
+
stride_qm,
|
| 257 |
+
stride_qk,
|
| 258 |
+
stride_kz,
|
| 259 |
+
stride_kh,
|
| 260 |
+
stride_kn,
|
| 261 |
+
stride_kk,
|
| 262 |
+
stride_vz,
|
| 263 |
+
stride_vh,
|
| 264 |
+
stride_vn,
|
| 265 |
+
stride_ve,
|
| 266 |
+
stride_oz,
|
| 267 |
+
stride_oh,
|
| 268 |
+
stride_om,
|
| 269 |
+
stride_oe,
|
| 270 |
+
stride_sh,
|
| 271 |
+
Z,
|
| 272 |
+
H,
|
| 273 |
+
N_CTX,
|
| 274 |
+
num_block,
|
| 275 |
+
BLOCK_M: tl.constexpr,
|
| 276 |
+
BLOCK_DMODEL_QK: tl.constexpr,
|
| 277 |
+
BLOCK_N: tl.constexpr,
|
| 278 |
+
BLOCK_DMODEL_V: tl.constexpr,
|
| 279 |
+
CAUSAL: tl.constexpr,
|
| 280 |
+
USE_DECAY: tl.constexpr,
|
| 281 |
+
):
|
| 282 |
+
start_m = tl.program_id(0)
|
| 283 |
+
off_hz = tl.program_id(1)
|
| 284 |
+
off_z = off_hz // H
|
| 285 |
+
off_h = off_hz % H
|
| 286 |
+
# offset pointers for batch/head
|
| 287 |
+
K += off_z * stride_kz + off_h * stride_kh
|
| 288 |
+
V += off_z * stride_vz + off_h * stride_vh
|
| 289 |
+
DO += off_z * stride_oz + off_h * stride_oh
|
| 290 |
+
DQ += off_z * stride_qz + off_h * stride_qh
|
| 291 |
+
# feature offset
|
| 292 |
+
offs_qkk = tl.arange(0, BLOCK_DMODEL_QK)
|
| 293 |
+
offs_ve = tl.arange(0, BLOCK_DMODEL_V)
|
| 294 |
+
# row block index
|
| 295 |
+
offs_m = tl.arange(0, BLOCK_M)
|
| 296 |
+
# row block index
|
| 297 |
+
offs_qm = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 298 |
+
# do
|
| 299 |
+
do_ptrs = DO + (offs_qm[:, None] * stride_om
|
| 300 |
+
+ offs_ve[None, :] * stride_oe)
|
| 301 |
+
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm
|
| 302 |
+
+ offs_qkk[None, :] * stride_qk)
|
| 303 |
+
|
| 304 |
+
do = tl.load(do_ptrs, mask=offs_qm[:, None] < N_CTX, other=0.0)
|
| 305 |
+
|
| 306 |
+
dq = tl.zeros([BLOCK_M, BLOCK_DMODEL_QK], dtype=tl.float32)
|
| 307 |
+
lo = 0
|
| 308 |
+
hi = (start_m + 1) * BLOCK_M if CAUSAL else N_CTX
|
| 309 |
+
|
| 310 |
+
offs_m_curr = start_m * BLOCK_M + offs_m
|
| 311 |
+
|
| 312 |
+
for start_n in range(0, num_block):
|
| 313 |
+
offs_kvn = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 314 |
+
k_ptrs = K + (offs_kvn[:, None] * stride_kn
|
| 315 |
+
+ offs_qkk[None, :] * stride_kk)
|
| 316 |
+
v_ptrs = V + (offs_kvn[:, None] * stride_vn
|
| 317 |
+
+ offs_ve[None, :] * stride_ve)
|
| 318 |
+
# k and v stay in SRAM throughout
|
| 319 |
+
k = tl.load(k_ptrs, mask=offs_kvn[:, None] < N_CTX, other=0.0)
|
| 320 |
+
v = tl.load(v_ptrs, mask=offs_kvn[:, None] < N_CTX, other=0.0)
|
| 321 |
+
# dp = do vT
|
| 322 |
+
dp = tl.dot(do, tl.trans(v).to(do.dtype))
|
| 323 |
+
if CAUSAL:
|
| 324 |
+
index = offs_m_curr[:, None] - offs_kvn[None, :]
|
| 325 |
+
if USE_DECAY:
|
| 326 |
+
S_block_ptr = S + off_h * stride_sh
|
| 327 |
+
s = tl.load(S_block_ptr)
|
| 328 |
+
s_index = s * index
|
| 329 |
+
s_index = tl.where(s_index >= 0, -s_index, float("-inf"))
|
| 330 |
+
s = tl.exp(s_index)
|
| 331 |
+
dp = dp * s
|
| 332 |
+
else:
|
| 333 |
+
dp = tl.where(index >= 0, dp, 0)
|
| 334 |
+
# dq = dq + dp k
|
| 335 |
+
dq += tl.dot(dp.to(k.dtype), k)
|
| 336 |
+
|
| 337 |
+
tl.store(dq_ptrs, dq, mask=offs_qm[:, None] < N_CTX)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class _attention(torch.autograd.Function):
|
| 341 |
+
|
| 342 |
+
@staticmethod
|
| 343 |
+
def forward(ctx, q, k, v, causal, s):
|
| 344 |
+
q = q.contiguous()
|
| 345 |
+
k = k.contiguous()
|
| 346 |
+
v = v.contiguous()
|
| 347 |
+
s = s.contiguous()
|
| 348 |
+
# only support for Ampere now
|
| 349 |
+
capability = torch.cuda.get_device_capability()
|
| 350 |
+
if capability[0] < 8:
|
| 351 |
+
raise RuntimeError(
|
| 352 |
+
"Lightning attention currently only supported for compute capability >= 80"
|
| 353 |
+
)
|
| 354 |
+
# shape constraints
|
| 355 |
+
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
| 356 |
+
# right
|
| 357 |
+
o = torch.empty(
|
| 358 |
+
(q.shape[0], q.shape[1], q.shape[2], v.shape[-1]),
|
| 359 |
+
dtype=q.dtype,
|
| 360 |
+
device=q.device,
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
BLOCK_M = 128
|
| 364 |
+
BLOCK_N = 64
|
| 365 |
+
num_warps = 4 if Lk <= 64 else 8
|
| 366 |
+
num_stages = 1
|
| 367 |
+
|
| 368 |
+
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
|
| 369 |
+
use_decay = s.shape[0] > 0
|
| 370 |
+
_fwd_kernel[grid](
|
| 371 |
+
q,
|
| 372 |
+
k,
|
| 373 |
+
v,
|
| 374 |
+
o,
|
| 375 |
+
s,
|
| 376 |
+
q.stride(0),
|
| 377 |
+
q.stride(1),
|
| 378 |
+
q.stride(2),
|
| 379 |
+
q.stride(3),
|
| 380 |
+
k.stride(0),
|
| 381 |
+
k.stride(1),
|
| 382 |
+
k.stride(2),
|
| 383 |
+
k.stride(3),
|
| 384 |
+
v.stride(0),
|
| 385 |
+
v.stride(1),
|
| 386 |
+
v.stride(2),
|
| 387 |
+
v.stride(3),
|
| 388 |
+
o.stride(0),
|
| 389 |
+
o.stride(1),
|
| 390 |
+
o.stride(2),
|
| 391 |
+
o.stride(3),
|
| 392 |
+
s.stride(0),
|
| 393 |
+
q.shape[0],
|
| 394 |
+
q.shape[1],
|
| 395 |
+
q.shape[2],
|
| 396 |
+
BLOCK_M=BLOCK_M,
|
| 397 |
+
BLOCK_DMODEL_QK=Lk,
|
| 398 |
+
BLOCK_N=BLOCK_N,
|
| 399 |
+
BLOCK_DMODEL_V=Lv,
|
| 400 |
+
IS_CAUSAL=causal,
|
| 401 |
+
USE_DECAY=use_decay,
|
| 402 |
+
num_warps=num_warps,
|
| 403 |
+
num_stages=num_stages,
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
ctx.save_for_backward(q, k, v, s)
|
| 407 |
+
ctx.grid = grid
|
| 408 |
+
ctx.BLOCK_M = BLOCK_M
|
| 409 |
+
ctx.BLOCK_DMODEL_QK = Lk
|
| 410 |
+
ctx.BLOCK_N = BLOCK_N
|
| 411 |
+
ctx.BLOCK_DMODEL_V = Lv
|
| 412 |
+
ctx.causal = causal
|
| 413 |
+
ctx.use_decay = use_decay
|
| 414 |
+
return o
|
| 415 |
+
|
| 416 |
+
@staticmethod
|
| 417 |
+
def backward(ctx, do):
|
| 418 |
+
q, k, v, s = ctx.saved_tensors
|
| 419 |
+
BLOCK_M = 32
|
| 420 |
+
BLOCK_N = 32
|
| 421 |
+
num_warps = 4
|
| 422 |
+
num_stages = 1
|
| 423 |
+
|
| 424 |
+
do = do.contiguous()
|
| 425 |
+
dq = torch.zeros_like(q, dtype=torch.float32)
|
| 426 |
+
dk = torch.empty_like(k)
|
| 427 |
+
dv = torch.empty_like(v)
|
| 428 |
+
|
| 429 |
+
grid_kv = (triton.cdiv(k.shape[2],
|
| 430 |
+
BLOCK_N), k.shape[0] * k.shape[1], 1)
|
| 431 |
+
_bwd_kernel_kv[grid_kv](
|
| 432 |
+
q,
|
| 433 |
+
k,
|
| 434 |
+
v,
|
| 435 |
+
s,
|
| 436 |
+
do,
|
| 437 |
+
dq,
|
| 438 |
+
dk,
|
| 439 |
+
dv,
|
| 440 |
+
q.stride(0),
|
| 441 |
+
q.stride(1),
|
| 442 |
+
q.stride(2),
|
| 443 |
+
q.stride(3),
|
| 444 |
+
k.stride(0),
|
| 445 |
+
k.stride(1),
|
| 446 |
+
k.stride(2),
|
| 447 |
+
k.stride(3),
|
| 448 |
+
v.stride(0),
|
| 449 |
+
v.stride(1),
|
| 450 |
+
v.stride(2),
|
| 451 |
+
v.stride(3),
|
| 452 |
+
do.stride(0),
|
| 453 |
+
do.stride(1),
|
| 454 |
+
do.stride(2),
|
| 455 |
+
do.stride(3),
|
| 456 |
+
s.stride(0),
|
| 457 |
+
q.shape[0],
|
| 458 |
+
q.shape[1],
|
| 459 |
+
q.shape[2],
|
| 460 |
+
grid_kv[0],
|
| 461 |
+
BLOCK_M=BLOCK_M,
|
| 462 |
+
BLOCK_DMODEL_QK=ctx.BLOCK_DMODEL_QK,
|
| 463 |
+
BLOCK_N=BLOCK_N,
|
| 464 |
+
BLOCK_DMODEL_V=ctx.BLOCK_DMODEL_V,
|
| 465 |
+
CAUSAL=ctx.causal,
|
| 466 |
+
USE_DECAY=ctx.use_decay,
|
| 467 |
+
num_warps=num_warps,
|
| 468 |
+
num_stages=num_stages,
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
grid_q = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
|
| 472 |
+
|
| 473 |
+
_bwd_kernel_q[grid_q](
|
| 474 |
+
q,
|
| 475 |
+
k,
|
| 476 |
+
v,
|
| 477 |
+
s,
|
| 478 |
+
do,
|
| 479 |
+
dq,
|
| 480 |
+
dk,
|
| 481 |
+
dv,
|
| 482 |
+
q.stride(0),
|
| 483 |
+
q.stride(1),
|
| 484 |
+
q.stride(2),
|
| 485 |
+
q.stride(3),
|
| 486 |
+
k.stride(0),
|
| 487 |
+
k.stride(1),
|
| 488 |
+
k.stride(2),
|
| 489 |
+
k.stride(3),
|
| 490 |
+
v.stride(0),
|
| 491 |
+
v.stride(1),
|
| 492 |
+
v.stride(2),
|
| 493 |
+
v.stride(3),
|
| 494 |
+
do.stride(0),
|
| 495 |
+
do.stride(1),
|
| 496 |
+
do.stride(2),
|
| 497 |
+
do.stride(3),
|
| 498 |
+
s.stride(0),
|
| 499 |
+
q.shape[0],
|
| 500 |
+
q.shape[1],
|
| 501 |
+
q.shape[2],
|
| 502 |
+
grid_q[0],
|
| 503 |
+
BLOCK_M=BLOCK_M,
|
| 504 |
+
BLOCK_DMODEL_QK=ctx.BLOCK_DMODEL_QK,
|
| 505 |
+
BLOCK_N=BLOCK_N,
|
| 506 |
+
BLOCK_DMODEL_V=ctx.BLOCK_DMODEL_V,
|
| 507 |
+
CAUSAL=ctx.causal,
|
| 508 |
+
USE_DECAY=ctx.use_decay,
|
| 509 |
+
num_warps=num_warps,
|
| 510 |
+
num_stages=num_stages,
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
return dq.to(q.dtype), dk, dv, None, None
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
attention = _attention.apply
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def lightning_attention(q, k, v, causal, ed):
|
| 520 |
+
d = q.shape[-1]
|
| 521 |
+
e = v.shape[-1]
|
| 522 |
+
# arr = f(d)
|
| 523 |
+
if d >= 128:
|
| 524 |
+
m = 128
|
| 525 |
+
else:
|
| 526 |
+
m = 64
|
| 527 |
+
arr = [m * i for i in range(d // m + 1)]
|
| 528 |
+
if arr[-1] != d:
|
| 529 |
+
arr.append(d)
|
| 530 |
+
n = len(arr)
|
| 531 |
+
output = 0
|
| 532 |
+
for i in range(n - 1):
|
| 533 |
+
s = arr[i]
|
| 534 |
+
e = arr[i + 1]
|
| 535 |
+
q1 = q[..., s:e]
|
| 536 |
+
k1 = k[..., s:e]
|
| 537 |
+
o = attention(q1, k1, v, causal, ed)
|
| 538 |
+
output = output + o
|
| 539 |
+
|
| 540 |
+
return output
|
modeling_transnormer.py
CHANGED
|
@@ -63,7 +63,7 @@ BLOCK = 256
|
|
| 63 |
|
| 64 |
if use_triton:
|
| 65 |
try:
|
| 66 |
-
from .
|
| 67 |
|
| 68 |
has_lightning_attention = True
|
| 69 |
except (ImportError, ModuleNotFoundError):
|
|
@@ -345,8 +345,9 @@ class NormLinearAttention(nn.Module):
|
|
| 345 |
k[:, :, i:i + 1],
|
| 346 |
v[:, :, i:i + 1],
|
| 347 |
)
|
| 348 |
-
qkv = torch.einsum("... n e, ... e d -> ... n d",
|
| 349 |
-
|
|
|
|
| 350 |
output.append(qkv)
|
| 351 |
output = torch.concat(output, dim=-2)
|
| 352 |
|
|
|
|
| 63 |
|
| 64 |
if use_triton:
|
| 65 |
try:
|
| 66 |
+
from .lightning_attention2 import lightning_attention
|
| 67 |
|
| 68 |
has_lightning_attention = True
|
| 69 |
except (ImportError, ModuleNotFoundError):
|
|
|
|
| 345 |
k[:, :, i:i + 1],
|
| 346 |
v[:, :, i:i + 1],
|
| 347 |
)
|
| 348 |
+
qkv = torch.einsum("... n e, ... e d -> ... n d", q[:, :,
|
| 349 |
+
i:i + 1],
|
| 350 |
+
kv.to(q.dtype))
|
| 351 |
output.append(qkv)
|
| 352 |
output = torch.concat(output, dim=-2)
|
| 353 |
|