import math

import torch

def get_freqs(context_len, hidden_dim):
    assert hidden_dim % 2 == 0
    base = 10000.0
    theta = 1 / base ** (torch.arange(0, hidden_dim, 2, dtype=torch.float32) / hidden_dim)

    seq = torch.arange(context_len, dtype=torch.float32)
    m_theta = torch.outer(seq, theta)  # seqlen hidden_dim // 2

    full_dim_theta = torch.cat([m_theta, m_theta], dim=-1)  # seqlen hidden_dim
    return full_dim_theta

def apply_sujianlin(q, freqs):
    # freqs 的 shape 是 seqlen hidden_dim。所以 q 的后两个维度也应是 seqlen hidden_dim。
    # 苏剑林原版 rope 实现。
    q_for_cos = q
    shape = q.shape
    x1, x2 = torch.chunk(q.reshape(*shape[:-1], -1, 2), 2, dim=-1)
    q_for_sin = torch.cat([-x2, x1], dim=-1).reshape(shape)
    return q_for_cos * torch.cos(freqs) + q_for_sin * torch.sin(freqs)

def apply_megatron(q, freqs):
    # freqs 的 shape 是 seqlen hidden_dim。所以 q 的后两个维度也应是 seqlen hidden_dim。
    # megatron 实现,跟原版 rope 有点差别,但计算量更少一些,且仍然满足相对位置编码。
    q_for_cos = q
    x1, x2 = torch.chunk(q, 2, dim=-1)
    q_for_sin = torch.cat((-x2, x1), dim=-1)
    return q_for_cos * torch.cos(freqs) + q_for_sin * torch.sin(freqs)

def rope(Q: torch.Tensor, max_seqlen, prev_seqlen):
    batch_size, num_head, seqlen, head_dim = Q.shape
    freqs = get_freqs(max_seqlen, head_dim)[prev_seqlen : prev_seqlen + seqlen]
    return apply_megatron(Q, freqs)

# hidden_state: batch seqlen hidden
# Wq: hidden num_head*head_dim_k
# Wk: hidden num_kv_head*head_dim_k
# Wv: hidden num_kv_head*head_dim_v
# Wo: num_head*head_dim_v hidden
# K_cache: batch num_kv_head seqlen head_dim_k
# V_cache: batch num_kv_head seqlen head_dim_v
# return: batch seqlen hidden
def multi_head_attention(
    hidden_state: torch.Tensor,
    max_seqlen: int,
    Wq: torch.Tensor,
    Wk: torch.Tensor,
    Wv: torch.Tensor,
    Wo: torch.Tensor,
    num_head: int,
    num_kv_head: int,
    K_cache: torch.Tensor | None,
    V_cache: torch.Tensor | None,
):
    batch_size, seqlen, hidden_size = hidden_state.shape
    head_dim_q, rq = divmod(Wq.shape[1], num_head)
    head_dim_k, rk = divmod(Wk.shape[1], num_kv_head)
    head_dim_v, rv = divmod(Wv.shape[1], num_kv_head)

    assert num_head % num_kv_head == 0
    assert Wq.shape[0] == Wk.shape[0] == Wv.shape[0] == Wo.shape[1] == hidden_size
    assert head_dim_q == head_dim_k and rq == 0 and rk == 0 and rv == 0

    Q = hidden_state @ Wq  # batch seqlen num_head*head_dim_k
    K = hidden_state @ Wk  # batch seqlen num_kv_head*head_dim_k
    V = hidden_state @ Wv  # batch seqlen num_kv_head*head_dim_v

    # 这里转置一下是为了把 num_head 提前,让 seqlen x head_dim 参与矩阵乘。
    # 相当于 batch_size x num_head 个 seqlen x head_dim 的矩阵。
    # batch num_head seqlen head_dim_k
    Q = Q.reshape(batch_size, seqlen, num_head, head_dim_k).transpose(1, 2)
    # batch num_kv_head seqlen head_dim_k
    K = K.reshape(batch_size, seqlen, num_kv_head, head_dim_k).transpose(1, 2)
    # batch num_kv_head seqlen head_dim_v
    V = V.reshape(batch_size, seqlen, num_kv_head, head_dim_v).transpose(1, 2)

    # 如果有 KV cache,就需要注意 rope 时,当前的 token 是第几个 token。
    Q = rope(Q, max_seqlen, prev_seqlen=K_cache.shape[-2] if K_cache is not None else 0)
    K = rope(K, max_seqlen, prev_seqlen=K_cache.shape[-2] if K_cache is not None else 0)
    # 更新 KV_cache
    new_K_cache = K if K_cache is None else torch.cat([K_cache, K], dim=-2)
    new_V_cache = V if V_cache is None else torch.cat([V_cache, V], dim=-2)
    # 包含历史 token,到目前为止的 seqlen
    seqlen_so_far = new_K_cache.shape[-2]

    # ################### 计算 attention ###################
    # GQA: GQA groups 跟 num_key_value_heads 是一个意思。
    # 如果 num_kv_head == num_head,GQA 就是 MHA。如果 num_kv_head == 1,那就是 MQA。
    # 如果 num_kv_head < num_head,那必须满足 num_head 是 num_kv_head 的整数倍。
    # 这样才能让 K 和 V 在 grouped 维度去广播。
    # 因为要广播,所以 GQA 不影响计算量,只是会做参数共享,减少显存占用。
    # 要实现广播,简单的方法就是用 expand,不会拷贝数据,自动在指定维度广播。
    if num_kv_head < num_head:
        num_q_head_per_group = num_head // num_kv_head
        Q = Q.view(batch_size, num_kv_head, num_q_head_per_group, seqlen, head_dim_k)
        # 如果有 KV cache,K 和 V 可以先合并再 expand。反过来的话浪费显存。
        # 合并后刚好就是新的 new_K_cache / new_V_cache,所以直接拿过来用了。
        # 记得调整 attention mask。
        # 注意到不管有没有 KV cache,都可以用 new_K_cache / new_V_cache,来表示当前需要参与计算的 K 和 V,所以可以直接用。
        K = new_K_cache.unsqueeze(2).expand(
            batch_size, num_kv_head, num_q_head_per_group, seqlen_so_far, head_dim_k
        )
        V = new_V_cache.unsqueeze(2).expand(
            batch_size, num_kv_head, num_q_head_per_group, seqlen_so_far, head_dim_v
        )

    score = Q @ K.transpose(-1, -2) / math.sqrt(head_dim_k)
    # MHA 时,score 是 batch num_head seqlen seqlen
    # GQA 时,score 是 batch num_kv_head num_q_head_per_group seqlen seqlen

    # 对于计算出来 batch_size x num_head 个 seqlen x seqlen 的 score 矩阵
    # 每一个 seqlen x seqlen 的行,对应一个 Q 的行,表示这一行 Q 代表的 token 对其他 token 的注意力。
    # 所以,为了不让 token 看到未来的 token,需要把上三角(而不是下三角),的部分置为负无穷,不包括对角线,这样 softmax 之后变成 0
    # 注意,token 能看到自身。所以不要把对角线 mask 掉。
    # 一般 attention_mask 里面用 1 表示参与计算的位置,所以需要一个下三角的 mask,并且对角线为 1
    # torch.tril(tensor) 函数保留 tensor 的下三角,以及对角线。
    # diagonal 表示对角线的偏移量,0 表示不偏移,保留主对角线,并不是把对角线置为 0 的意思。
    # 还得考虑启用 KV cache 时,attention_mask 不是方阵,需要靠调整 diagonal 来移动对角线。
    # 调整 diagonal 是需要保证 mask 的右下角那个数字是 1,并且它的上方全 0。
    attention_mask = torch.tril(torch.ones(seqlen, seqlen_so_far), diagonal=seqlen_so_far - seqlen)

    score.masked_fill_(attention_mask == 0, -torch.inf)  # inplace 操作
    score = score.softmax(dim=-1)

    after_qkv = score @ V
    # MHA 时,after_qkv 是 batch num_head seqlen head_dim_v
    # GQA 时,after_qkv 是 batch num_kv_head num_q_head_per_group seqlen head_dim_v

    # 统一一下 MHA 和 GQA 的输出 shape
    after_qkv = after_qkv.reshape(batch_size, num_head, seqlen, head_dim_v)
    # ################### attention 结束 ###################

    # 把 seqlen 转到前面,把 num_head 和 head_dim_v 合并到一起,也就是把所有 head 的 score 拼在一起。
    # batch seqlen num_head*head_dim_v
    after_qkv = after_qkv.transpose(1, 2).reshape(batch_size, seqlen, -1)

    # 这里的 after_qkv 可以看作是把所有 num_head 个 head 的 score 拼在一起了。
    # 所以 Wo 也可以看作是做了 num_head 次 projection,并且把每个 head 的 projection 结果合并起来了。
    # 实际操作中,没必要区分 Wo 的 num_head,直接整体矩阵乘,就是了。
    attn_output = after_qkv @ Wo  # batch seqlen hidden
    return attn_output, new_K_cache, new_V_cache  # batch seqlen hidden

if __name__ == "__main__":
    batch_size = 1
    seqlen = 64
    hidden = 32
    num_head = 4
    num_kv_head = 2
    head_dim_k = 4
    head_dim_v = 12

    # 这里的 head_dim_k 和 head_dim_v 不一定要相等,也不一定要等于 hidden_size / num_head。
    # 事实上可以随便取。
    # megatron 里面是让他俩相等,而且让 head_dim_k * num_head 等于 hidden
    # 但是 head_dim_q 一定要等于 head_dim_k,不然 Q @ K.T 不能计算

    hidden_state = torch.randn(batch_size, seqlen, hidden)
    Wq = torch.randn(hidden, num_head * head_dim_k)
    Wk = torch.randn(hidden, num_kv_head * head_dim_k)
    Wv = torch.randn(hidden, num_kv_head * head_dim_v)
    Wo = torch.randn(num_head * head_dim_v, hidden)

    hidden_state1, hidden_state2 = hidden_state.split([seqlen - 1, 1], dim=-2)

    # 测试 KV cache。先把前 seqlen - 1 个token 传入,获取 KV cache。
    attn_output1, K_cache, V_cache = multi_head_attention(
        hidden_state1, seqlen, Wq, Wk, Wv, Wo, num_head, num_kv_head, None, None
    )
    # 再把最后一个 token 和 KV cache 传入。
    attn_output2, _, _ = multi_head_attention(
        hidden_state2, seqlen, Wq, Wk, Wv, Wo, num_head, num_kv_head, K_cache, V_cache
    )
    print(attn_output2.shape)
    print(attn_output2)

    # 再整体计算。看最后一个 token 的输出是否一致。
    attn_output, _, _ = multi_head_attention(
        hidden_state, seqlen, Wq, Wk, Wv, Wo, num_head, num_kv_head, None, None
    )
    print(attn_output[:, [-1], :].shape)
    print(attn_output[:, -1, :])

    print("max diff:", (attn_output2 - attn_output[:, [-1], :]).abs().max())