参考:https://zhuanlan.zhihu.com/p/669926191

import torch

N = 32  # seqlen
d = 128  # hidden size

Q = torch.randn(N, d)
K = torch.randn(N, d)
V = torch.randn(N, d)

def fa(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
    T = 4  # 分块数。假设 Q 跟 KV 的分块数一样。
    Qs = Q.reshape(T, N // T, d)  # 模拟对一个 (N, d) 的矩阵分块。下同。
    Ks = K.reshape(T, N // T, d)
    Vs = V.reshape(T, N // T, d)

    Os = torch.zeros(T, N // T, d)
    L = torch.zeros(T, N // T, 1)  # 行之和
    M = torch.full((T, N // T, 1), -torch.inf)  # 行之最大值

    for j in range(T):
        Kj, Vj = Ks[j], Vs[j]  # load to SRAM
        for i in range(T):
            Qi, Oi, Li, Mi = Qs[i], Os[i], L[i], M[i]  # load to SRAM

            Sij = Qi @ Kj.T
            mij = torch.max(Sij, dim=-1, keepdim=True).values
            Pij = torch.exp(Sij - mij)
            Lij = torch.sum(Pij, dim=-1, keepdim=True)

            Mi_new = torch.max(Mi, mij)
            Li_new = torch.exp(Mi - Mi_new) * Li + torch.exp(mij - Mi_new) * Lij

            # update
            Os[i] = (torch.exp(Mi - Mi_new) * Oi * Li + torch.exp(mij - Mi_new) * Pij @ Vj) / Li_new
            L[i] = Li_new
            M[i] = Mi_new

    return Os.reshape(N, d)

print(a := fa(Q, K, V))
print(b := torch.softmax(Q @ K.T, dim=-1) @ V)
print((a - b).abs().max())

image.png

为什么把 K V 放到外层循环?

FA2 优化点:

参考:https://zhuanlan.zhihu.com/p/691067658

ring attention

image.png

核心观察:

这就有了 ring attention。

所谓 ring,就是要把当前计算的 Ki 和 Vi 在不同的卡之间流转起来,像一个环。并且这个通信过程可以跟计算过程 overlap 起来。

跟 flash attention v2 的区别在于,fa2 的 Q 是外层循环,而 ring attention 的 Q 在不同卡上。