# -*- coding: utf-8 -*- import math from typing import Optional, List import torch import tilelang import tilelang.language as T @tilelang.jit( out_idx=[6], pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }, ) def flash_mla_varlen_func_kernel(UQ, UKV, heads, dim_qk, dim_vo, softmax_scale, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): batch_size = T.dynamic("batch_size") scale = softmax_scale * 1.44269504 # log2(e) q_shape = [UQ, heads, dim_qk] k_shape = [UKV, heads, dim_qk] v_shape = [UKV, heads, dim_vo] o_shape = [UQ, heads, dim_vo] dtype = T.bfloat16 accum_dtype = T.float32 @T.prim_func def main( Q_unpad: T.Tensor(q_shape, dtype), K_unpad: T.Tensor(k_shape, dtype), V_unpad: T.Tensor(v_shape, dtype), cu_seqlens_q: T.Tensor([batch_size + 1], T.int32), cu_seqlens_k: T.Tensor([batch_size + 1], T.int32), max_seqlen_q: T.int32, Output_unpad: T.Tensor(o_shape, dtype), ): with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim_qk], dtype) K_shared = T.alloc_shared([block_N, dim_qk], dtype) V_shared = T.alloc_shared([block_N, dim_vo], dtype) O_shared = T.alloc_shared([block_M, dim_vo], dtype) acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) acc_o = T.alloc_fragment([block_M, dim_vo], accum_dtype) scores_max = T.alloc_fragment([block_M], accum_dtype) scores_max_prev = T.alloc_fragment([block_M], accum_dtype) scores_scale = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) batch_idx = bz head_idx = by q_start_idx = cu_seqlens_q[batch_idx] kv_start_idx = cu_seqlens_k[batch_idx] q_end_idx = cu_seqlens_q[batch_idx + 1] kv_end_idx = cu_seqlens_k[batch_idx + 1] q_current_seqlen = q_end_idx - q_start_idx kv_current_seqlen = kv_end_idx - kv_start_idx T.copy( Q_unpad[q_start_idx + bx * block_M : q_start_idx + bx * block_M + block_M, head_idx, :], Q_shared ) # OOB positions will be handled below T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) offset = kv_current_seqlen - q_current_seqlen # always align on the right loop_range = ( T.min(T.ceildiv(offset + (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N)) if is_causal else T.ceildiv(kv_current_seqlen, block_N) ) for k in T.Pipelined(loop_range, num_stages=num_stages): # Q * K T.copy( K_unpad[kv_start_idx + k * block_N : kv_start_idx + k * block_N + block_N, head_idx, :], K_shared ) # OOB positions will be handled below if is_causal: for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.if_then_else( (bx * block_M + i + offset < k * block_N + j) or (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), -1e9, 0, ) else: for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.if_then_else( (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), -1e9, 0 ) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) # Softmax T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_M): scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # in the first ceil_div(kBlockM, kBlockN) steps. # for i in T.Parallel(block_M): # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, block_N): # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - # max * log_2(e)) This allows the compiler to use the ffma # instruction instead of fadd and fmul separately. acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) for i in T.Parallel(block_M): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] T.copy(acc_s, acc_s_cast) # Rescale for i, j in T.Parallel(block_M, dim_vo): acc_o[i, j] *= scores_scale[i] # V * softmax(Q * K) T.copy( V_unpad[kv_start_idx + k * block_N : kv_start_idx + k * block_N + block_N, head_idx, :], V_shared ) # OOB positions' weights are 0 T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_M, dim_vo): # When sq > skv, some tokens can see nothing acc_o[i, j] = 0 if is_causal and bx * block_M + i + offset < 0 else acc_o[i, j] / logsum[i] T.copy(acc_o, O_shared) for i, d in T.Parallel(block_M, dim_vo): if bx * block_M + i < q_current_seqlen: Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d] return main def flash_mla_varlen_func( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, softmax_scale: float, causal: bool, ): assert causal == True nnz_qo, num_heads_q, head_dim_qk = q.shape # nnz_kv, num_heads_kv, head_dim_qk = v.shape nnz_kv, num_heads_kv, head_dim_vo = v.shape assert num_heads_q == num_heads_kv kernel = flash_mla_varlen_func_kernel( nnz_qo, nnz_kv, num_heads_q, head_dim_qk, head_dim_vo, softmax_scale, True, block_M=128, block_N=128, num_stages=1, threads=256 ) return kernel(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) import tilelang as T import tilelang.language as T @tilelang.jit( out_idx=[6], pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }, ) def flash_mla_with_kvcache_kernel( batch, seqlen_q, h_q, h_kv, dv, dpe, block_N, block_H, num_split, block_size, num_pages, max_num_blocks_per_seq, softmax_scale=None ): if softmax_scale is None: softmax_scale = (dv + dpe) ** -0.5 scale = float(softmax_scale * 1.44269504) # log2(e) dtype = T.bfloat16 accum_dtype = T.float32 # Enforce constraints for this specific kernel version assert h_kv == 1, "h_kv must be 1" # kv_group_num equals h_q when h_kv is 1 kv_group_num = h_q VALID_BLOCK_H = min(block_H, kv_group_num) VALID_BLOCK_H = max(VALID_BLOCK_H, 16) block_H = VALID_BLOCK_H assert block_size >= block_N and block_size % block_N == 0, \ "block_size must be larger than block_N and a multiple of block_N" @T.prim_func def main_split( Q: T.Tensor([batch, seqlen_q, h_q, dv + dpe], dtype), # Paged KV: [num_pages, page_size, heads=1, dim] KV: T.Tensor([num_pages, block_size, 1, dv + dpe], dtype), block_table: T.Tensor([batch, max_num_blocks_per_seq], T.int32), cache_seqlens: T.Tensor([batch], T.int32), glse: T.Tensor([batch, seqlen_q, h_q, num_split], dtype), Output_partial: T.Tensor([batch, seqlen_q, h_q, num_split, dv], dtype), Output: T.Tensor([batch, seqlen_q, h_q, dv], dtype), ): # Grid: batch, split_Q_heads, split_seq # Since h_kv=1, all Q heads within the tile attend to the same KV head (0) with T.Kernel(batch, seqlen_q * (h_q + VALID_BLOCK_H - 1) // VALID_BLOCK_H, num_split, threads=256) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dv], dtype) S_shared = T.alloc_shared([block_H, block_N], dtype) Q_pe_shared = T.alloc_shared([block_H, dpe], dtype) KV_shared = T.alloc_shared([block_N, dv], dtype) K_pe_shared = T.alloc_shared([block_N, dpe], dtype) O_shared = T.alloc_shared([block_H, dv], dtype) acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) acc_o = T.alloc_fragment([block_H, dv], accum_dtype) scores_max = T.alloc_fragment([block_H], accum_dtype) scores_max_prev = T.alloc_fragment([block_H], accum_dtype) scores_scale = T.alloc_fragment([block_H], accum_dtype) scores_sum = T.alloc_fragment([block_H], accum_dtype) logsum = T.alloc_fragment([block_H], accum_dtype) seq_q_head_block_idx = by seq_q_idx = T.floordiv(seq_q_head_block_idx, (h_q + VALID_BLOCK_H - 1) // VALID_BLOCK_H) head_block_idx = T.floormod(seq_q_head_block_idx, (h_q + VALID_BLOCK_H - 1) // VALID_BLOCK_H) T.use_swizzle(10) T.copy(Q[bx, seq_q_idx, head_block_idx * VALID_BLOCK_H : (head_block_idx + 1) * VALID_BLOCK_H, : dv], Q_shared) T.copy(Q[bx, seq_q_idx, head_block_idx * VALID_BLOCK_H : (head_block_idx + 1) * VALID_BLOCK_H, dv:], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) total_blocks = T.ceildiv(cache_seqlens[bx] - seqlen_q + seq_q_idx + 1, block_N) blocks_per_split = T.floordiv(total_blocks, num_split) remaining_blocks = T.floormod(total_blocks, num_split) loop_range = blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0) start = (blocks_per_split * bz + T.min(bz, remaining_blocks)) * block_N for k in T.Pipelined(loop_range, num_stages=2): global_token_idx = start + k * block_N logical_page_idx = global_token_idx // block_size page_offset = global_token_idx % block_size physical_page_id = block_table[bx, logical_page_idx] # KV Head is fixed to 0 T.copy(KV[physical_page_id, page_offset : page_offset + block_N, 0, :dv], KV_shared) T.copy(KV[physical_page_id, page_offset : page_offset + block_N, 0, dv:], K_pe_shared) T.clear(acc_s) T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.if_then_else(start + k * block_N + j >= cache_seqlens[bx] - seqlen_q + seq_q_idx + 1, -T.infinity(accum_dtype), acc_s[i, j]) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) T.copy(acc_s, S_shared) T.copy(S_shared, acc_s_cast) for i in T.Parallel(block_H): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_H, dv): acc_o[i, j] *= scores_scale[i] T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) for i, j in T.Parallel(block_H, dv): acc_o[i, j] /= logsum[i] for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale T.copy(logsum, glse[bx, seq_q_idx, head_block_idx * VALID_BLOCK_H : (head_block_idx + 1) * VALID_BLOCK_H, bz]) T.copy(acc_o, O_shared) T.copy(O_shared, Output_partial[bx, seq_q_idx, head_block_idx * VALID_BLOCK_H : (head_block_idx + 1) * VALID_BLOCK_H, bz, :]) # Combine kernel with T.Kernel(seqlen_q * h_q, batch, threads=128) as (by, bz): po_local = T.alloc_fragment([dv], dtype) o_accum_local = T.alloc_fragment([dv], accum_dtype) # lse_local_split = T.alloc_var(accum_dtype) # lse_logsum_local = T.alloc_var(accum_dtype) # lse_max_local = T.alloc_var(accum_dtype) # scale_local = T.alloc_var(accum_dtype) lse_local_split = T.alloc_fragment([1], accum_dtype) lse_logsum_local = T.alloc_fragment([1], accum_dtype) lse_max_local = T.alloc_fragment([1], accum_dtype) scale_local = T.alloc_fragment([1], accum_dtype) seq_q_head_idx = by seq_q_idx = T.floordiv(seq_q_head_idx, h_q) head_idx = T.floormod(seq_q_head_idx, h_q) lse_logsum_local[0] = 0.0 T.clear(o_accum_local) lse_max_local[0] = -T.infinity(accum_dtype) for k in T.serial(num_split): lse_max_local[0] = T.max(lse_max_local[0], glse[bz, seq_q_idx, head_idx, k]) for k in T.Pipelined(num_split, num_stages=1): lse_local_split[0] = glse[bz, seq_q_idx, head_idx, k] lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] for k in T.serial(num_split): for i in T.Parallel(dv): po_local[i] = Output_partial[bz, seq_q_idx, head_idx, k, i] lse_local_split[0] = glse[bz, seq_q_idx, head_idx, k] scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) for i in T.Parallel(dv): o_accum_local[i] += po_local[i] * scale_local[0] for i in T.Parallel(dv): Output[bz, seq_q_idx, head_idx, i] = o_accum_local[i] @T.prim_func def main_no_split( Q: T.Tensor([batch, seqlen_q, h_q, dv+ dpe], dtype), KV: T.Tensor([num_pages, block_size, 1, dv+ dpe], dtype), block_table: T.Tensor([batch, max_num_blocks_per_seq], T.int32), cache_seqlens: T.Tensor([batch], T.int32), glse: T.Tensor([batch, seqlen_q, h_q, num_split], dtype), Output_partial: T.Tensor([batch, seqlen_q, h_q, num_split, dv], dtype), Output: T.Tensor([batch, seqlen_q, h_q, dv], dtype), ): with T.Kernel(batch, seqlen_q * h_q // VALID_BLOCK_H, threads=256) as (bx, by): Q_shared = T.alloc_shared([block_H, dv], dtype) S_shared = T.alloc_shared([block_H, block_N], dtype) Q_pe_shared = T.alloc_shared([block_H, dpe], dtype) KV_shared = T.alloc_shared([block_N, dv], dtype) K_pe_shared = T.alloc_shared([block_N, dpe], dtype) O_shared = T.alloc_shared([block_H, dv], dtype) acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) acc_o = T.alloc_fragment([block_H, dv], accum_dtype) scores_max = T.alloc_fragment([block_H], accum_dtype) scores_max_prev = T.alloc_fragment([block_H], accum_dtype) scores_scale = T.alloc_fragment([block_H], accum_dtype) scores_sum = T.alloc_fragment([block_H], accum_dtype) logsum = T.alloc_fragment([block_H], accum_dtype) seq_q_head_block_idx = by seq_q_idx = T.floordiv(seq_q_head_block_idx, (h_q + VALID_BLOCK_H - 1) // VALID_BLOCK_H) head_block_idx = T.floormod(seq_q_head_block_idx, (h_q + VALID_BLOCK_H - 1) // VALID_BLOCK_H) T.use_swizzle(10) T.copy(Q[bx, seq_q_idx, head_block_idx * VALID_BLOCK_H : (head_block_idx + 1) * VALID_BLOCK_H, :dv], Q_shared) T.copy(Q[bx, seq_q_idx, head_block_idx * VALID_BLOCK_H : (head_block_idx + 1) * VALID_BLOCK_H, dv:], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = T.ceildiv(cache_seqlens[bx] - seqlen_q + seq_q_idx + 1, block_N) for kr in T.Pipelined(loop_range, num_stages=2): k = loop_range - 1 - kr global_token_idx = k * block_N logical_page_idx = global_token_idx // block_size page_offset = global_token_idx % block_size physical_page_id = block_table[bx, logical_page_idx] # KV Head is fixed to 0 T.copy(KV[physical_page_id, page_offset : page_offset + block_N, 0, : dv], KV_shared) T.copy(KV[physical_page_id, page_offset : page_offset + block_N, 0, dv:], K_pe_shared) T.clear(acc_s) T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) if kr == 0: for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.if_then_else(k * block_N + j >= cache_seqlens[bx] - seqlen_q + seq_q_idx + 1, -T.infinity(accum_dtype), acc_s[i, j]) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) T.copy(acc_s, S_shared) for i in T.Parallel(block_H): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_H, dv): acc_o[i, j] *= scores_scale[i] T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) for i, j in T.Parallel(block_H, dv): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) T.copy(O_shared, Output[bx, seq_q_idx, head_block_idx * VALID_BLOCK_H : (head_block_idx + 1) * VALID_BLOCK_H, :]) if num_split > 1: return main_split else: return main_no_split def get_splits( batch_size: int, num_heads: int, seqlen_q: int, avg_seqlen_k: int, block_size_h: int = 64, block_size_n: int = 128, streaming_info: Optional[List[int]] = None, ): """ Calculates the optimal static num_splits to saturate the GPU without incurring unnecessary reduction overhead. """ # 1. Get Device Capabilities device = torch.cuda.current_device() props = torch.cuda.get_device_properties(device) num_sms = props.multi_processor_count # 2. Define Saturation Target # We usually want 1 to 2 waves of blocks per SM to hide latency. # For MLA (memory bound), 1 full wave is often enough, but 2 is safer. target_blocks_per_sm = 2 target_total_blocks = num_sms * target_blocks_per_sm # 3. Calculate Effective KV Length (Maybe The "Streaming" Adjustment) # The effective length is capped by the cache capacity # (Sink + Local) when sequence length exceeds the window. effective_seqlen_k = avg_seqlen_k if streaming_info is not None and len(streaming_info) >= 2: sink_block_num = streaming_info[0] local_block_num = streaming_info[1] effective_seqlen_k = min(avg_seqlen_k, (sink_block_num + local_block_num) * block_size_n) # 4. Calculate "Natural" Parallelism # The number of independent tasks available without splitting KV. # Usually: Batch * Heads (since seqlen_q is typically 1 in decoding) # ceil_div(num_heads, block_size_h) handles cases where heads are grouped. num_head_blocks = (num_heads + block_size_h - 1) // block_size_h natural_blocks = batch_size * seqlen_q * num_head_blocks # 5. Determine Split Ratio needed to hit target if natural_blocks >= target_total_blocks: # We already have enough parallelism to saturate the GPU. # Splitting further just adds overhead. return 1 # We need to split to create more blocks needed_splits = (target_total_blocks + natural_blocks - 1) // natural_blocks # 5. Clamp based on Sequence Length (The "Don't shred it" check) # We don't want a tile processing fewer than, say, 128 tokens. # Otherwise, loop overhead dominates the math. min_tokens_per_tile = block_size_n max_splits_possible = max(1, effective_seqlen_k // min_tokens_per_tile) optimal_split = min(needed_splits, max_splits_possible) # 6. (Optional) Power of 2 alignment often helps compiler optimizations # Rounds to nearest power of 2 (1, 2, 4, 8, 16...) if optimal_split > 1: optimal_split = 2 ** math.floor(math.log2(optimal_split)) return int(max(1, optimal_split)) def flash_mla_with_kvcache( q: torch.Tensor, kv: torch.Tensor, cache_seqlens: torch.Tensor, block_table: torch.Tensor, dim_nope: int, softmax_scale: float, causal: bool, ): assert causal == True batch_size, seqlen_q, num_heads_q, head_dim_qk = q.shape num_pages, page_size, num_heads_kv, head_dim_vo = kv.shape assert num_heads_kv == 1 assert head_dim_qk == head_dim_vo block_H = 64 block_N = 64 num_splits = get_splits(batch_size, num_heads_q, seqlen_q, torch.mean(cache_seqlens.float()).int().item(), block_H, block_N, None) glse = torch.empty(batch_size, seqlen_q, num_heads_q, num_splits, dtype=q.dtype, device=q.device) out_partial = torch.empty(batch_size, seqlen_q, num_heads_q, num_splits, dim_nope, dtype=q.dtype, device=q.device) kernel = flash_mla_with_kvcache_kernel( batch_size, seqlen_q, num_heads_q, num_heads_kv, dim_nope, head_dim_qk - dim_nope, block_N, block_H, num_splits, page_size, num_pages, block_table.shape[1], softmax_scale ) return kernel(q, kv, block_table, cache_seqlens, glse, out_partial)