|
|
|
|
|
|
|
|
import math |
|
|
|
|
|
from typing import Optional, List |
|
|
|
|
|
import torch |
|
|
|
|
|
import tilelang |
|
|
import tilelang.language as T |
|
|
|
|
|
|
|
|
@tilelang.jit( |
|
|
out_idx=[6], |
|
|
pass_configs={ |
|
|
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, |
|
|
}, |
|
|
) |
|
|
def flash_mla_varlen_func_kernel(UQ, UKV, heads, dim_qk, dim_vo, softmax_scale, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): |
|
|
batch_size = T.dynamic("batch_size") |
|
|
scale = softmax_scale * 1.44269504 |
|
|
q_shape = [UQ, heads, dim_qk] |
|
|
k_shape = [UKV, heads, dim_qk] |
|
|
v_shape = [UKV, heads, dim_vo] |
|
|
o_shape = [UQ, heads, dim_vo] |
|
|
|
|
|
dtype = T.bfloat16 |
|
|
accum_dtype = T.float32 |
|
|
|
|
|
@T.prim_func |
|
|
def main( |
|
|
Q_unpad: T.Tensor(q_shape, dtype), |
|
|
K_unpad: T.Tensor(k_shape, dtype), |
|
|
V_unpad: T.Tensor(v_shape, dtype), |
|
|
cu_seqlens_q: T.Tensor([batch_size + 1], T.int32), |
|
|
cu_seqlens_k: T.Tensor([batch_size + 1], T.int32), |
|
|
max_seqlen_q: T.int32, |
|
|
Output_unpad: T.Tensor(o_shape, dtype), |
|
|
): |
|
|
with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz): |
|
|
Q_shared = T.alloc_shared([block_M, dim_qk], dtype) |
|
|
K_shared = T.alloc_shared([block_N, dim_qk], dtype) |
|
|
V_shared = T.alloc_shared([block_N, dim_vo], dtype) |
|
|
O_shared = T.alloc_shared([block_M, dim_vo], dtype) |
|
|
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) |
|
|
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) |
|
|
acc_o = T.alloc_fragment([block_M, dim_vo], accum_dtype) |
|
|
scores_max = T.alloc_fragment([block_M], accum_dtype) |
|
|
scores_max_prev = T.alloc_fragment([block_M], accum_dtype) |
|
|
scores_scale = T.alloc_fragment([block_M], accum_dtype) |
|
|
scores_sum = T.alloc_fragment([block_M], accum_dtype) |
|
|
logsum = T.alloc_fragment([block_M], accum_dtype) |
|
|
|
|
|
batch_idx = bz |
|
|
head_idx = by |
|
|
|
|
|
q_start_idx = cu_seqlens_q[batch_idx] |
|
|
kv_start_idx = cu_seqlens_k[batch_idx] |
|
|
q_end_idx = cu_seqlens_q[batch_idx + 1] |
|
|
kv_end_idx = cu_seqlens_k[batch_idx + 1] |
|
|
|
|
|
q_current_seqlen = q_end_idx - q_start_idx |
|
|
kv_current_seqlen = kv_end_idx - kv_start_idx |
|
|
|
|
|
T.copy( |
|
|
Q_unpad[q_start_idx + bx * block_M : q_start_idx + bx * block_M + block_M, head_idx, :], Q_shared |
|
|
) |
|
|
|
|
|
T.fill(acc_o, 0) |
|
|
T.fill(logsum, 0) |
|
|
T.fill(scores_max, -T.infinity(accum_dtype)) |
|
|
|
|
|
offset = kv_current_seqlen - q_current_seqlen |
|
|
loop_range = ( |
|
|
T.min(T.ceildiv(offset + (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N)) |
|
|
if is_causal |
|
|
else T.ceildiv(kv_current_seqlen, block_N) |
|
|
) |
|
|
|
|
|
for k in T.Pipelined(loop_range, num_stages=num_stages): |
|
|
|
|
|
T.copy( |
|
|
K_unpad[kv_start_idx + k * block_N : kv_start_idx + k * block_N + block_N, head_idx, :], K_shared |
|
|
) |
|
|
if is_causal: |
|
|
for i, j in T.Parallel(block_M, block_N): |
|
|
acc_s[i, j] = T.if_then_else( |
|
|
(bx * block_M + i + offset < k * block_N + j) |
|
|
or (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), |
|
|
-1e9, |
|
|
0, |
|
|
) |
|
|
else: |
|
|
for i, j in T.Parallel(block_M, block_N): |
|
|
acc_s[i, j] = T.if_then_else( |
|
|
(bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), -1e9, 0 |
|
|
) |
|
|
|
|
|
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) |
|
|
|
|
|
|
|
|
T.copy(scores_max, scores_max_prev) |
|
|
T.fill(scores_max, -T.infinity(accum_dtype)) |
|
|
T.reduce_max(acc_s, scores_max, dim=1, clear=False) |
|
|
for i in T.Parallel(block_M): |
|
|
scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in T.Parallel(block_M): |
|
|
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) |
|
|
for i, j in T.Parallel(block_M, block_N): |
|
|
|
|
|
|
|
|
|
|
|
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) |
|
|
T.reduce_sum(acc_s, scores_sum, dim=1) |
|
|
for i in T.Parallel(block_M): |
|
|
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] |
|
|
T.copy(acc_s, acc_s_cast) |
|
|
|
|
|
|
|
|
for i, j in T.Parallel(block_M, dim_vo): |
|
|
acc_o[i, j] *= scores_scale[i] |
|
|
|
|
|
|
|
|
T.copy( |
|
|
V_unpad[kv_start_idx + k * block_N : kv_start_idx + k * block_N + block_N, head_idx, :], V_shared |
|
|
) |
|
|
|
|
|
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) |
|
|
|
|
|
for i, j in T.Parallel(block_M, dim_vo): |
|
|
|
|
|
acc_o[i, j] = 0 if is_causal and bx * block_M + i + offset < 0 else acc_o[i, j] / logsum[i] |
|
|
|
|
|
T.copy(acc_o, O_shared) |
|
|
for i, d in T.Parallel(block_M, dim_vo): |
|
|
if bx * block_M + i < q_current_seqlen: |
|
|
Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d] |
|
|
|
|
|
return main |
|
|
|
|
|
|
|
|
def flash_mla_varlen_func( |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
v: torch.Tensor, |
|
|
cu_seqlens_q: torch.Tensor, |
|
|
cu_seqlens_k: torch.Tensor, |
|
|
max_seqlen_q: int, |
|
|
max_seqlen_k: int, |
|
|
softmax_scale: float, |
|
|
causal: bool, |
|
|
): |
|
|
assert causal == True |
|
|
|
|
|
nnz_qo, num_heads_q, head_dim_qk = q.shape |
|
|
|
|
|
nnz_kv, num_heads_kv, head_dim_vo = v.shape |
|
|
assert num_heads_q == num_heads_kv |
|
|
|
|
|
kernel = flash_mla_varlen_func_kernel( |
|
|
nnz_qo, nnz_kv, num_heads_q, head_dim_qk, head_dim_vo, softmax_scale, True, block_M=128, block_N=128, num_stages=1, threads=256 |
|
|
) |
|
|
return kernel(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) |
|
|
|
|
|
|
|
|
|
|
|
import tilelang as T |
|
|
import tilelang.language as T |
|
|
|
|
|
@tilelang.jit( |
|
|
out_idx=[6], |
|
|
pass_configs={ |
|
|
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, |
|
|
}, |
|
|
) |
|
|
def flash_mla_with_kvcache_kernel( |
|
|
batch, |
|
|
seqlen_q, |
|
|
h_q, |
|
|
h_kv, |
|
|
dv, |
|
|
dpe, |
|
|
block_N, |
|
|
block_H, |
|
|
num_split, |
|
|
block_size, |
|
|
num_pages, |
|
|
max_num_blocks_per_seq, |
|
|
softmax_scale=None |
|
|
): |
|
|
if softmax_scale is None: |
|
|
softmax_scale = (dv + dpe) ** -0.5 |
|
|
scale = float(softmax_scale * 1.44269504) |
|
|
dtype = T.bfloat16 |
|
|
accum_dtype = T.float32 |
|
|
|
|
|
|
|
|
assert h_kv == 1, "h_kv must be 1" |
|
|
|
|
|
|
|
|
kv_group_num = h_q |
|
|
|
|
|
VALID_BLOCK_H = min(block_H, kv_group_num) |
|
|
VALID_BLOCK_H = max(VALID_BLOCK_H, 16) |
|
|
block_H = VALID_BLOCK_H |
|
|
|
|
|
assert block_size >= block_N and block_size % block_N == 0, \ |
|
|
"block_size must be larger than block_N and a multiple of block_N" |
|
|
|
|
|
@T.prim_func |
|
|
def main_split( |
|
|
Q: T.Tensor([batch, seqlen_q, h_q, dv + dpe], dtype), |
|
|
|
|
|
KV: T.Tensor([num_pages, block_size, 1, dv + dpe], dtype), |
|
|
block_table: T.Tensor([batch, max_num_blocks_per_seq], T.int32), |
|
|
cache_seqlens: T.Tensor([batch], T.int32), |
|
|
glse: T.Tensor([batch, seqlen_q, h_q, num_split], dtype), |
|
|
Output_partial: T.Tensor([batch, seqlen_q, h_q, num_split, dv], dtype), |
|
|
Output: T.Tensor([batch, seqlen_q, h_q, dv], dtype), |
|
|
): |
|
|
|
|
|
|
|
|
with T.Kernel(batch, seqlen_q * (h_q + VALID_BLOCK_H - 1) // VALID_BLOCK_H, num_split, threads=256) as (bx, by, bz): |
|
|
Q_shared = T.alloc_shared([block_H, dv], dtype) |
|
|
S_shared = T.alloc_shared([block_H, block_N], dtype) |
|
|
Q_pe_shared = T.alloc_shared([block_H, dpe], dtype) |
|
|
KV_shared = T.alloc_shared([block_N, dv], dtype) |
|
|
K_pe_shared = T.alloc_shared([block_N, dpe], dtype) |
|
|
O_shared = T.alloc_shared([block_H, dv], dtype) |
|
|
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) |
|
|
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) |
|
|
acc_o = T.alloc_fragment([block_H, dv], accum_dtype) |
|
|
scores_max = T.alloc_fragment([block_H], accum_dtype) |
|
|
scores_max_prev = T.alloc_fragment([block_H], accum_dtype) |
|
|
scores_scale = T.alloc_fragment([block_H], accum_dtype) |
|
|
scores_sum = T.alloc_fragment([block_H], accum_dtype) |
|
|
logsum = T.alloc_fragment([block_H], accum_dtype) |
|
|
|
|
|
seq_q_head_block_idx = by |
|
|
seq_q_idx = T.floordiv(seq_q_head_block_idx, (h_q + VALID_BLOCK_H - 1) // VALID_BLOCK_H) |
|
|
head_block_idx = T.floormod(seq_q_head_block_idx, (h_q + VALID_BLOCK_H - 1) // VALID_BLOCK_H) |
|
|
|
|
|
T.use_swizzle(10) |
|
|
|
|
|
T.copy(Q[bx, seq_q_idx, head_block_idx * VALID_BLOCK_H : (head_block_idx + 1) * VALID_BLOCK_H, : dv], Q_shared) |
|
|
T.copy(Q[bx, seq_q_idx, head_block_idx * VALID_BLOCK_H : (head_block_idx + 1) * VALID_BLOCK_H, dv:], Q_pe_shared) |
|
|
T.fill(acc_o, 0) |
|
|
T.fill(logsum, 0) |
|
|
T.fill(scores_max, -T.infinity(accum_dtype)) |
|
|
|
|
|
total_blocks = T.ceildiv(cache_seqlens[bx] - seqlen_q + seq_q_idx + 1, block_N) |
|
|
blocks_per_split = T.floordiv(total_blocks, num_split) |
|
|
remaining_blocks = T.floormod(total_blocks, num_split) |
|
|
loop_range = blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0) |
|
|
start = (blocks_per_split * bz + T.min(bz, remaining_blocks)) * block_N |
|
|
|
|
|
for k in T.Pipelined(loop_range, num_stages=2): |
|
|
global_token_idx = start + k * block_N |
|
|
|
|
|
logical_page_idx = global_token_idx // block_size |
|
|
page_offset = global_token_idx % block_size |
|
|
physical_page_id = block_table[bx, logical_page_idx] |
|
|
|
|
|
|
|
|
T.copy(KV[physical_page_id, page_offset : page_offset + block_N, 0, :dv], KV_shared) |
|
|
T.copy(KV[physical_page_id, page_offset : page_offset + block_N, 0, dv:], K_pe_shared) |
|
|
|
|
|
T.clear(acc_s) |
|
|
T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) |
|
|
T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) |
|
|
|
|
|
T.copy(scores_max, scores_max_prev) |
|
|
T.fill(scores_max, -T.infinity(accum_dtype)) |
|
|
for i, j in T.Parallel(block_H, block_N): |
|
|
acc_s[i, j] = T.if_then_else(start + k * block_N + j >= cache_seqlens[bx] - seqlen_q + seq_q_idx + 1, -T.infinity(accum_dtype), acc_s[i, j]) |
|
|
|
|
|
T.reduce_max(acc_s, scores_max, dim=1, clear=False) |
|
|
for i in T.Parallel(block_H): |
|
|
scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) |
|
|
for i in T.Parallel(block_H): |
|
|
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) |
|
|
for i, j in T.Parallel(block_H, block_N): |
|
|
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) |
|
|
|
|
|
T.reduce_sum(acc_s, scores_sum, dim=1) |
|
|
T.copy(acc_s, S_shared) |
|
|
T.copy(S_shared, acc_s_cast) |
|
|
|
|
|
for i in T.Parallel(block_H): |
|
|
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] |
|
|
for i, j in T.Parallel(block_H, dv): |
|
|
acc_o[i, j] *= scores_scale[i] |
|
|
|
|
|
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) |
|
|
|
|
|
for i, j in T.Parallel(block_H, dv): |
|
|
acc_o[i, j] /= logsum[i] |
|
|
for i in T.Parallel(block_H): |
|
|
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale |
|
|
|
|
|
T.copy(logsum, glse[bx, seq_q_idx, head_block_idx * VALID_BLOCK_H : (head_block_idx + 1) * VALID_BLOCK_H, bz]) |
|
|
T.copy(acc_o, O_shared) |
|
|
T.copy(O_shared, Output_partial[bx, seq_q_idx, head_block_idx * VALID_BLOCK_H : (head_block_idx + 1) * VALID_BLOCK_H, bz, :]) |
|
|
|
|
|
|
|
|
with T.Kernel(seqlen_q * h_q, batch, threads=128) as (by, bz): |
|
|
po_local = T.alloc_fragment([dv], dtype) |
|
|
o_accum_local = T.alloc_fragment([dv], accum_dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lse_local_split = T.alloc_fragment([1], accum_dtype) |
|
|
lse_logsum_local = T.alloc_fragment([1], accum_dtype) |
|
|
lse_max_local = T.alloc_fragment([1], accum_dtype) |
|
|
scale_local = T.alloc_fragment([1], accum_dtype) |
|
|
|
|
|
seq_q_head_idx = by |
|
|
seq_q_idx = T.floordiv(seq_q_head_idx, h_q) |
|
|
head_idx = T.floormod(seq_q_head_idx, h_q) |
|
|
|
|
|
lse_logsum_local[0] = 0.0 |
|
|
T.clear(o_accum_local) |
|
|
lse_max_local[0] = -T.infinity(accum_dtype) |
|
|
for k in T.serial(num_split): |
|
|
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, seq_q_idx, head_idx, k]) |
|
|
for k in T.Pipelined(num_split, num_stages=1): |
|
|
lse_local_split[0] = glse[bz, seq_q_idx, head_idx, k] |
|
|
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) |
|
|
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] |
|
|
for k in T.serial(num_split): |
|
|
for i in T.Parallel(dv): |
|
|
po_local[i] = Output_partial[bz, seq_q_idx, head_idx, k, i] |
|
|
lse_local_split[0] = glse[bz, seq_q_idx, head_idx, k] |
|
|
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) |
|
|
for i in T.Parallel(dv): |
|
|
o_accum_local[i] += po_local[i] * scale_local[0] |
|
|
for i in T.Parallel(dv): |
|
|
Output[bz, seq_q_idx, head_idx, i] = o_accum_local[i] |
|
|
|
|
|
@T.prim_func |
|
|
def main_no_split( |
|
|
Q: T.Tensor([batch, seqlen_q, h_q, dv+ dpe], dtype), |
|
|
KV: T.Tensor([num_pages, block_size, 1, dv+ dpe], dtype), |
|
|
block_table: T.Tensor([batch, max_num_blocks_per_seq], T.int32), |
|
|
cache_seqlens: T.Tensor([batch], T.int32), |
|
|
glse: T.Tensor([batch, seqlen_q, h_q, num_split], dtype), |
|
|
Output_partial: T.Tensor([batch, seqlen_q, h_q, num_split, dv], dtype), |
|
|
Output: T.Tensor([batch, seqlen_q, h_q, dv], dtype), |
|
|
): |
|
|
with T.Kernel(batch, seqlen_q * h_q // VALID_BLOCK_H, threads=256) as (bx, by): |
|
|
Q_shared = T.alloc_shared([block_H, dv], dtype) |
|
|
S_shared = T.alloc_shared([block_H, block_N], dtype) |
|
|
Q_pe_shared = T.alloc_shared([block_H, dpe], dtype) |
|
|
KV_shared = T.alloc_shared([block_N, dv], dtype) |
|
|
K_pe_shared = T.alloc_shared([block_N, dpe], dtype) |
|
|
O_shared = T.alloc_shared([block_H, dv], dtype) |
|
|
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) |
|
|
acc_o = T.alloc_fragment([block_H, dv], accum_dtype) |
|
|
scores_max = T.alloc_fragment([block_H], accum_dtype) |
|
|
scores_max_prev = T.alloc_fragment([block_H], accum_dtype) |
|
|
scores_scale = T.alloc_fragment([block_H], accum_dtype) |
|
|
scores_sum = T.alloc_fragment([block_H], accum_dtype) |
|
|
logsum = T.alloc_fragment([block_H], accum_dtype) |
|
|
|
|
|
seq_q_head_block_idx = by |
|
|
seq_q_idx = T.floordiv(seq_q_head_block_idx, (h_q + VALID_BLOCK_H - 1) // VALID_BLOCK_H) |
|
|
head_block_idx = T.floormod(seq_q_head_block_idx, (h_q + VALID_BLOCK_H - 1) // VALID_BLOCK_H) |
|
|
|
|
|
T.use_swizzle(10) |
|
|
|
|
|
T.copy(Q[bx, seq_q_idx, head_block_idx * VALID_BLOCK_H : (head_block_idx + 1) * VALID_BLOCK_H, :dv], Q_shared) |
|
|
T.copy(Q[bx, seq_q_idx, head_block_idx * VALID_BLOCK_H : (head_block_idx + 1) * VALID_BLOCK_H, dv:], Q_pe_shared) |
|
|
T.fill(acc_o, 0) |
|
|
T.fill(logsum, 0) |
|
|
T.fill(scores_max, -T.infinity(accum_dtype)) |
|
|
|
|
|
loop_range = T.ceildiv(cache_seqlens[bx] - seqlen_q + seq_q_idx + 1, block_N) |
|
|
for kr in T.Pipelined(loop_range, num_stages=2): |
|
|
k = loop_range - 1 - kr |
|
|
global_token_idx = k * block_N |
|
|
|
|
|
logical_page_idx = global_token_idx // block_size |
|
|
page_offset = global_token_idx % block_size |
|
|
physical_page_id = block_table[bx, logical_page_idx] |
|
|
|
|
|
|
|
|
T.copy(KV[physical_page_id, page_offset : page_offset + block_N, 0, : dv], KV_shared) |
|
|
T.copy(KV[physical_page_id, page_offset : page_offset + block_N, 0, dv:], K_pe_shared) |
|
|
|
|
|
T.clear(acc_s) |
|
|
T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) |
|
|
T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) |
|
|
|
|
|
T.copy(scores_max, scores_max_prev) |
|
|
T.fill(scores_max, -T.infinity(accum_dtype)) |
|
|
if kr == 0: |
|
|
for i, j in T.Parallel(block_H, block_N): |
|
|
acc_s[i, j] = T.if_then_else(k * block_N + j >= cache_seqlens[bx] - seqlen_q + seq_q_idx + 1, -T.infinity(accum_dtype), acc_s[i, j]) |
|
|
|
|
|
T.reduce_max(acc_s, scores_max, dim=1, clear=False) |
|
|
for i in T.Parallel(block_H): |
|
|
scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) |
|
|
for i in T.Parallel(block_H): |
|
|
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) |
|
|
for i, j in T.Parallel(block_H, block_N): |
|
|
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) |
|
|
|
|
|
T.reduce_sum(acc_s, scores_sum, dim=1) |
|
|
T.copy(acc_s, S_shared) |
|
|
for i in T.Parallel(block_H): |
|
|
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] |
|
|
for i, j in T.Parallel(block_H, dv): |
|
|
acc_o[i, j] *= scores_scale[i] |
|
|
|
|
|
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) |
|
|
|
|
|
for i, j in T.Parallel(block_H, dv): |
|
|
acc_o[i, j] /= logsum[i] |
|
|
T.copy(acc_o, O_shared) |
|
|
T.copy(O_shared, Output[bx, seq_q_idx, head_block_idx * VALID_BLOCK_H : (head_block_idx + 1) * VALID_BLOCK_H, :]) |
|
|
|
|
|
if num_split > 1: |
|
|
return main_split |
|
|
else: |
|
|
return main_no_split |
|
|
|
|
|
|
|
|
def get_splits( |
|
|
batch_size: int, |
|
|
num_heads: int, |
|
|
seqlen_q: int, |
|
|
avg_seqlen_k: int, |
|
|
block_size_h: int = 64, |
|
|
block_size_n: int = 128, |
|
|
streaming_info: Optional[List[int]] = None, |
|
|
): |
|
|
""" |
|
|
Calculates the optimal static num_splits to saturate the GPU |
|
|
without incurring unnecessary reduction overhead. |
|
|
""" |
|
|
|
|
|
device = torch.cuda.current_device() |
|
|
props = torch.cuda.get_device_properties(device) |
|
|
num_sms = props.multi_processor_count |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
target_blocks_per_sm = 2 |
|
|
target_total_blocks = num_sms * target_blocks_per_sm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
effective_seqlen_k = avg_seqlen_k |
|
|
if streaming_info is not None and len(streaming_info) >= 2: |
|
|
sink_block_num = streaming_info[0] |
|
|
local_block_num = streaming_info[1] |
|
|
|
|
|
effective_seqlen_k = min(avg_seqlen_k, (sink_block_num + local_block_num) * block_size_n) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_head_blocks = (num_heads + block_size_h - 1) // block_size_h |
|
|
natural_blocks = batch_size * seqlen_q * num_head_blocks |
|
|
|
|
|
|
|
|
if natural_blocks >= target_total_blocks: |
|
|
|
|
|
|
|
|
return 1 |
|
|
|
|
|
|
|
|
needed_splits = (target_total_blocks + natural_blocks - 1) // natural_blocks |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
min_tokens_per_tile = block_size_n |
|
|
max_splits_possible = max(1, effective_seqlen_k // min_tokens_per_tile) |
|
|
|
|
|
optimal_split = min(needed_splits, max_splits_possible) |
|
|
|
|
|
|
|
|
|
|
|
if optimal_split > 1: |
|
|
optimal_split = 2 ** math.floor(math.log2(optimal_split)) |
|
|
|
|
|
return int(max(1, optimal_split)) |
|
|
|
|
|
|
|
|
def flash_mla_with_kvcache( |
|
|
q: torch.Tensor, |
|
|
kv: torch.Tensor, |
|
|
cache_seqlens: torch.Tensor, |
|
|
block_table: torch.Tensor, |
|
|
dim_nope: int, |
|
|
softmax_scale: float, |
|
|
causal: bool, |
|
|
): |
|
|
assert causal == True |
|
|
|
|
|
batch_size, seqlen_q, num_heads_q, head_dim_qk = q.shape |
|
|
num_pages, page_size, num_heads_kv, head_dim_vo = kv.shape |
|
|
assert num_heads_kv == 1 |
|
|
assert head_dim_qk == head_dim_vo |
|
|
|
|
|
block_H = 64 |
|
|
block_N = 64 |
|
|
num_splits = get_splits(batch_size, num_heads_q, seqlen_q, torch.mean(cache_seqlens.float()).int().item(), block_H, block_N, None) |
|
|
|
|
|
glse = torch.empty(batch_size, seqlen_q, num_heads_q, num_splits, dtype=q.dtype, device=q.device) |
|
|
out_partial = torch.empty(batch_size, seqlen_q, num_heads_q, num_splits, dim_nope, dtype=q.dtype, device=q.device) |
|
|
|
|
|
kernel = flash_mla_with_kvcache_kernel( |
|
|
batch_size, seqlen_q, num_heads_q, num_heads_kv, dim_nope, head_dim_qk - dim_nope, block_N, block_H, num_splits, page_size, num_pages, block_table.shape[1], softmax_scale |
|
|
) |
|
|
return kernel(q, kv, block_table, cache_seqlens, glse, out_partial) |
|
|
|