import math
import torch
def get_freqs(context_len, hidden_dim):
assert hidden_dim % 2 == 0
base = 10000.0
theta = 1 / base ** (torch.arange(0, hidden_dim, 2, dtype=torch.float32) / hidden_dim)
seq = torch.arange(context_len, dtype=torch.float32)
m_theta = torch.outer(seq, theta) # seqlen hidden_dim // 2
full_dim_theta = torch.cat([m_theta, m_theta], dim=-1) # seqlen hidden_dim
return full_dim_theta
def apply_sujianlin(q, freqs):
# freqs 的 shape 是 seqlen hidden_dim。所以 q 的后两个维度也应是 seqlen hidden_dim。
# 苏剑林原版 rope 实现。
q_for_cos = q
shape = q.shape
x1, x2 = torch.chunk(q.reshape(*shape[:-1], -1, 2), 2, dim=-1)
q_for_sin = torch.cat([-x2, x1], dim=-1).reshape(shape)
return q_for_cos * torch.cos(freqs) + q_for_sin * torch.sin(freqs)
def apply_megatron(q, freqs):
# freqs 的 shape 是 seqlen hidden_dim。所以 q 的后两个维度也应是 seqlen hidden_dim。
# megatron 实现,跟原版 rope 有点差别,但计算量更少一些,且仍然满足相对位置编码。
q_for_cos = q
x1, x2 = torch.chunk(q, 2, dim=-1)
q_for_sin = torch.cat((-x2, x1), dim=-1)
return q_for_cos * torch.cos(freqs) + q_for_sin * torch.sin(freqs)
def rope(Q: torch.Tensor, max_seqlen, prev_seqlen):
batch_size, num_head, seqlen, head_dim = Q.shape
freqs = get_freqs(max_seqlen, head_dim)[prev_seqlen : prev_seqlen + seqlen]
return apply_megatron(Q, freqs)
# hidden_state: batch seqlen hidden
# Wq: hidden num_head*head_dim_k
# Wk: hidden num_kv_head*head_dim_k
# Wv: hidden num_kv_head*head_dim_v
# Wo: num_head*head_dim_v hidden
# K_cache: batch num_kv_head seqlen head_dim_k
# V_cache: batch num_kv_head seqlen head_dim_v
# return: batch seqlen hidden
def multi_head_attention(
hidden_state: torch.Tensor,
max_seqlen: int,
Wq: torch.Tensor,
Wk: torch.Tensor,
Wv: torch.Tensor,
Wo: torch.Tensor,
num_head: int,
num_kv_head: int,
K_cache: torch.Tensor | None,
V_cache: torch.Tensor | None,
):
batch_size, seqlen, hidden_size = hidden_state.shape
head_dim_q, rq = divmod(Wq.shape[1], num_head)
head_dim_k, rk = divmod(Wk.shape[1], num_kv_head)
head_dim_v, rv = divmod(Wv.shape[1], num_kv_head)
assert num_head % num_kv_head == 0
assert Wq.shape[0] == Wk.shape[0] == Wv.shape[0] == Wo.shape[1] == hidden_size
assert head_dim_q == head_dim_k and rq == 0 and rk == 0 and rv == 0
Q = hidden_state @ Wq # batch seqlen num_head*head_dim_k
K = hidden_state @ Wk # batch seqlen num_kv_head*head_dim_k
V = hidden_state @ Wv # batch seqlen num_kv_head*head_dim_v
# 这里转置一下是为了把 num_head 提前,让 seqlen x head_dim 参与矩阵乘。
# 相当于 batch_size x num_head 个 seqlen x head_dim 的矩阵。
# batch num_head seqlen head_dim_k
Q = Q.reshape(batch_size, seqlen, num_head, head_dim_k).transpose(1, 2)
# batch num_kv_head seqlen head_dim_k
K = K.reshape(batch_size, seqlen, num_kv_head, head_dim_k).transpose(1, 2)
# batch num_kv_head seqlen head_dim_v
V = V.reshape(batch_size, seqlen, num_kv_head, head_dim_v).transpose(1, 2)
# 如果有 KV cache,就需要注意 rope 时,当前的 token 是第几个 token。
Q = rope(Q, max_seqlen, prev_seqlen=K_cache.shape[-2] if K_cache is not None else 0)
K = rope(K, max_seqlen, prev_seqlen=K_cache.shape[-2] if K_cache is not None else 0)
# 更新 KV_cache
new_K_cache = K if K_cache is None else torch.cat([K_cache, K], dim=-2)
new_V_cache = V if V_cache is None else torch.cat([V_cache, V], dim=-2)
# 包含历史 token,到目前为止的 seqlen
seqlen_so_far = new_K_cache.shape[-2]
# ################### 计算 attention ###################
# GQA: GQA groups 跟 num_key_value_heads 是一个意思。
# 如果 num_kv_head == num_head,GQA 就是 MHA。如果 num_kv_head == 1,那就是 MQA。
# 如果 num_kv_head < num_head,那必须满足 num_head 是 num_kv_head 的整数倍。
# 这样才能让 K 和 V 在 grouped 维度去广播。
# 因为要广播,所以 GQA 不影响计算量,只是会做参数共享,减少显存占用。
# 要实现广播,简单的方法就是用 expand,不会拷贝数据,自动在指定维度广播。
if num_kv_head < num_head:
num_q_head_per_group = num_head // num_kv_head
Q = Q.view(batch_size, num_kv_head, num_q_head_per_group, seqlen, head_dim_k)
# 如果有 KV cache,K 和 V 可以先合并再 expand。反过来的话浪费显存。
# 合并后刚好就是新的 new_K_cache / new_V_cache,所以直接拿过来用了。
# 记得调整 attention mask。
# 注意到不管有没有 KV cache,都可以用 new_K_cache / new_V_cache,来表示当前需要参与计算的 K 和 V,所以可以直接用。
K = new_K_cache.unsqueeze(2).expand(
batch_size, num_kv_head, num_q_head_per_group, seqlen_so_far, head_dim_k
)
V = new_V_cache.unsqueeze(2).expand(
batch_size, num_kv_head, num_q_head_per_group, seqlen_so_far, head_dim_v
)
score = Q @ K.transpose(-1, -2) / math.sqrt(head_dim_k)
# MHA 时,score 是 batch num_head seqlen seqlen
# GQA 时,score 是 batch num_kv_head num_q_head_per_group seqlen seqlen
# 对于计算出来 batch_size x num_head 个 seqlen x seqlen 的 score 矩阵
# 每一个 seqlen x seqlen 的行,对应一个 Q 的行,表示这一行 Q 代表的 token 对其他 token 的注意力。
# 所以,为了不让 token 看到未来的 token,需要把上三角(而不是下三角),的部分置为负无穷,不包括对角线,这样 softmax 之后变成 0
# 注意,token 能看到自身。所以不要把对角线 mask 掉。
# 一般 attention_mask 里面用 1 表示参与计算的位置,所以需要一个下三角的 mask,并且对角线为 1
# torch.tril(tensor) 函数保留 tensor 的下三角,以及对角线。
# diagonal 表示对角线的偏移量,0 表示不偏移,保留主对角线,并不是把对角线置为 0 的意思。
# 还得考虑启用 KV cache 时,attention_mask 不是方阵,需要靠调整 diagonal 来移动对角线。
# 调整 diagonal 是需要保证 mask 的右下角那个数字是 1,并且它的上方全 0。
attention_mask = torch.tril(torch.ones(seqlen, seqlen_so_far), diagonal=seqlen_so_far - seqlen)
score.masked_fill_(attention_mask == 0, -torch.inf) # inplace 操作
score = score.softmax(dim=-1)
after_qkv = score @ V
# MHA 时,after_qkv 是 batch num_head seqlen head_dim_v
# GQA 时,after_qkv 是 batch num_kv_head num_q_head_per_group seqlen head_dim_v
# 统一一下 MHA 和 GQA 的输出 shape
after_qkv = after_qkv.reshape(batch_size, num_head, seqlen, head_dim_v)
# ################### attention 结束 ###################
# 把 seqlen 转到前面,把 num_head 和 head_dim_v 合并到一起,也就是把所有 head 的 score 拼在一起。
# batch seqlen num_head*head_dim_v
after_qkv = after_qkv.transpose(1, 2).reshape(batch_size, seqlen, -1)
# 这里的 after_qkv 可以看作是把所有 num_head 个 head 的 score 拼在一起了。
# 所以 Wo 也可以看作是做了 num_head 次 projection,并且把每个 head 的 projection 结果合并起来了。
# 实际操作中,没必要区分 Wo 的 num_head,直接整体矩阵乘,就是了。
attn_output = after_qkv @ Wo # batch seqlen hidden
return attn_output, new_K_cache, new_V_cache # batch seqlen hidden
if __name__ == "__main__":
batch_size = 1
seqlen = 64
hidden = 32
num_head = 4
num_kv_head = 2
head_dim_k = 4
head_dim_v = 12
# 这里的 head_dim_k 和 head_dim_v 不一定要相等,也不一定要等于 hidden_size / num_head。
# 事实上可以随便取。
# megatron 里面是让他俩相等,而且让 head_dim_k * num_head 等于 hidden
# 但是 head_dim_q 一定要等于 head_dim_k,不然 Q @ K.T 不能计算
hidden_state = torch.randn(batch_size, seqlen, hidden)
Wq = torch.randn(hidden, num_head * head_dim_k)
Wk = torch.randn(hidden, num_kv_head * head_dim_k)
Wv = torch.randn(hidden, num_kv_head * head_dim_v)
Wo = torch.randn(num_head * head_dim_v, hidden)
hidden_state1, hidden_state2 = hidden_state.split([seqlen - 1, 1], dim=-2)
# 测试 KV cache。先把前 seqlen - 1 个token 传入,获取 KV cache。
attn_output1, K_cache, V_cache = multi_head_attention(
hidden_state1, seqlen, Wq, Wk, Wv, Wo, num_head, num_kv_head, None, None
)
# 再把最后一个 token 和 KV cache 传入。
attn_output2, _, _ = multi_head_attention(
hidden_state2, seqlen, Wq, Wk, Wv, Wo, num_head, num_kv_head, K_cache, V_cache
)
print(attn_output2.shape)
print(attn_output2)
# 再整体计算。看最后一个 token 的输出是否一致。
attn_output, _, _ = multi_head_attention(
hidden_state, seqlen, Wq, Wk, Wv, Wo, num_head, num_kv_head, None, None
)
print(attn_output[:, [-1], :].shape)
print(attn_output[:, -1, :])
print("max diff:", (attn_output2 - attn_output[:, [-1], :]).abs().max())