|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Tuple |
|
|
|
|
|
import torch |
|
|
import triton |
|
|
import triton.language as tl |
|
|
|
|
|
from fla.ops.utils.op import exp, safe_exp |
|
|
from fla.utils import check_shared_mem, is_nvidia_hopper |
|
|
|
|
|
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] |
|
|
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] |
|
|
|
|
|
|
|
|
@triton.heuristics({ |
|
|
'USE_G': lambda args: args['g'] is not None, |
|
|
'USE_OFFSETS': lambda args: args['offsets'] is not None |
|
|
}) |
|
|
@triton.autotune( |
|
|
configs=[ |
|
|
triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) |
|
|
for BK in BKV_LIST |
|
|
for BV in BKV_LIST |
|
|
for num_warps in NUM_WARPS |
|
|
for num_stages in [2, 3, 4] |
|
|
], |
|
|
key=['H', 'K', 'V', 'BT'], |
|
|
) |
|
|
@triton.jit(do_not_specialize=['T']) |
|
|
def chunk_fwd_kernel_o( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
h, |
|
|
g, |
|
|
o, |
|
|
offsets, |
|
|
indices, |
|
|
scale, |
|
|
T, |
|
|
H: tl.constexpr, |
|
|
K: tl.constexpr, |
|
|
V: tl.constexpr, |
|
|
BT: tl.constexpr, |
|
|
BK: tl.constexpr, |
|
|
BV: tl.constexpr, |
|
|
USE_G: tl.constexpr, |
|
|
USE_OFFSETS: tl.constexpr, |
|
|
HEAD_FIRST: tl.constexpr |
|
|
): |
|
|
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
|
|
i_b, i_h = i_bh // H, i_bh % H |
|
|
|
|
|
if USE_OFFSETS: |
|
|
i_tg = i_t |
|
|
i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) |
|
|
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) |
|
|
T = eos - bos |
|
|
NT = tl.cdiv(T, BT) |
|
|
else: |
|
|
NT = tl.cdiv(T, BT) |
|
|
i_tg = i_b * NT + i_t |
|
|
bos, eos = i_b * T, i_b * T + T |
|
|
|
|
|
s_qk = K if HEAD_FIRST else H*K |
|
|
s_vo = V if HEAD_FIRST else H*V |
|
|
s_g = 1 if HEAD_FIRST else H |
|
|
|
|
|
q += (i_bh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K) |
|
|
k += (i_bh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K) |
|
|
v += (i_bh * T*V) if HEAD_FIRST else ((bos * H + i_h) * V) |
|
|
o += (i_bh * T*V) if HEAD_FIRST else ((bos * H + i_h) * V) |
|
|
h += ((i_bh * NT + i_t).to(tl.int64) * K*V) if HEAD_FIRST else ((i_tg * H + i_h).to(tl.int64) * K*V) |
|
|
|
|
|
b_o = tl.zeros([BT, BV], dtype=tl.float32) |
|
|
b_A = tl.zeros([BT, BT], dtype=tl.float32) |
|
|
|
|
|
for i_k in range(tl.cdiv(K, BK)): |
|
|
p_q = tl.make_block_ptr(q, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
|
|
p_k = tl.make_block_ptr(k, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) |
|
|
p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) |
|
|
|
|
|
b_q = tl.load(p_q, boundary_check=(0, 1)) |
|
|
|
|
|
b_k = tl.load(p_k, boundary_check=(0, 1)) |
|
|
|
|
|
b_h = tl.load(p_h, boundary_check=(0, 1)) |
|
|
|
|
|
|
|
|
b_o += tl.dot(b_q, b_h) |
|
|
|
|
|
b_A += tl.dot(b_q, b_k) |
|
|
|
|
|
if USE_G: |
|
|
g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h) |
|
|
p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,)) |
|
|
b_g = tl.load(p_g, boundary_check=(0,)) |
|
|
b_o = b_o * exp(b_g)[:, None] |
|
|
b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :]) |
|
|
|
|
|
o_i = tl.arange(0, BT) |
|
|
m_A = o_i[:, None] >= o_i[None, :] |
|
|
b_A = tl.where(m_A, b_A, 0) |
|
|
|
|
|
p_v = tl.make_block_ptr(v, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) |
|
|
p_o = tl.make_block_ptr(o, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) |
|
|
b_v = tl.load(p_v, boundary_check=(0, 1)) |
|
|
|
|
|
|
|
|
|
|
|
b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale |
|
|
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
|
|
|
|
@triton.heuristics({ |
|
|
'USE_OFFSETS': lambda args: args['offsets'] is not None, |
|
|
'USE_G': lambda args: args['g'] is not None, |
|
|
'USE_DW': lambda args: args['dw'] is not None |
|
|
}) |
|
|
@triton.autotune( |
|
|
configs=[ |
|
|
triton.Config({}, num_warps=num_warps, num_stages=num_stages) |
|
|
for num_warps in NUM_WARPS |
|
|
for num_stages in [2, 3, 4] |
|
|
], |
|
|
key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G', 'USE_DW'], |
|
|
) |
|
|
@triton.jit(do_not_specialize=['T']) |
|
|
def chunk_bwd_kernel_dqkwg( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
h, |
|
|
g, |
|
|
do, |
|
|
dh, |
|
|
dq, |
|
|
dk, |
|
|
dg, |
|
|
w, |
|
|
dv, |
|
|
dw, |
|
|
offsets, |
|
|
indices, |
|
|
scale, |
|
|
B: tl.constexpr, |
|
|
T, |
|
|
H: tl.constexpr, |
|
|
K: tl.constexpr, |
|
|
V: tl.constexpr, |
|
|
BT: tl.constexpr, |
|
|
BK: tl.constexpr, |
|
|
BV: tl.constexpr, |
|
|
USE_G: tl.constexpr, |
|
|
USE_DW: tl.constexpr, |
|
|
USE_OFFSETS: tl.constexpr, |
|
|
HEAD_FIRST: tl.constexpr |
|
|
): |
|
|
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
|
|
i_b, i_h = i_bh // H, i_bh % H |
|
|
if USE_G: |
|
|
dg += i_k * B * H * T |
|
|
if USE_OFFSETS: |
|
|
i_tg = i_t |
|
|
i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) |
|
|
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) |
|
|
T = eos - bos |
|
|
NT = tl.cdiv(T, BT) |
|
|
else: |
|
|
NT = tl.cdiv(T, BT) |
|
|
i_tg = i_b * NT + i_t |
|
|
bos, eos = i_b * T, i_b * T + T |
|
|
|
|
|
|
|
|
v += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V |
|
|
do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V |
|
|
h += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V |
|
|
dh += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V |
|
|
q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K |
|
|
k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K |
|
|
dq += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K |
|
|
dk += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K |
|
|
s_qk = K if HEAD_FIRST else H*K |
|
|
s_vo = V if HEAD_FIRST else H*V |
|
|
s_g = 1 if HEAD_FIRST else H |
|
|
|
|
|
|
|
|
if USE_DW: |
|
|
dw += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K |
|
|
dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V |
|
|
w += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K |
|
|
|
|
|
b_dq = tl.zeros([BT, BK], dtype=tl.float32) |
|
|
b_dk = tl.zeros([BT, BK], dtype=tl.float32) |
|
|
b_ds = tl.zeros([BT, BT], dtype=tl.float32) |
|
|
b_dg_last = tl.zeros([1,], dtype=tl.float32) if USE_G else None |
|
|
b_dw = tl.zeros([BT, BK], dtype=tl.float32) if USE_DW else None |
|
|
|
|
|
for i_v in range(tl.cdiv(V, BV)): |
|
|
p_v = tl.make_block_ptr(v, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) |
|
|
p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) |
|
|
p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) |
|
|
p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) |
|
|
|
|
|
b_v = tl.load(p_v, boundary_check=(0, 1)) |
|
|
b_do = tl.load(p_do, boundary_check=(0, 1)) |
|
|
|
|
|
b_h = tl.load(p_h, boundary_check=(0, 1)) |
|
|
b_dh = tl.load(p_dh, boundary_check=(0, 1)) |
|
|
if USE_G: |
|
|
b_dg_last += (tl.sum(b_h * b_dh)) |
|
|
|
|
|
b_ds += tl.dot(b_do, tl.trans(b_v)) |
|
|
|
|
|
b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) |
|
|
|
|
|
b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) |
|
|
if USE_DW: |
|
|
p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) |
|
|
b_dv = tl.load(p_dv, boundary_check=(0, 1)) |
|
|
b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype)) |
|
|
|
|
|
if USE_DW and not USE_G: |
|
|
p_dw = tl.make_block_ptr(dw, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
|
|
tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
|
tl.debug_barrier() |
|
|
o_i = tl.arange(0, BT) |
|
|
p_q = tl.make_block_ptr(q, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
|
|
p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
|
|
b_q = tl.load(p_q, boundary_check=(0, 1)) |
|
|
b_k = tl.load(p_k, boundary_check=(0, 1)) |
|
|
|
|
|
p_dq = tl.make_block_ptr(dq, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
|
|
p_dk = tl.make_block_ptr(dk, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
|
|
|
|
|
if USE_G: |
|
|
b_dg = tl.zeros([BT,], dtype=tl.float32) |
|
|
g += i_bh * T if HEAD_FIRST else bos * H + i_h |
|
|
dg += i_bh * T if HEAD_FIRST else bos * H + i_h |
|
|
p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,)) |
|
|
b_g = tl.load(p_g, boundary_check=(0,)) |
|
|
b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * s_g) |
|
|
b_dg_last *= exp(b_g_last) |
|
|
|
|
|
if USE_DW: |
|
|
p_w = tl.make_block_ptr(w, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
|
|
p_dw = tl.make_block_ptr(dw, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
|
|
b_w = tl.load(p_w, boundary_check=(0, 1)) |
|
|
b_dw = b_dw * exp(b_g)[:, None] |
|
|
tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) |
|
|
b_dg -= tl.sum(b_w * b_dw, axis=1) |
|
|
|
|
|
b_dq = b_dq * exp(b_g)[:, None] * scale |
|
|
b_dg += tl.sum(b_dq * b_q, axis=1) |
|
|
|
|
|
b_dk = b_dk * safe_exp(-b_g + b_g_last)[:, None] |
|
|
b_dg -= tl.sum(b_k * b_dk, axis=1) |
|
|
b_dg_last += tl.sum(b_dk * b_k) |
|
|
|
|
|
b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * safe_exp(b_g[:, None] - b_g[None, :]), 0) * scale |
|
|
b_ds2 = b_ds * tl.dot(b_q, tl.trans(b_k)) |
|
|
b_dg += tl.sum(b_ds2, axis=1) |
|
|
b_dg -= tl.sum(b_ds2, axis=0) |
|
|
|
|
|
b_ds = b_ds.to(b_k.dtype) |
|
|
|
|
|
b_dq += tl.dot(b_ds, b_k) |
|
|
b_dk += tl.dot(tl.trans(b_ds), b_q) |
|
|
p_dg = tl.make_block_ptr(dg, (T,), (s_g,), (i_t * BT,), (BT,), (0,)) |
|
|
|
|
|
|
|
|
b_dg = tl.where(o_i < min(BT, T-i_t*BT) - 1, b_dg, b_dg + b_dg_last) |
|
|
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) |
|
|
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) |
|
|
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) |
|
|
else: |
|
|
b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0) |
|
|
b_ds = b_ds.to(b_k.dtype) |
|
|
b_dq += tl.dot(b_ds, b_k) |
|
|
b_dk += tl.dot(tl.trans(b_ds), b_q) * scale |
|
|
b_dq *= scale |
|
|
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) |
|
|
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
|
|
|
|
@triton.heuristics({ |
|
|
'USE_OFFSETS': lambda args: args['offsets'] is not None, |
|
|
'USE_G': lambda args: args['g'] is not None, |
|
|
}) |
|
|
@triton.autotune( |
|
|
configs=[ |
|
|
triton.Config({}, num_warps=num_warps, num_stages=num_stages) |
|
|
for num_warps in [2, 4, 8] |
|
|
for num_stages in [2, 3, 4] |
|
|
], |
|
|
key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'], |
|
|
) |
|
|
@triton.jit(do_not_specialize=['T']) |
|
|
def chunk_bwd_kernel_dv( |
|
|
q, |
|
|
k, |
|
|
g, |
|
|
do, |
|
|
dv, |
|
|
dh, |
|
|
offsets, |
|
|
indices, |
|
|
scale, |
|
|
T, |
|
|
H: tl.constexpr, |
|
|
K: tl.constexpr, |
|
|
V: tl.constexpr, |
|
|
BT: tl.constexpr, |
|
|
BK: tl.constexpr, |
|
|
BV: tl.constexpr, |
|
|
USE_G: tl.constexpr, |
|
|
USE_OFFSETS: tl.constexpr, |
|
|
HEAD_FIRST: tl.constexpr |
|
|
): |
|
|
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
|
|
i_b, i_h = i_bh // H, i_bh % H |
|
|
if USE_OFFSETS: |
|
|
i_tg = i_t |
|
|
i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) |
|
|
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) |
|
|
T = eos - bos |
|
|
NT = tl.cdiv(T, BT) |
|
|
else: |
|
|
NT = tl.cdiv(T, BT) |
|
|
i_tg = i_b * NT + i_t |
|
|
bos, eos = i_b * T, i_b * T + T |
|
|
|
|
|
b_dv = tl.zeros([BT, BV], dtype=tl.float32) |
|
|
|
|
|
|
|
|
q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K |
|
|
k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K |
|
|
do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V |
|
|
dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V |
|
|
s_qk = K if HEAD_FIRST else H*K |
|
|
s_vo = V if HEAD_FIRST else H*V |
|
|
s_g = 1 if HEAD_FIRST else H |
|
|
dh += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V |
|
|
|
|
|
b_A = tl.zeros([BT, BT], dtype=tl.float32) |
|
|
for i_k in range(tl.cdiv(K, BK)): |
|
|
p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
|
|
p_q = tl.make_block_ptr(q, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) |
|
|
b_q = tl.load(p_q, boundary_check=(0, 1)) |
|
|
b_k = tl.load(p_k, boundary_check=(0, 1)) |
|
|
b_A += tl.dot(b_k, b_q) |
|
|
p_dh = tl.make_block_ptr(dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) |
|
|
b_dh = tl.load(p_dh, boundary_check=(0, 1)) |
|
|
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype)) |
|
|
|
|
|
if USE_G: |
|
|
g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h) |
|
|
p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,)) |
|
|
b_g = tl.load(p_g, boundary_check=(0,)) |
|
|
b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * s_g) |
|
|
b_dv *= safe_exp(-b_g + b_g_last)[:, None] |
|
|
|
|
|
mask = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]) |
|
|
if USE_G: |
|
|
b_A = tl.where(mask, b_A * safe_exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty) |
|
|
else: |
|
|
b_A = tl.where(mask, b_A * scale, 0).to(do.dtype.element_ty) |
|
|
p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) |
|
|
p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) |
|
|
b_do = tl.load(p_do, boundary_check=(0, 1)) |
|
|
b_dv += tl.dot(b_A.to(b_do.dtype), b_do) |
|
|
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
|
|
|
|
@triton.heuristics({ |
|
|
'USE_G': lambda args: args['g'] is not None, |
|
|
'USE_OFFSETS': lambda args: args['offsets'] is not None, |
|
|
}) |
|
|
@triton.autotune( |
|
|
configs=[ |
|
|
triton.Config({}, num_warps=num_warps, num_stages=num_stages) |
|
|
for num_warps in NUM_WARPS |
|
|
for num_stages in [2, 3, 4] |
|
|
], |
|
|
key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'], |
|
|
) |
|
|
@triton.jit(do_not_specialize=['T']) |
|
|
def chunk_bwd_kernel_dv_local( |
|
|
q, |
|
|
k, |
|
|
g, |
|
|
do, |
|
|
dv, |
|
|
offsets, |
|
|
indices, |
|
|
scale, |
|
|
T, |
|
|
H: tl.constexpr, |
|
|
K: tl.constexpr, |
|
|
V: tl.constexpr, |
|
|
BT: tl.constexpr, |
|
|
BK: tl.constexpr, |
|
|
BV: tl.constexpr, |
|
|
USE_G: tl.constexpr, |
|
|
USE_OFFSETS: tl.constexpr, |
|
|
HEAD_FIRST: tl.constexpr |
|
|
): |
|
|
i_t, i_bh = tl.program_id(0), tl.program_id(1) |
|
|
i_b, i_h = i_bh // H, i_bh % H |
|
|
if USE_OFFSETS: |
|
|
i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) |
|
|
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) |
|
|
T = eos - bos |
|
|
else: |
|
|
bos, eos = i_b * T, i_b * T + T |
|
|
|
|
|
|
|
|
q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K |
|
|
k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K |
|
|
do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V |
|
|
dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V |
|
|
s_qk = K if HEAD_FIRST else H*K |
|
|
s_vo = V if HEAD_FIRST else H*V |
|
|
s_g = 1 if HEAD_FIRST else H |
|
|
|
|
|
b_A = tl.zeros([BT, BT], dtype=tl.float32) |
|
|
for i_k in range(tl.cdiv(K, BK)): |
|
|
p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
|
|
p_q = tl.make_block_ptr(q, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) |
|
|
b_q = tl.load(p_q, boundary_check=(0, 1)) |
|
|
b_k = tl.load(p_k, boundary_check=(0, 1)) |
|
|
b_A += tl.dot(b_k, b_q) |
|
|
|
|
|
if USE_G: |
|
|
g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h) |
|
|
p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,)) |
|
|
b_g = tl.load(p_g, boundary_check=(0,)) |
|
|
|
|
|
mask = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]) |
|
|
if USE_G: |
|
|
b_A = tl.where(mask, b_A * safe_exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty) |
|
|
else: |
|
|
b_A = tl.where(mask, b_A * scale, 0).to(do.dtype.element_ty) |
|
|
|
|
|
for i_v in range(tl.cdiv(V, BV)): |
|
|
p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) |
|
|
p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) |
|
|
b_do = tl.load(p_do, boundary_check=(0, 1)) |
|
|
b_dv = tl.dot(b_A.to(b_do.dtype), b_do) |
|
|
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
|
|
|
|
def chunk_fwd_o( |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
v: torch.Tensor, |
|
|
h: torch.Tensor, |
|
|
g: Optional[torch.Tensor] = None, |
|
|
scale: Optional[float] = None, |
|
|
offsets: Optional[torch.LongTensor] = None, |
|
|
indices: Optional[torch.LongTensor] = None, |
|
|
head_first: bool = True, |
|
|
chunk_size: int = 64 |
|
|
) -> torch.Tensor: |
|
|
if head_first: |
|
|
B, H, T, K, V = *q.shape, v.shape[-1] |
|
|
else: |
|
|
B, T, H, K, V = *q.shape, v.shape[-1] |
|
|
if scale is None: |
|
|
scale = k.shape[-1] ** -0.5 |
|
|
BT = min(chunk_size, max(16, triton.next_power_of_2(T))) |
|
|
NT = triton.cdiv(T, BT) if offsets is None else len(indices) |
|
|
|
|
|
o = torch.empty_like(v) |
|
|
|
|
|
def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) |
|
|
chunk_fwd_kernel_o[grid]( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
h, |
|
|
g, |
|
|
o, |
|
|
offsets, |
|
|
indices, |
|
|
scale, |
|
|
T=T, |
|
|
H=H, |
|
|
K=K, |
|
|
V=V, |
|
|
BT=BT, |
|
|
HEAD_FIRST=head_first |
|
|
) |
|
|
return o |
|
|
|
|
|
|
|
|
def chunk_bwd_dv( |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
g: torch.Tensor, |
|
|
do: torch.Tensor, |
|
|
dh: torch.Tensor, |
|
|
scale: float, |
|
|
offsets: Optional[torch.LongTensor] = None, |
|
|
indices: Optional[torch.LongTensor] = None, |
|
|
head_first: bool = True, |
|
|
chunk_size: int = 64 |
|
|
) -> torch.Tensor: |
|
|
if head_first: |
|
|
B, H, T, K, V = *k.shape, do.shape[-1] |
|
|
else: |
|
|
B, T, H, K, V = *k.shape, do.shape[-1] |
|
|
BT = min(chunk_size, max(16, triton.next_power_of_2(T))) |
|
|
|
|
|
if check_shared_mem('hopper', k.device.index): |
|
|
CONST_TILING = 128 |
|
|
elif check_shared_mem: |
|
|
CONST_TILING = 64 |
|
|
else: |
|
|
CONST_TILING = 32 |
|
|
BK = min(triton.next_power_of_2(K), CONST_TILING) |
|
|
BV = min(triton.next_power_of_2(V), CONST_TILING) |
|
|
NT = triton.cdiv(T, BT) if offsets is None else len(indices) |
|
|
NV = triton.cdiv(V, BV) |
|
|
|
|
|
dv = torch.empty_like(do) |
|
|
grid = (NV, NT, B * H) |
|
|
chunk_bwd_kernel_dv[grid]( |
|
|
q, |
|
|
k, |
|
|
g, |
|
|
do, |
|
|
dv, |
|
|
dh, |
|
|
offsets, |
|
|
indices, |
|
|
scale, |
|
|
T=T, |
|
|
H=H, |
|
|
K=K, |
|
|
V=V, |
|
|
BT=BT, |
|
|
BK=BK, |
|
|
BV=BV, |
|
|
HEAD_FIRST=head_first |
|
|
) |
|
|
return dv |
|
|
|
|
|
|
|
|
def chunk_bwd_dv_local( |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
g: torch.Tensor, |
|
|
do: torch.Tensor, |
|
|
dh: torch.Tensor, |
|
|
scale: float, |
|
|
offsets: Optional[torch.LongTensor] = None, |
|
|
indices: Optional[torch.LongTensor] = None, |
|
|
head_first: bool = True, |
|
|
chunk_size: int = 64 |
|
|
) -> torch.Tensor: |
|
|
if head_first: |
|
|
B, H, T, K, V = *k.shape, do.shape[-1] |
|
|
else: |
|
|
B, T, H, K, V = *k.shape, do.shape[-1] |
|
|
BT = min(chunk_size, max(16, triton.next_power_of_2(T))) |
|
|
|
|
|
if check_shared_mem('hopper', k.device.index): |
|
|
CONST_TILING = 128 |
|
|
elif check_shared_mem: |
|
|
CONST_TILING = 64 |
|
|
else: |
|
|
CONST_TILING = 32 |
|
|
BK = min(triton.next_power_of_2(K), CONST_TILING) |
|
|
BV = min(triton.next_power_of_2(V), CONST_TILING) |
|
|
NT = triton.cdiv(T, BT) if offsets is None else len(indices) |
|
|
|
|
|
dv = torch.empty_like(do) |
|
|
grid = (NT, B * H) |
|
|
chunk_bwd_kernel_dv_local[grid]( |
|
|
q, |
|
|
k, |
|
|
g, |
|
|
do, |
|
|
dv, |
|
|
offsets, |
|
|
indices, |
|
|
scale, |
|
|
T=T, |
|
|
H=H, |
|
|
K=K, |
|
|
V=V, |
|
|
BT=BT, |
|
|
BK=BK, |
|
|
BV=BV, |
|
|
HEAD_FIRST=head_first |
|
|
) |
|
|
return dv |
|
|
|
|
|
|
|
|
def chunk_bwd_dqkwg( |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
v: torch.Tensor, |
|
|
g: torch.Tensor, |
|
|
do: torch.Tensor, |
|
|
h: torch.Tensor, |
|
|
dh: torch.Tensor, |
|
|
dv: Optional[torch.Tensor] = None, |
|
|
w: Optional[torch.Tensor] = None, |
|
|
offsets: Optional[torch.LongTensor] = None, |
|
|
indices: Optional[torch.LongTensor] = None, |
|
|
chunk_size: int = 64, |
|
|
scale: float = 1.0, |
|
|
head_first: bool = True, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
|
|
|
if head_first: |
|
|
B, H, T, K, V = *k.shape, v.shape[-1] |
|
|
else: |
|
|
B, T, H, K, V = *k.shape, v.shape[-1] |
|
|
BT = min(chunk_size, max(16, triton.next_power_of_2(T))) |
|
|
NT = triton.cdiv(T, BT) if offsets is None else len(indices) |
|
|
|
|
|
CONST_TILING = 64 if check_shared_mem() else 32 |
|
|
BK = min(triton.next_power_of_2(K), CONST_TILING) |
|
|
BV = min(triton.next_power_of_2(V), CONST_TILING) |
|
|
NK = triton.cdiv(K, BK) |
|
|
dq = torch.empty_like(q) |
|
|
dk = torch.empty_like(k) |
|
|
dg = torch.empty(NK, *g.shape, dtype=torch.float32, device=g.device) if g is not None else None |
|
|
dw = torch.empty_like(w) if w is not None else None |
|
|
|
|
|
grid = (NK, NT, B * H) |
|
|
chunk_bwd_kernel_dqkwg[grid]( |
|
|
q=q, |
|
|
k=k, |
|
|
v=v, |
|
|
h=h, |
|
|
g=g, |
|
|
do=do, |
|
|
dh=dh, |
|
|
dv=dv, |
|
|
w=w, |
|
|
dw=dw, |
|
|
dq=dq, |
|
|
dk=dk, |
|
|
dg=dg, |
|
|
offsets=offsets, |
|
|
indices=indices, |
|
|
scale=scale, |
|
|
B=B, |
|
|
T=T, |
|
|
H=H, |
|
|
K=K, |
|
|
V=V, |
|
|
BT=BT, |
|
|
BK=BK, |
|
|
BV=BV, |
|
|
HEAD_FIRST=head_first |
|
|
) |
|
|
|
|
|
if dg is not None: |
|
|
dg = dg.sum(0) |
|
|
return dq, dk, dw, dg |
|
|
|