| import math |
| from typing import Any |
| from einops import rearrange |
| import torch |
| from diffusers.models.attention_processor import Attention |
|
|
|
|
| |
|
|
| |
|
|
| EPSILON = 1e-6 |
|
|
|
|
| class FlashAttentionFunction(torch.autograd.function.Function): |
| @staticmethod |
| @torch.no_grad() |
| def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): |
| """Algorithm 2 in the paper""" |
|
|
| device = q.device |
| dtype = q.dtype |
| max_neg_value = -torch.finfo(q.dtype).max |
| qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) |
|
|
| o = torch.zeros_like(q) |
| all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) |
| all_row_maxes = torch.full( |
| (*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device |
| ) |
|
|
| scale = q.shape[-1] ** -0.5 |
|
|
| if mask is None: |
| mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) |
| else: |
| mask = rearrange(mask, "b n -> b 1 1 n") |
| mask = mask.split(q_bucket_size, dim=-1) |
|
|
| row_splits = zip( |
| q.split(q_bucket_size, dim=-2), |
| o.split(q_bucket_size, dim=-2), |
| mask, |
| all_row_sums.split(q_bucket_size, dim=-2), |
| all_row_maxes.split(q_bucket_size, dim=-2), |
| ) |
|
|
| for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): |
| q_start_index = ind * q_bucket_size - qk_len_diff |
|
|
| col_splits = zip( |
| k.split(k_bucket_size, dim=-2), |
| v.split(k_bucket_size, dim=-2), |
| ) |
|
|
| for k_ind, (kc, vc) in enumerate(col_splits): |
| k_start_index = k_ind * k_bucket_size |
|
|
| attn_weights = ( |
| torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale |
| ) |
|
|
| if row_mask is not None: |
| attn_weights.masked_fill_(~row_mask, max_neg_value) |
|
|
| if causal and q_start_index < (k_start_index + k_bucket_size - 1): |
| causal_mask = torch.ones( |
| (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device |
| ).triu(q_start_index - k_start_index + 1) |
| attn_weights.masked_fill_(causal_mask, max_neg_value) |
|
|
| block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) |
| attn_weights -= block_row_maxes |
| exp_weights = torch.exp(attn_weights) |
|
|
| if row_mask is not None: |
| exp_weights.masked_fill_(~row_mask, 0.0) |
|
|
| block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp( |
| min=EPSILON |
| ) |
|
|
| new_row_maxes = torch.maximum(block_row_maxes, row_maxes) |
|
|
| exp_values = torch.einsum( |
| "... i j, ... j d -> ... i d", exp_weights, vc |
| ) |
|
|
| exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) |
| exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) |
|
|
| new_row_sums = ( |
| exp_row_max_diff * row_sums |
| + exp_block_row_max_diff * block_row_sums |
| ) |
|
|
| oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_( |
| (exp_block_row_max_diff / new_row_sums) * exp_values |
| ) |
|
|
| row_maxes.copy_(new_row_maxes) |
| row_sums.copy_(new_row_sums) |
|
|
| ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) |
| ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) |
|
|
| return o |
|
|
| @staticmethod |
| @torch.no_grad() |
| def backward(ctx, do): |
| """Algorithm 4 in the paper""" |
|
|
| causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args |
| q, k, v, o, l, m = ctx.saved_tensors |
|
|
| device = q.device |
|
|
| max_neg_value = -torch.finfo(q.dtype).max |
| qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) |
|
|
| dq = torch.zeros_like(q) |
| dk = torch.zeros_like(k) |
| dv = torch.zeros_like(v) |
|
|
| row_splits = zip( |
| q.split(q_bucket_size, dim=-2), |
| o.split(q_bucket_size, dim=-2), |
| do.split(q_bucket_size, dim=-2), |
| mask, |
| l.split(q_bucket_size, dim=-2), |
| m.split(q_bucket_size, dim=-2), |
| dq.split(q_bucket_size, dim=-2), |
| ) |
|
|
| for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): |
| q_start_index = ind * q_bucket_size - qk_len_diff |
|
|
| col_splits = zip( |
| k.split(k_bucket_size, dim=-2), |
| v.split(k_bucket_size, dim=-2), |
| dk.split(k_bucket_size, dim=-2), |
| dv.split(k_bucket_size, dim=-2), |
| ) |
|
|
| for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): |
| k_start_index = k_ind * k_bucket_size |
|
|
| attn_weights = ( |
| torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale |
| ) |
|
|
| if causal and q_start_index < (k_start_index + k_bucket_size - 1): |
| causal_mask = torch.ones( |
| (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device |
| ).triu(q_start_index - k_start_index + 1) |
| attn_weights.masked_fill_(causal_mask, max_neg_value) |
|
|
| exp_attn_weights = torch.exp(attn_weights - mc) |
|
|
| if row_mask is not None: |
| exp_attn_weights.masked_fill_(~row_mask, 0.0) |
|
|
| p = exp_attn_weights / lc |
|
|
| dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc) |
| dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc) |
|
|
| D = (doc * oc).sum(dim=-1, keepdims=True) |
| ds = p * scale * (dp - D) |
|
|
| dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc) |
| dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc) |
|
|
| dqc.add_(dq_chunk) |
| dkc.add_(dk_chunk) |
| dvc.add_(dv_chunk) |
|
|
| return dq, dk, dv, None, None, None, None |
|
|
|
|
| class FlashAttnProcessor: |
| def __call__( |
| self, |
| attn: Attention, |
| hidden_states, |
| encoder_hidden_states=None, |
| attention_mask=None, |
| ) -> Any: |
| q_bucket_size = 512 |
| k_bucket_size = 1024 |
|
|
| h = attn.heads |
| q = attn.to_q(hidden_states) |
|
|
| encoder_hidden_states = ( |
| encoder_hidden_states |
| if encoder_hidden_states is not None |
| else hidden_states |
| ) |
| encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype) |
|
|
| if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None: |
| context_k, context_v = attn.hypernetwork.forward( |
| hidden_states, encoder_hidden_states |
| ) |
| context_k = context_k.to(hidden_states.dtype) |
| context_v = context_v.to(hidden_states.dtype) |
| else: |
| context_k = encoder_hidden_states |
| context_v = encoder_hidden_states |
|
|
| k = attn.to_k(context_k) |
| v = attn.to_v(context_v) |
| del encoder_hidden_states, hidden_states |
|
|
| q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) |
|
|
| out = FlashAttentionFunction.apply( |
| q, k, v, attention_mask, False, q_bucket_size, k_bucket_size |
| ) |
|
|
| out = rearrange(out, "b h n d -> b n (h d)") |
|
|
| out = attn.to_out[0](out) |
| out = attn.to_out[1](out) |
| return out |
|
|