参考:https://zhuanlan.zhihu.com/p/669926191
import torch
N = 32 # seqlen
d = 128 # hidden size
Q = torch.randn(N, d)
K = torch.randn(N, d)
V = torch.randn(N, d)
def fa(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
T = 4 # 分块数。假设 Q 跟 KV 的分块数一样。
Qs = Q.reshape(T, N // T, d) # 模拟对一个 (N, d) 的矩阵分块。下同。
Ks = K.reshape(T, N // T, d)
Vs = V.reshape(T, N // T, d)
Os = torch.zeros(T, N // T, d)
L = torch.zeros(T, N // T, 1) # 行之和
M = torch.full((T, N // T, 1), -torch.inf) # 行之最大值
for j in range(T):
Kj, Vj = Ks[j], Vs[j] # load to SRAM
for i in range(T):
Qi, Oi, Li, Mi = Qs[i], Os[i], L[i], M[i] # load to SRAM
Sij = Qi @ Kj.T
mij = torch.max(Sij, dim=-1, keepdim=True).values
Pij = torch.exp(Sij - mij)
Lij = torch.sum(Pij, dim=-1, keepdim=True)
Mi_new = torch.max(Mi, mij)
Li_new = torch.exp(Mi - Mi_new) * Li + torch.exp(mij - Mi_new) * Lij
# update
Os[i] = (torch.exp(Mi - Mi_new) * Oi * Li + torch.exp(mij - Mi_new) * Pij @ Vj) / Li_new
L[i] = Li_new
M[i] = Mi_new
return Os.reshape(N, d)
print(a := fa(Q, K, V))
print(b := torch.softmax(Q @ K.T, dim=-1) @ V)
print((a - b).abs().max())
为什么把 K V 放到外层循环?
FA2 优化点:
参考:https://zhuanlan.zhihu.com/p/691067658
核心观察:
这就有了 ring attention。
所谓 ring,就是要把当前计算的 Ki 和 Vi 在不同的卡之间流转起来,像一个环。并且这个通信过程可以跟计算过程 overlap 起来。
跟 flash attention v2 的区别在于,fa2 的 Q 是外层循环,而 ring attention 的 Q 在不同卡上。