大模型Attention发展历程

1

MHA Multi Head Attention

MHA 相比于传统的 Self-Attention 提升了模型的表达能力。例如一个专家组对比一个专家。
先看代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)

query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)

attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights

这个 attention 是 Multi-head self-attention,注意力有头这个概念。其中,Q K V 这三个向量分别是由 hidden_states 乘上一个 Linear 层来计算的。这里面由于是多头注意力机制,所以需要让 Q K V 的 shape 变为 (batch, seq_len, num_head, head_dim)。之后 transpose 成 (batch, num_head, seq_len, head_dim),这么做的目的是为了让后续计算时,每个头可以独立进行注意力计算,让 num_head 变为“批次”的一部分,注意力仅在 seq_len 维度计算,而不把不同的 head 混在一起算。
再算完这几个向量后,需要对 Q K 向量添加位置信息,这就是 apply_rotary_pos_emb 的作用。之后就是 attention 的计算,也就是我们说的公式:

Attention(Q,K,V)=softmax(QKTdk)V\begin{aligned} \text{Attention}(Q, K, V) &= \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \end{aligned}

其中,QK 的意思是 Q 和 K 的点积,表示 Q 和 K 之间的相似性,更进一步,这个向量是词向量,是词在高维空间的数值映射。词向量之间相关度高表示什么?是不是在一定程度上(不是完全)表示,在关注词 A 的时候,应当给予词 B 更多的关注?除以 dk\sqrt{d_k} 的作用是训练时防止梯度消失。
算完 attention 后,再被 o_projection 映射,结果就算出来了。

  • MHA (Multi-Head Attention):假设头个数为32,
    • 配置:每个 Query 头都有自己独立的 Key 头和 Value 头。
    • 数量:32 个 Q 头,对应 32 个 K 头 和 32 个 V 头。
    • 特点:完全独立,信息隔离性好,模型表达能力强。
      总结:MHA 多头自注意力机制是效果最好的也是最慢的,而 MQA 和 GQA 则是通过不同的方式来提升计算效率。
      一些比较出名的模型,如 LLama2-7B、LLama2-13B 使用的都是这种架构。

MQA Multi Query Attention

相比 MHA,MQA 允许一个 Query 头同时关注多个 Key 头和 Value 头,从而在保证信息隔离性的前提下,提升模型的计算效率。

  • MQA (Multi-Query Attention):假设头个数为32,
    • 配置:所有的 Query 头共享同一组 Key 和 Value 头。
    • 数量:32 个 Q 头,对应 1 个 K 头 和 1 个 V 头。
    • 特点:极端压缩,推理速度极快,但模型效果下降明显(因为所有头被迫看同样的 K/V 信息)。
      目前没听说哪个模型全程使用这个架构。但是 DeepSeek V3.2 的 decode 阶段使用了该方式。

GQA Group Query Attention

GQA 相比 MHA,其主要区别在于,GQA 是先进行“组内”注意力计算,再进行“组间”注意力计算。也就是说,GQA 首先将 Q K V 分成多组,然后在组内进行注意力计算,最后在组间进行注意力计算。推理速度更快,普遍用于 70B 及更大的模型中。MHA 则用于稍微小一点的模型中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights

可以看见代码里有一个 repeat_kv 的操作,作用是将较少的 Key/Value 头复制扩展,使它们能和 Query 的头数一一对应进行矩阵乘法。这正是 GQA 的典型操作:“分组共享,通过复制来对齐”。

  • GQA (Grouped Query Attention):
    • 配置:将 Query 头分组,每组 Query 头共享一组 Key 和 Value。
    • 数量:例如分为 8 组(Group=8)。32 个 Q 头,对应 8 个 K 头 和 8 个 V 头(每4个Q头共享1个K/V头)。
    • 特点:介于 MHA 和 MQA 之间。比 MQA 效果好(保留了部分独立性),比 MHA 速度快(KV Cache 大小减少了 4 倍)。
      GQA 相比于 MHA,效果上几乎相近,而速度却能达到 MQA 的水平,被广泛的应用在现在的大模型上。
      LLama2-70B、LLama3、Qwen2 全系列使用的都是这种架构。

MLA Multi Latent Attention

2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
batch_size, seq_length = hidden_states.shape[:-1]
query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)

if self.q_lora_rank is None:
q_states = self.q_proj(hidden_states)
else:
q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q_states = q_states.view(query_shape).transpose(1, 2)
q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)

k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)

k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)

cos, sin = position_embeddings
if self.config.rope_interleave: # support using interleaved weights for efficiency
q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
else:
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
k_rot = k_rot.expand(*k_pass.shape[:-1], -1)

query_states = torch.cat((q_pass, q_rot), dim=-1)
key_states = torch.cat((k_pass, k_rot), dim=-1)

if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)

if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
attn_output = attn_output[:, :, :, : self.v_head_dim]

attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights

上面代码是 MLA 的流程。可以分为以下几步:

  1. Query 的生成与压缩
    DeepSeek 对 Q 进行了压缩,减少计算量
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 这一部分对应 Query 的生成
if self.q_lora_rank is None:
# 如果没有设置压缩秩,则使用标准投影
q_states = self.q_proj(hidden_states)
else:
# 【核心 MLA 特征】:Query 低秩压缩
# 1. q_a_proj: 降维 (hidden -> low_rank)
# 2. q_a_layernorm: 训练时加了个layernorm,防止梯度爆炸,归一化
# 3. q_b_proj: 升维 (low_rank -> num_heads * qk_head_dim)
q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))

q_states = q_states.view(query_shape).transpose(1, 2)

# 【核心 MLA 特征】:解耦 RoPE
# 将 Query 分割为两部分:一部分不带位置信息,一部分带位置信息
q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
  1. KV 的压缩与潜在向量生成
1
2
3
4
5
6
7
8
9
# 【核心 MLA 特征】:KV 压缩
# 将输入投影到一个非常小的压缩向量 (Latent Vector)
# 这就是推理时缓存的 "KV Cache",体积非常小!
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)

# 将压缩向量拆分:
# k_pass: 压缩后的 KV 内容特征 (用于后续还原)
# k_rot: 专门用于位置编码的部分
k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
  1. KV 的解压与还原
1
2
3
4
5
6
# 1. 对压缩的内容特征进行 LayerNorm
# 2. kv_b_proj: 升维投影,将低维 Latent 还原成高维的 Key 和 Value
k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)

# 分离还原后的 Key 和 Value
k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)

虽然我们在 Cache 里存的是压缩后的向量,但在计算当前时刻的 Attention 分数时,需要将其还原成完整的多头 Key 和 Value。虽然计算量增加了一些,但换来了显存带宽的大幅节省。

  1. 位置编码与拼接
1
2
3
4
5
6
7
8
9
10
# 对专门的位置部分应用 RoPE
cos, sin = position_embeddings
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)

# 扩展 k_rot 维度以匹配多头

# 【核心 MLA 特征】:重组
# 将内容部分 和位置部分 拼接
query_states = torch.cat((q_pass, q_rot), dim=-1)
key_states = torch.cat((k_pass, k_rot), dim=-1)

最终参与计算的 query_states 和 key_states 是由“非旋转部分”和“旋转部分”拼接而成的。这样做的好处是,位置信息不会破坏压缩向量的低秩结构。为什么位置编码要这么做呢?因为在实现 MLA 的时候,有一个小技巧叫做“矩阵吸收”。

xWdownQWup,iQ(Wup,iK)T(cKV)Tx W_{down}^Q W_{up,i}^Q (W_{up,i}^K{})^T (c^{KV}{})^T

在这个公式中,可以将 kv 的上投影矩阵Wup,iKW_{up,i}^K看成是Wup,iQW_{up,i}^Q,这样子,位置信息部分就可以通过矩阵吸收的方式,被WdownQW_{down}^Q“吸收”进去了。但是如果加了常规旋转位置编码的话,位置信息部分就无法被WdownQW_{down}^Q“吸收”进去了。可以看见被隔断了。无法吸收,会导致推理效率下降。

xWdownQWup,iQRmRnT(Wup,iK)T(cKV)Tx W_{down}^Q W_{up,i}^Q R_{m} R_{n}^T (W_{up,i}^K{})^T (c^{KV}{})^T

所以引入了 Decoupled RoPE (解耦 RoPE)。q_rot 和 k_rot 分开处理,这样位置编码和内容编码就可以分开处理,互不干扰。

  1. 注意力计算
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 更新 KV Cache (这里存储的通常是还原后的 K 和 V,或者是压缩向量,取决于具体的 Cache 实现)
if past_key_values is not None:
key_states, value_states = past_key_values.update(...)

# 【Flash Attention 兼容性处理】
# 如果 QK 的维度和 V 的维度不一样,需要 Padding 对齐
if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])

# 调用标准注意力接口计算
attn_output, attn_weights = attention_interface(...)

# 去除 Padding
if ...:
attn_output = attn_output[:, :, :, : self.v_head_dim]

# 输出投影
attn_output = self.o_proj(attn_output)
  • MLA (Multi-Head Latent Attention)
    • 配置:通过生成 compressed_kv,将 KV Cache 的大小压缩了 90% 以上。存储上 compressed_kv 只有一份,接近 MQA。
    • 数量:通过 kv_b_proj 动态还原 KV,保留了多头注意力的表达能力。每个头一个 KV 数量,这一部分接近 MHA。
    • 特点:通过 split 和 cat 操作,把位置信息和内容信息分开处理,使得压缩机制能与 RoPE 完美兼容。

NSA Natural Sparse Attention

该注意力机制首次被提出是在 2025 年的 2 月。由 DeepSeek 针对长序列处理提出的原生稀疏注意力机制,属于 block-wise 的稀疏方案。这个目前只存在于论文中,没有实际用于某个模型。直到 Deepseek-V3.2 中,他们提出了 DSA,可以说是 NSA 的一个变种实现, 属于 token-wise 的稀疏方案。

  1. 压缩块,计算块间注意力
    标准自注意力需要计算序列中每个词元(Token)与其他所有词元的关系。对于一个长度为 N 的序列,需要计算 N*N 个注意力权重,随长度二次增长。
    解决方案是,NSA 不再以单个 token 为基本计算单元,而是将连续的多个 token 聚合成一个 block。例如,将 4096 个词元的序列划分为 64 个块,每个块包含 64 个 token。
    计算流程转变:
    之前(标准注意力): 处理 4096 个 token -> 需要处理 4096 * 4096≈16.7M 个关系对。
    之后(块压缩): 处理 64 个 block -> 先计算块与块之间的注意力。此时,需要处理的关系对数量骤降至 64 * 64=4096 个。
  2. 重要性筛选:选出重要的 K 个块
    在压缩后的 block 中,筛选出需要详细看的部分:
    块压缩是假设所有 block 都同等重要。但实际上,对于当前要处理的 token 来说,某些 block 是关键的而其他的是次要的。NSA 引入动态机制,根据当前的 Q 内容,评估并筛选出最相关的少量关键 block。
    系统会为每个 block 计算一个“重要性分数”。这个分数通常基于当前 Q 与每个 block 的“摘要向量”(通常是 block 内 token 的均值或通过一个小型网络生成)的相似度。
    在块压缩的基础上增加重要性筛选,系统从64个块中选出最相关的 K 个块(例如 K=8),额外计算当前查询向量与这 K 个关键块内全部 token 的注意力。
  3. 滑动窗口
    为了保证局部上下文信息的完整性,token 不但要关注遥远的关键块,还要关注紧挨着它的其他 token。NSA 强制性地规定,无论重要性筛选的结果如何,每个 token 都必须关注以其自身为中心的一个局部窗口内的所有其他token。这个窗口是“滑动”的,因为它随着 token 位置的变化而移动。例如,假设滑动窗口大小为 256。那么对于序列中第 500 个 token,无论如何它都会关注第 500-255 至第 500 这个范围内的所有 token。

DSA DeepSeek Sparse Attention

DSA 首次被应用在 Deepseek V3.2 Exp 中,属于 token-wise 的稀疏方案。DeepSeek-V3.2-Exp 仅在 DeepSeek-V3 的基础上新增了 Lightning Indexer 模块,用于选择参与 attention 的 token。
该模块的主要输入是 compressed_q 与 MLA 的输入矩阵 hidden_states,输出则是每个 token 所对应的 2048 个可参与 attention 计算的历史 token 的 index。
被选中的 2048 个 token 会在 attention 的 mask 阶段发挥作用:通过将未选中的位置的 mask 值设为 inf,在经过 softmax 之后,就能有效去除不需要参与 attention 的 token。这样实现了高效的稀疏化,显著降低了计算量。从O(L^2)降到了O(LK),其中 K<<L 为 2048。
Indexer 代码(来自于官方 DeepSeek V3.2 仓库):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def fp8_index(
q: torch.Tensor,
q_s: torch.Tensor,
k: torch.Tensor,
k_s: torch.Tensor,
) -> torch.Tensor:
"""
Perform index score using FP8 precision.

Args:
q (torch.Tensor): The Q tensor, must be contiguous.
q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous.
k (torch.Tensor): The K tensor, must be contiguous.
k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous.

fp8 q @ fp8 k -> fp32 logits
relu(fp32 logits) * q_s (weights) -> fp32 logits
fp32 logits -> fp32 logits_sum
fp32 logits_sum * k_s (e8m0) -> fp32 index_score
"""
return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s)

class Indexer(torch.nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.dim: int = args.dim
self.n_heads: int = args.index_n_heads
self.n_local_heads = args.index_n_heads // world_size
self.head_dim: int = args.index_head_dim
self.rope_head_dim: int = args.qk_rope_head_dim
self.index_topk: int = args.index_topk
self.q_lora_rank: int = args.q_lora_rank
self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.head_dim)
self.wk = Linear(self.dim, self.head_dim)
self.k_norm = LayerNorm(self.head_dim)
# weights_proj in the checkpoint is stored in bf16, while the parameters here are stored in fp32 for convenient.
self.weights_proj = Linear(self.dim, self.n_heads, dtype=torch.float32)
self.softmax_scale = self.head_dim ** -0.5
self.scale_fmt = args.scale_fmt

self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.head_dim, dtype=torch.float8_e4m3fn), persistent=False)
self.register_buffer("k_scale_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.head_dim // block_size, dtype=torch.float32), persistent=False)

def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
bsz, seqlen, _ = x.size()
end_pos = start_pos + seqlen
q = self.wq_b(qr)
q = q.view(bsz, seqlen, self.n_heads, self.head_dim)
q_pe, q_nope = torch.split(q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
# rope in indexer is not interleaved
q_pe = apply_rotary_emb(q_pe, freqs_cis, False)
q = torch.cat([q_pe, q_nope], dim=-1)
k = self.wk(x)
k = self.k_norm(k)
k_pe, k_nope = torch.split(k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
# rope in indexer is not interleaved
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis, False).squeeze(2)
k = torch.cat([k_pe, k_nope], dim=-1)
q = rotate_activation(q)
k = rotate_activation(k)
q_fp8, q_scale = act_quant(q, block_size, self.scale_fmt)
k_fp8, k_scale = act_quant(k, block_size, self.scale_fmt)
self.k_cache[:bsz, start_pos:end_pos] = k_fp8
self.k_scale_cache[:bsz, start_pos:end_pos] = k_scale
weights = self.weights_proj(x.float()) * self.n_heads ** -0.5
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
index_score = fp8_index(q_fp8.contiguous(), weights, self.k_cache[:bsz, :end_pos].contiguous(), self.k_scale_cache[:bsz, :end_pos].contiguous())
if mask is not None:
index_score += mask
topk_indices = index_score.topk(min(self.index_topk, end_pos), dim=-1)[1]
topk_indices_ = topk_indices.clone()
dist.broadcast(topk_indices_, src=0)
assert torch.all(topk_indices == topk_indices_), f"{topk_indices=} {topk_indices_=}"
return topk_indices

公式如下:

It,s=j=1HIwt,jIReLU(qt,jIksI)I_{t,s} = \sum_{j=1}^{H_{I}} w_{t,j}^{I} \cdot \operatorname{ReLU}\left( \boldsymbol{q}_{t,j}^{I} \cdot \boldsymbol{k}_{s}^{I} \right)

总体框架如下:
3
其中,绿色部分是 indexer 实现。而框住的部分是函数 fp8_index 部分。可以看到,虽然说 attention 的计算复杂度降低了,但是实际上选择不同 token 重要性这个 indexer 模块,也算是个 attention。还是需要计算查询相对于之前每个 token 的 attention 的重要性分数,复杂度还是 O(n2)O(n^2)。只不过这个过程是在 fp8 精度下进行的,所以效率会比较高。

为什么需要哈达玛变换"rotate_activation"?
直接量化 q 和 k 向量可能会造成精度损失。引入旋转变换操作:将大范围数值打散到小范围。可以与随机的正交矩阵相乘来做这件事,使得向量数值分布均匀且模长不变。不过矩阵乘法代价较大。于是使用小代价的哈达玛变换。

最后得出的topk_indices,会作为 mask 输入到 MLA 中,从而实现稀疏化。可以看下面代码。
对于 Deepseek V3.2 的 attention 实现而言:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
"""
Forward pass for the Multi-Head Latent Attention (MLA) Layer.

Args:
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
start_pos (int): Starting position in the sequence for caching.
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.

Returns:
torch.Tensor: Output tensor with the same shape as the input.
"""
bsz, seqlen, _ = x.size()
end_pos = start_pos + seqlen
qr = self.q_norm(self.wq_a(x))
q = self.wq_b(qr)
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
q_pe = apply_rotary_emb(q_pe, freqs_cis)
kv = self.wkv_a(x)
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv = self.kv_norm(kv)
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
# we use fp8 kv cache in actual deployment, so here we simulate the precision by casting kv to fp8 and then back to bf16.
kv_fp8, kv_scale = act_quant(kv, block_size, self.scale_fmt)
kv = (kv_fp8.view(-1, block_size).float() * kv_scale.view(-1, 1)).to(kv.dtype).view_as(kv)
self.kv_cache[:bsz, start_pos:end_pos] = kv
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
if mask is not None: # MHA prefill
q = torch.cat([q_nope, q_pe], dim=-1)
kv = self.wkv_b(kv)
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
scores = torch.einsum("bshd,bthd->bsht", q, k).mul_(self.softmax_scale)

# indexer
topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
index_mask = torch.full((bsz, seqlen, seqlen), float("-inf"), device=x.device).scatter_(-1, topk_indices, 0)
index_mask += mask
scores += index_mask.unsqueeze(2)

scores = scores.softmax(dim=-1)
x = torch.einsum("bsht,bthd->bshd", scores, v)
else: # MQA decode
if self.dequant_wkv_b is None and self.wkv_b.scale is not None:
self.dequant_wkv_b = weight_dequant(self.wkv_b.weight, self.wkv_b.scale)
wkv_b = self.wkv_b.weight if self.dequant_wkv_b is None else self.dequant_wkv_b
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale

# indexer
topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
index_mask = torch.full((bsz, 1, end_pos), float("-inf"), device=x.device).scatter_(-1, topk_indices, 0)
scores += index_mask.unsqueeze(2)

scores = scores.softmax(dim=-1)
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
x = self.wo(x.flatten(2))
return x

cuda性能分析工具nsys使用

安装 nsys 命令行工具

一般这个工具是随着 cuda toolkit 一起安装的。如果没有安装,可以查看网站:https://docs.nvidia.com/nsight-systems/InstallationGuide/index.html
ubuntu可以用如下安装:

1
2
3
4
5
6
apt update
apt install -y --no-install-recommends gnupg
echo "deb http://developer.download.nvidia.com/devtools/repos/ubuntu$(source /etc/lsb-release; echo "$DISTRIB_RELEASE" | tr -d .)/$(dpkg --print-architecture) /" | tee /etc/apt/sources.list.d/nvidia-devtools.list
apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub
apt update
apt install nsight-systems-cli

安装后可以尝试查看:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
@└────> # nsys status -e
Timestamp counter supported: Yes

CPU Profiling Environment Check
Root privilege: disabled
Linux Kernel Paranoid Level = 4
Linux Distribution = Ubuntu
Linux Kernel Version = 5.15.0-105-generic: OK
Linux perf_event_open syscall available: Fail
Sampling trigger event available: Fail
Intel(c) Last Branch Record support: Not Available
CPU Profiling Environment (process-tree): Fail
CPU Profiling Environment (system-wide): Fail

See the product documentation at https://docs.nvidia.com/nsight-systems for more information,
including information on how to set the Linux Kernel Paranoid Level.

在可以使用 nsys 工具后,可以使用 nsys 来查看一些 kernel 的性能。
举个例子,我们有以下 cuda 代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#include <bits/stdc++.h>
#include <cuda.h>
#include "cuda_runtime.h"
#include "device_launch_parameters.h"
#include <time.h>
#include <sys/time.h>

#define THREAD_PER_BLOCK 256

// baseline
__global__ void reduce0(float* d_in, float* d_out) {
__shared__ float sdata[THREAD_PER_BLOCK];

// each thread loads one element from global to shared mem
unsigned int tid = threadIdx.x;
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
sdata[tid] = d_in[i];
__syncthreads();

// do reduction in shared mem
for (unsigned int s = 1; s < blockDim.x; s *= 2) {
if (tid % (2 * s) == 0) {
sdata[tid] += sdata[tid + s];
}
__syncthreads();
}

// write result for this block to global mem
if (tid == 0) {
d_out[blockIdx.x] = sdata[0];
}
}

// bank conflict
__global__ void reduce1(float* d_in, float* d_out) {
__shared__ float sdata[THREAD_PER_BLOCK];

// each thread loads one element from global to shared mem
unsigned int tid = threadIdx.x;
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
sdata[tid] = d_in[i];
__syncthreads();

// do reduction in shared mem
for (unsigned int s = 1; s < blockDim.x; s *= 2) {
int index = 2 * s * tid;
if (index < blockDim.x) {
sdata[index] += sdata[index + s];
}
__syncthreads();
}

// write result for this block to global mem
if (tid == 0) {
d_out[blockIdx.x] = sdata[0];
}
}

bool check(float* out, float* res, int n) {
for (int i = 0; i < n; i++) {
if (out[i] != res[i]) {
return false;
}
}
return true;
}

int main() {
const int N = 32 * 1024 * 1024;
float* a = (float*)malloc(N * sizeof(float));
float* d_a;
cudaMalloc((void**)&d_a, N * sizeof(float));

int block_num = N / THREAD_PER_BLOCK;
float* out = (float*)malloc((N / THREAD_PER_BLOCK) * sizeof(float));
float* d_out;
cudaMalloc((void**)&d_out, (N / THREAD_PER_BLOCK) * sizeof(float));
float* res = (float*)malloc((N / THREAD_PER_BLOCK) * sizeof(float));

for (int i = 0; i < N; i++) {
a[i] = 1;
}

for (int i = 0; i < block_num; i++) {
float cur = 0;
for (int j = 0; j < THREAD_PER_BLOCK; j++) {
cur += a[i * THREAD_PER_BLOCK + j];
}
res[i] = cur;
}

cudaMemcpy(d_a, a, N * sizeof(float), cudaMemcpyHostToDevice);

dim3 Grid(N / THREAD_PER_BLOCK, 1);
dim3 Block(THREAD_PER_BLOCK, 1);

reduce0<<<Grid, Block>>>(d_a, d_out);
cudaMemcpy(out, d_out, block_num * sizeof(float), cudaMemcpyDeviceToHost);
if (check(out, res, block_num)) {
printf("the ans is right\n");
} else {
printf("the ans is wrong\n");
for (int i = 0; i < block_num; i++) {
printf("%lf ", out[i]);
}
printf("\n");
}

reduce1<<<Grid, Block>>>(d_a, d_out);
cudaMemcpy(out, d_out, block_num * sizeof(float), cudaMemcpyDeviceToHost);
if (check(out, res, block_num)) {
printf("the ans is right\n");
} else {
printf("the ans is wrong\n");
for (int i = 0; i < block_num; i++) {
printf("%lf ", out[i]);
}
printf("\n");
}

cudaFree(d_a);
cudaFree(d_out);
}

以上是一个 reduce_sum 的例子,有两个性能不同的核函数 reduce0 和 reduce1。可以使用命令行工具来看这两个工具的性能。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# 一个过渡方式,沿用之前 nvprof 的用法。
@└────> # nsys nvprof ./a.out

# 后续支持的方式
@└────> # nsys profile --stats=true ./a.out
...
[4/8] Executing 'osrtsum' stats report

Time (%) Total Time (ns) Num Calls Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
-------- --------------- --------- ------------- ------------- --------- ------------- ------------- ----------------------
49.8 6,560,704,940 41 160,017,193.7 100,127,248.0 2,946 3,656,490,088 561,147,066.6 poll
40.5 5,331,726,596 1,591 3,351,179.5 34,054.0 1,200 405,442,288 30,016,785.5 ioctl
4.7 614,593,088 60 10,243,218.1 10,295,123.0 1,532,148 14,425,405 2,880,533.4 waitpid
2.6 341,072,143 60 5,684,535.7 5,470,557.0 576,832 10,285,481 1,599,344.7 fork
2.4 312,589,312 113 2,766,277.1 14,265.0 4,724 267,144,344 25,109,089.4 open64
0.0 5,684,320 144 39,474.4 11,704.5 1,108 3,826,795 317,952.4 fopen
0.0 2,034,643 38 53,543.2 10,234.5 3,945 1,236,239 198,417.6 mmap64
0.0 607,662 10 60,766.2 56,248.0 42,036 112,182 19,745.1 sem_timedwait
0.0 405,925 123 3,300.2 2,289.0 1,006 77,642 7,014.0 fclose
0.0 391,343 4 97,835.8 80,675.0 58,047 171,946 53,486.5 pthread_create
0.0 154,726 19 8,143.5 4,985.0 1,004 47,713 10,904.9 mmap
0.0 82,936 1 82,936.0 82,936.0 82,936 82,936 0.0 pthread_cond_wait
0.0 78,328 8 9,791.0 4,877.5 2,074 40,677 12,827.6 munmap
0.0 71,358 7 10,194.0 9,879.0 3,968 14,551 3,520.2 open
0.0 51,719 3 17,239.7 13,793.0 3,372 34,554 15,874.2 fread
0.0 42,644 29 1,470.5 1,295.0 1,000 5,372 801.5 fcntl
0.0 41,205 1 41,205.0 41,205.0 41,205 41,205 0.0 fgets
0.0 36,755 15 2,450.3 2,072.0 1,083 7,000 1,497.0 read
0.0 33,701 12 2,808.4 2,360.0 1,384 6,014 1,224.1 write
0.0 30,013 3 10,004.3 11,976.0 5,640 12,397 3,785.5 pipe2
0.0 26,410 2 13,205.0 13,205.0 10,057 16,353 4,451.9 socket
0.0 12,546 2 6,273.0 6,273.0 5,777 6,769 701.4 fwrite
0.0 10,654 2 5,327.0 5,327.0 4,160 6,494 1,650.4 pthread_cond_broadcast
0.0 10,412 1 10,412.0 10,412.0 10,412 10,412 0.0 pthread_mutex_trylock
0.0 9,158 1 9,158.0 9,158.0 9,158 9,158 0.0 connect
0.0 5,650 1 5,650.0 5,650.0 5,650 5,650 0.0 bind
0.0 3,085 1 3,085.0 3,085.0 3,085 3,085 0.0 listen

[5/8] Executing 'cudaapisum' stats report

Time (%) Total Time (ns) Num Calls Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
-------- --------------- --------- ------------- ------------- ----------- ------------- ------------- ----------------------
77.1 1,961,601,754 2 980,800,877.0 980,800,877.0 674,626,715 1,286,975,039 432,995,652.3 cudaMalloc
22.2 564,147,432 2 282,073,716.0 282,073,716.0 275,230,367 288,917,065 9,677,957.0 cudaFree
0.7 17,105,648 3 5,701,882.7 657,298.0 434,348 16,014,002 8,931,253.0 cudaMemcpy
0.0 415,480 2 207,740.0 207,740.0 42,497 372,983 233,688.9 cudaLaunchKernel
0.0 1,368 1 1,368.0 1,368.0 1,368 1,368 0.0 cuModuleGetLoadingMode

[6/8] Executing 'gpukernsum' stats report

Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) GridXYZ BlockXYZ Name
-------- --------------- --------- --------- --------- -------- -------- ----------- ---------------- -------------- -------------------------
62.5 556,575 1 556,575.0 556,575.0 556,575 556,575 0.0 131072 1 1 256 1 1 reduce0(float *, float *)
37.5 334,111 1 334,111.0 334,111.0 334,111 334,111 0.0 131072 1 1 256 1 1 reduce1(float *, float *)

[7/8] Executing 'gpumemtimesum' stats report

Time (%) Total Time (ns) Count Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Operation
-------- --------------- ----- ------------ ------------ ---------- ---------- ----------- ------------------
99.7 15,820,690 1 15,820,690.0 15,820,690.0 15,820,690 15,820,690 0.0 [CUDA memcpy HtoD]
0.3 46,176 2 23,088.0 23,088.0 23,040 23,136 67.9 [CUDA memcpy DtoH]

[8/8] Executing 'gpumemsizesum' stats report

Total (MB) Count Avg (MB) Med (MB) Min (MB) Max (MB) StdDev (MB) Operation
---------- ----- -------- -------- -------- -------- ----------- ------------------
134.218 1 134.218 134.218 134.218 134.218 0.000 [CUDA memcpy HtoD]
1.049 2 0.524 0.524 0.524 0.524 0.000 [CUDA memcpy DtoH]

可以看见,步骤 4 中调用的是 osrtsum(OS Runtime Summary),关注操作系统层面的性能数据。而如果在命令行中加入

1
2
3
4
5
6
7
8
9
10
11
@└────> # nsys profile --stats=true --trace=cuda,nvtx,cudnn,cublas ./a.out
...
[4/7] Executing 'cudaapisum' stats report

Time (%) Total Time (ns) Num Calls Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
-------- --------------- --------- ------------- ------------- -------- ----------- ------------- ----------------------
94.6 282,480,873 2 141,240,436.5 141,240,436.5 269,662 282,211,211 199,362,781.2 cudaMalloc
5.2 15,504,039 3 5,168,013.0 826,866.0 541,742 14,135,431 7,767,320.2 cudaMemcpy
0.1 392,989 2 196,494.5 196,494.5 96,115 296,874 141,958.1 cudaFree
0.1 318,607 2 159,303.5 159,303.5 28,009 290,598 185,678.5 cudaLaunchKernel
0.0 1,184 1 1,184.0 1,184.0 1,184 1,184 0.0 cuModuleGetLoadingMode

则使用的是 cudaapisum(CUDA API Summary),关注的是 CUDA API 层面的性能数据。
除此之外,可以看到 gpukernsum 中核函数执行时间。reduce1 函数性能优于 reduce0,他们的执行次数,执行时间最大值最小值和平均值。

若算子是用 pybind 绑定,用 python 调用的,可以使用 torch 的函数来只监控该算子。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
# warmup
for i in range(10):
_ = torch.matmul(x, y) # x 和 y 是矩阵,这里就不展开了

# start profiling
torch.cuda.cudart().cudaProfilerStart()

### benchmarking
for i in range(100): # 测试 100 次
torch.cuda.nvtx.range_push("your_ops_name")
_ = your_ops(x, y)
torch.cuda.nvtx.range_pop()
torch.cuda.synchronize()

# stop profiling
torch.cuda.cudart().cudaProfilerStop()

安装 Nsight Systems 可视化工具

在进行完命令行分析后,会生成一个报告文件,结尾是 .nsys-rep。这个文件可以下载下来,丢进 nsight-system 可视化软件,在软件中可以看到更加详细的数据以及程序执行的时间线。
1从上图可以看出 cuda hardware 的函数执行时间情况和 cpu 侧的执行时间情况。下面是核函数的发射时间,上面 device 侧是核函数实际执行时间,和 gpukernsum 统计的时间一致。
2上图是在软件中可以看见的较为详细的统计数据,和命令行结果一致。

模型打印

已Llama-7B hugging face版本为例:

1
2
3
4
5
6
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
device = torch.device("cuda:{}".format(gpu))
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code = True)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code = True).half().to(device)
print(model)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(32000, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaFlashAttention2(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=4096, bias=False)
(v_proj): Linear(in_features=4096, out_features=4096, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
(up_proj): Linear(in_features=4096, out_features=11008, bias=False)
(down_proj): Linear(in_features=11008, out_features=4096, bias=False)
(act_fn): SiLUActivation()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
(norm): LlamaRMSNorm()
)
(lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

从结构可以看出来,模型参数量为 32,0004,096+32(4,0964,0964+4,09611,0083)+4,09632,000=6,738,149,37632,000 * 4,096 + 32 * (4,096 * 4,096 * 4 + 4,096 * 11,008 * 3) + 4,096 * 32,000 = 6,738,149,376。所以约为7B7B

模型图解

以输入为 10 个 token 为例:
1

attribute ((visibility(“”)))

是 gcc 的编译器指令,用于设置在 shared object 中所修饰的符号对外的可见性。该修饰对 .a 文件不生效,只对 .so 库生效。

attribute ((visibility(“default”)))

该修饰用于修饰符号的可见性为默认对外可见。意思是通过该符号修饰的函数可以在 so 文件外访问到。

func.cpp:

1
2
3
__attribute__ ((visibility("default"))) void func1(int a) {
cout << a << endl;
}

main.cpp:

1
2
3
4
5
6
#include <iostream>
using namespace std;
extern void func1(int a);
int main() {
func1(10);
}
1
2
@└────> g++ func.cpp --shared -fPIC -o libfunc.so
@└────> g++ main.cpp -L./ -lfunc

之后发现是可以编译成功的。因为该符号是可见的。

1
2
@└────> nm libfunc.so | grep func
0000000000001179 T _Z5func1i

大写的 T 表示定义在 text 段,并且可被外部引用。如果你是通过编译 .o 文件再链接为 .so 文件的,还可以使用 readelf -s 查看 .o 文件的可见性.

attribute ((visibility(“hidden”)))

该修饰用于修饰符号的可见性为默认对外不可见。意思是通过该符号修饰的函数不可以在 so 文件外访问到,只能在 so 文件内部访问到。

func.cpp:

1
2
3
__attribute__ ((visibility("hidden"))) void func2(int a) {
cout << a << endl;
}

main.cpp:

1
2
3
4
5
6
#include <iostream>
using namespace std;
extern void func2(int a);
int main() {
func2(10);
}
1
2
3
4
5
@└────> g++ func.cpp --shared -fPIC -o libfunc.so
@└────> g++ main.cpp -L./ -lfunc
/usr/bin/ld: /tmp/cc7GABC5.o: in function `main':
fstream.cpp:(.text+0xe): undefined reference to `func2(int)'
collect2: error: ld returned 1 exit status

之后发现是可以编译失败,因为符号不可见.

1
2
@└────> nm libfunc.so | grep func
00000000000011ef T _Z5func2i

其他

1
2
__attribute__ ((visibility("internal")))
__attribute__ ((visibility("protected")))

上述两种一样是用于修饰符号, internal 对外不可见,而 protected 对外可见。
此外,在编译 so 文件时可以通过指定 -fvisibility=xxx 来指定默认的没有给出修饰的符号属性。
如:

1
@└────> gcc -fPIC -shared -o libtest.so -fvisibility=hidden test.c

这样在 test.c 中没用经过修饰的符号对外都不可见,而修饰为 default 的依旧对外可见。

概念

在 ELF 文件中,查看可以获得它的节的名字。其中有几个带有 plt 和 got 的节。

在此处,给出各节的定义如下:

  • .got:Global Offset Table,全局偏移表。这是链接器为外部符号填充的实际偏移表。
  • .plt:Procedure Linkage Table,程序链接表。他有两个作用,要么在 .got.plt 中拿到链接地址跳转,要么触发链接器去寻找地址。
  • .got.plt:是 .got 的一部分(但是是两个不同的节),是 got 专门为 plt 准备的节,包含了 plt 表需要的地址。(新版 gcc 可能将他叫为 .plt.got)
  • .rela.plt:程序链接表的重定位表,记录所有全局函数的动态链接信息,用于在程序加载时修正 plt 表中的跳转指针,使它们指向正确的地址。

实验

接下来将使用 gdb 一步一步跟着汇编走完动态链接的过程。

准备工作

实验代码如下:

1
2
3
4
5
6
#include <stdio.h>
int main() {
puts("hello");
printf("hello");
return 0;
}

查看节的地址与大小:

1
2
3
4
5
6
@└────> # objdump -h plt | grep -E "plt|got"
plt: file format elf64-x86-64
9 .rela.plt 00000030 0000000000400468 0000000000400468 00000468 2**3
11 .plt 00000030 00000000004004c0 00000000004004c0 000004c0 2**4
20 .got 00000020 0000000000600fe0 0000000000600fe0 00000fe0 2**3
21 .got.plt 00000028 0000000000601000 0000000000601000 00001000 2**3

查看需要动态链接的符号:

1
2
3
4
5
6
7
8
9
10
11
12
13
@└────> # readelf -r plt

Relocation section '.rela.dyn' at offset 0x408 contains 4 entries:
Offset Info Type Sym. Value Sym. Name + Addend
000000600fe0 000100000006 R_X86_64_GLOB_DAT 0000000000000000 _ITM_deregisterTMClone + 0
000000600fe8 000400000006 R_X86_64_GLOB_DAT 0000000000000000 __libc_start_main@GLIBC_2.2.5 + 0
000000600ff0 000500000006 R_X86_64_GLOB_DAT 0000000000000000 __gmon_start__ + 0
000000600ff8 000600000006 R_X86_64_GLOB_DAT 0000000000000000 _ITM_registerTMCloneTa + 0

Relocation section '.rela.plt' at offset 0x468 contains 2 entries:
Offset Info Type Sym. Value Sym. Name + Addend
000000601018 000200000007 R_X86_64_JUMP_SLO 0000000000000000 puts@GLIBC_2.2.5 + 0
000000601020 000300000007 R_X86_64_JUMP_SLO 0000000000000000 printf@GLIBC_2.2.5 + 0

反汇编查看 plt 相关函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
@└────> # objdump -d plt
Disassembly of section .plt:

00000000004004c0 <.plt>:
4004c0: ff 35 42 0b 20 00 pushq 0x200b42(%rip) # 601008 <_GLOBAL_OFFSET_TABLE_+0x8>
4004c6: ff 25 44 0b 20 00 jmpq *0x200b44(%rip) # 601010 <_GLOBAL_OFFSET_TABLE_+0x10>
4004cc: 0f 1f 40 00 nopl 0x0(%rax)

00000000004004d0 <puts@plt>:
4004d0: ff 25 42 0b 20 00 jmpq *0x200b42(%rip) # 601018 <puts@GLIBC_2.2.5>
4004d6: 68 00 00 00 00 pushq $0x0
4004db: e9 e0 ff ff ff jmpq 4004c0 <.plt>

00000000004004e0 <printf@plt>:
4004e0: ff 25 3a 0b 20 00 jmpq *0x200b3a(%rip) # 601020 <printf@GLIBC_2.2.5>
4004e6: 68 01 00 00 00 pushq $0x1
4004eb: e9 d0 ff ff ff jmpq 4004c0 <.plt>

开始

  1. 首先断点到 puts 函数,查看调用处:
1
2
3
4
5
6
7
8
9
10
11
12
13
@(gdb) disassemble main
Dump of assembler code for function main:
0x00000000004005d6 <+0>: push %rbp
0x00000000004005d7 <+1>: mov %rsp,%rbp
=> 0x00000000004005da <+4>: mov $0x400698,%edi
0x00000000004005df <+9>: callq 0x4004d0 <puts@plt>
0x00000000004005e4 <+14>: mov $0x400698,%edi
0x00000000004005e9 <+19>: mov $0x0,%eax
0x00000000004005ee <+24>: callq 0x4004e0 <printf@plt>
0x00000000004005f3 <+29>: mov $0x0,%eax
0x00000000004005f8 <+34>: pop %rbp
0x00000000004005f9 <+35>: retq
End of assembler dump.

可以看到,调用处实际上是使用 call 指令走到 puts 的代码段。下面的 printf 也是如出一辙。

  1. 查看 puts@plt 的汇编指令
1
2
3
4
5
6
@(gdb) disassemble
Dump of assembler code for function puts@plt:
=> 0x00000000004004d0 <+0>: jmpq *0x200b42(%rip) # 0x601018 <puts@got.plt>
0x00000000004004d6 <+6>: pushq $0x0
0x00000000004004db <+11>: jmpq 0x4004c0
End of assembler dump.

可以看到,在汇编中,他首先要跳转到 0x601018 地址的位置。这个地址内容是个全局变量,实际上根据节的地址位置和大小可以判断,是处于 .got.plt 的位置内( 0x601000 ~ 0x601028)。所以可以认为,在 .got.plt 中,存在了 puts 函数的地址。

  1. 查看 .got.plt
1
2
3
4
5
@(gdb) x/16x 0x601018
0x601018 <puts@got.plt>: 0x004004d6 0x00000000 0x004004e6 0x00000000
0x601028: 0x00000000 0x00000000 0x00000000 0x00000000
0x601038: 0x00000000 0x00000000 0x00000000 0x00000000
0x601048: 0x00000000 0x00000000 0x00000000 0x00000000

查看表中内容,发现跳转的地址是 0x4004d6,这不就是我们跳转之前的下一个地址吗!(puts@plt 的第二条指令) 同理,printf 函数也是如此(0x4004e6)。这是因为,之前没有调用过 puts 函数,第一次查找的时候,.got.plt 表中找不到函数的地址,那就先返回继续执行去调用链接器获取地址。

  1. 准备调用链接器
1
2
3
4
00000000004004d0 <puts@plt>:
4004d0: ff 25 42 0b 20 00 jmpq *0x200b42(%rip) # 601018 <puts@GLIBC_2.2.5>
4004d6: 68 00 00 00 00 pushq $0x0
4004db: e9 e0 ff ff ff jmpq 4004c0 <.plt>

首先 pushq $0x0,这个是在 got.plt 中的编号,如 puts 是 0,printf 是 1。这个参数是给后续链接器使用的。然后跳到了 .plt 的位置执行(0x4004c0)。可以看到,printf@plt 函数最后也是跳到这个位置执行。

  1. 调用链接器
1
2
3
4
00000000004004c0 <.plt>:
4004c0: ff 35 42 0b 20 00 pushq 0x200b42(%rip) # 601008 <_GLOBAL_OFFSET_TABLE_+0x8>
4004c6: ff 25 44 0b 20 00 jmpq *0x200b44(%rip) # 601010 <_GLOBAL_OFFSET_TABLE_+0x10>
4004cc: 0f 1f 40 00 nopl 0x0(%rax)

首先 push 了 0x601008 到栈中,这是 .got.plt 表中的一个地址。之后跳转到 0x601010 所存储的地址去执行相应的代码。不难看出,0x601010 也是存储在 .got.plt 表中的。查看一下存储的内容:

1
2
3
4
@(gdb) x/10x 0x601010
0x601010: 0xf7de64a0 0x00007fff 0x004004d6 0x00000000
0x601020 <printf@got.plt>: 0x004004e6 0x00000000 0x00000000 0x00000000
0x601030: 0x00000000 0x00000000

可以看到,是让我们跳转到 0x00007ffff7de64a0 去执行相应的代码。那么这块代码是什么呢?

1
2
3
4
5
@(gdb) info sharedlibrary
From To Syms Read Shared Object Library
0x00007ffff7dd0fa0 0x00007ffff7df2cd4 Yes (*) /lib64/ld-linux-x86-64.so.2
0x00007ffff7a2cb90 0x00007ffff7b798ad Yes (*) /lib64/libc.so.6
(*): Shared library is missing debugging information.

可以看到,该地址是 ld-linux-x86-64.so 加载的位置。说明执行的是链接器的代码。

1
2
3
4
5
6
7
8
9
10
1: x/5i $pc
=> 0x7ffff7de64a0 <_dl_runtime_resolve_xsavec>: endbr64
0x7ffff7de64a4 <_dl_runtime_resolve_xsavec+4>: push %rbx
0x7ffff7de64a5 <_dl_runtime_resolve_xsavec+5>: mov %rsp,%rbx
0x7ffff7de64a8 <_dl_runtime_resolve_xsavec+8>: and $0xffffffffffffffc0,%rsp
0x7ffff7de64ac <_dl_runtime_resolve_xsavec+12>:
sub 0x21616d(%rip),%rsp # 0x7ffff7ffc620 <_rtld_local_ro+384>
@(gdb) bt
#0 0x00007ffff7de64a0 in _dl_runtime_resolve_xsavec () from /lib64/ld-linux-x86-64.so.2
#1 0x00000000004005e4 in main () at plt.c:3

可以看到这里代码执行的是 ld 中的 _dl_runtime_resolve_xsavec 函数是第一次函数调用时用于查找函数符号的,并且在结尾处会直接去调用找到的函数符号(本文中为 puts 函数)。

  1. 写回 .got.plt 表
    在 puts 上打个断点,这样继续的话就是执行完 _dl_runtime_resolve_xsavec 还未执行 puts 的状态了。
1
2
3
4
5
6
7
@(gdb) bt
#0 0x00007ffff7a7d8c0 in puts () from /lib64/libc.so.6
#1 0x00000000004005e4 in main () at plt.c:3
@(gdb) x/10x 0x601018
0x601018 <puts@got.plt>: 0xf7a7d8c0 0x00007fff 0x004004e6 0x00000000
0x601028: 0x00000000 0x00000000 0x00000000 0x00000000
0x601038: 0x00000000 0x00000000

可以看到,此时,got.plt 表中的地址已经被写为 puts 函数实际的地址了(0x00007ffff7a7d8c0 在 0x00007ffff7a2cb90 ~ 0x00007ffff7b798ad 范围内,属于 /lib64/libc.so.6),这样下次调用 puts 就不用再次调用链接器了。

题外话

其实看一下 .got.plt 表的内容,会发现明明 puts 是第一个需要被链接的函数,为什么第一个却不是它呢?

1
2
3
4
5
6
@(gdb) x/10x 0x601000
0x601000: 0x0000000000600e10 0x00007ffff7ffe1d0
0x601010: 0x00007ffff7de64a0 0x00007ffff7a7d8c0
0x601020 <printf@got.plt>: 0x00000000004004e6 0x0000000000000000
0x601030: 0x0000000000000000 0x0000000000000000
0x601040: 0x0000000000000000 0x0000000000000000

puts 地址实际上是 got[3]:0x00007ffff7a7d8c0,前面还有 3 项。其中:

  • got[0]:0x0000000000600e10 自身模块 dynamic 段地址
1
2
@(gdb) info symbol 0x0000000000600e10
_DYNAMIC in section .dynamic of /root/xxx/plt
  • got[1]:0x00007ffff7ffe1d0 本模块的 link_map 的地址。编译期间会初始化为 0。link_map 是一个双向链表的入口,链接进程所有加载的动态库。当链接器查找符号时,通过遍历该链表找到对应的符号。

  • got[2]:0x00007ffff7de64a0 _dl_runtime_resolve_xsavec 的地址。

1
2
@(gdb) info symbol 0x00007ffff7de64a0
_dl_runtime_resolve_xsavec in section .text of /lib64/ld-linux-x86-64.so.2

_dl_runtime_resolve 格式:

1
2
3
4
//调用形式为:
_dl_runtime_resolve((link_map*)(got[1]), 0);
// 第二个参数 0,为 <puts@plt>:中的 pushq $0x0;
// 同理如果是 printf,就是<printf@plt>:中 pushq $0x1;

总结

虚拟地址空间内流程图:
1

第二次调用:
2

为什么要使用内嵌汇编?

内嵌汇编通常用于在程序中实现一些高效、精确的操作。例如,在嵌入式平台上运行的程序,如果需要代码占用内存更小、程序运行的效率更高或需要准确地操作寄存器时,嵌入汇编会是不错的选择。

基本语法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
asm("assembly code"        /* 汇编代码 */
:output_operand /* 输出参数列表 */
:input_operand /* 输入参数列表 */
:clobbered_operand /* 被改变的操作对象列表 */
);

// 举例
static int value_assignment(int input) {
int ret = 0;
asm volatile(
"movl %1, %0\n" // 超过一条指令就要用 \n 来分割,排版整齐还要加 \t
:"=r"(ret)
:"r"(input)
);
return ret;
}

被改变的操作对象列表

在被改变的参数列表 clobbered_operand 中有一个比较有用的标识符:memory。指定 memory,相当于对编译器形成了一个内存读写的屏障,保证在内联汇编执行前,编译器将某些寄存器里的值刷新进内存,同时在内联汇编执行后,编译器重新加载相关变量的值
所以我们可以见到这样的代码:

1
asm volatile ("" ::: "memory");

作为内存屏障,保证编译器的优化不会跨过这道屏障。加上 volatile 告诉编译器不要优化汇编。

修饰符

修饰符一般跟在参数列表前面。

修饰符 含义
= 只写,常用于修饰所有输出操作数
只读
+ 可读可写
r 可以是任意通用寄存器存储其值
m 一个有效的内存地址
i 是立即数
% 被修饰的操作数可以和下一个互换
& 只能做输出,一般和 “=” 一起使用,如 “=&r(val)”
x 只能做输入

占位符

%0 表示输入和输出列表合并的第 1 个操作数,%1 表示第 2 个,以此类推。

硬件结构

现代机器都是多个处理器,每个处理器有自己的 cache。这个结构如下所示:
1
可以看到,每个 CPU 都有自己的缓存,之后再写到内存中。并且由于编译器的优化,你写的代码可能和你执行的代码顺序有所不同。他们优化的规则是:保证对于一个单核情况下,执行结果不会发生变化。但是多线程就不一定了。

那么在多线程情况下,如何协调这些 CPU 缓存的数据一致性就成了一个问题。

常见优化

再谈保证数据的一致性之前,先谈谈编译器能做的优化。

重排 Reordering

编译器和 CPU 都会发生重排,为了提升代码的效率。采用乱序执行、流水线、分支预测以及多级缓存等方法来提升程序性能。编译器会基于这些规则来提升自己代码的速度,所以就会对指令进行优化。例子如下:

1
2
3
4
5
6
7
8
9
10
11
12
int a = 0;
int b = 0;

void fun() {
a = b + 1; // L5
b = 1; // L6
}

int main() {
fun();
return 0;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
@└────> # gcc 1.c -O0 -g
@└────> # objdump -d a.out
0000000000400536 <fun>:
400536: 55 push %rbp
400537: 48 89 e5 mov %rsp,%rbp
40053a: 8b 05 e4 0a 20 00 mov 0x200ae4(%rip),%eax # 601024 <b>
400540: 83 c0 01 add $0x1,%eax
400543: 89 05 d7 0a 20 00 mov %eax,0x200ad7(%rip) # 601020 <__TMC_END__>
400549: c7 05 d1 0a 20 00 01 movl $0x1,0x200ad1(%rip) # 601024 <b>
400550: 00 00 00
400553: 90 nop
400554: 5d pop %rbp
400555: c3 retq

@└────> # gcc 1.c -O2 -g
@└────> # objdump -d a.out
0000000000400560 <fun>:
400560: 8b 05 ba 0a 20 00 mov 0x200aba(%rip),%eax # 601020 <__TMC_END__>
400566: c7 05 b0 0a 20 00 01 movl $0x1,0x200ab0(%rip) # 601020 <__TMC_END__>
40056d: 00 00 00
400570: 83 c0 01 add $0x1,%eax
400573: 89 05 ab 0a 20 00 mov %eax,0x200aab(%rip) # 601024 <a>
400579: c3 retq
40057a: 66 0f 1f 44 00 00 nopw 0x0(%rax,%rax,1)
  1. 对于 O0 等级的优化,执行顺序是 L5->L6。
  2. 但是对于 O2 等级的优化,执行顺序是 L6->L5,但是结果是不影响的。

为什么要这么做呢?因为 CPU 读取数据从 cache 中读取。如果不优化的话,先读 b,再读 a 的时候可能把 b 的缓存换出去了,那么再写 b 的时候还需要把 b 换进来。但是如果优化了,就是读 b,写 b,再写 a,就不存在缓存的换入换出了。

插入 Invention

假设有如下代码:

1
2
3
for (int i = 0; i < n; ++i) {
x[i] = y[i] + z[i];
}

可能优化成如下:

1
2
3
4
5
for (int i = 0; i < n; ++i) {
__builtin_prefetch(&y[i + 16]);
__builtin_prefetch(&z[i + 16]);
x[i] = y[i] + z[i];
}

预读取这些数据来减少缓存未命中次数。

删除 Removal

删除很好理解了,删除没用的变量赋值。

1
2
3
4
5
int x = 1;
int y = 2;
int z = x + y;
x = 3;
y = 4;

优化后:

1
2
3
4
int x;
int y;
x = 3;
y = 4;

关系术语

sequence-before

sequence-before 是对一个线程内,求值顺序关系的描述:

  • A sequence-before B,先对 A 求值,再对 B 求值。
  • A not sequence-before B,并且 B not sequence-before A,那么 A 和 B 谁先求值是未知的。

synchronizes-with

描述的是不同线程内的执行关系。在两个线程分别执行时,即使线程 A 先执行,线程 B 后执行,A 中写了某个共享变量,由于指令重排或者写到了 cache寄存器没来得及写入内存导致 B 读到了错误的值。

  • A synchronizes-with B,在线程 A 中的写操作结果对线程 B 可见。

happens-before

是 sequence-before 的扩展,包括了不同线程的关系。

  • A happens-before B,那么不但 A 先于 B 执行,并且 A 的结果对 B 可见。
    • 同线程:和 sequence-before 一样。
    • 不同线程:和 synchronizes-with 一样。

内存序

C++11 中引入了 6 种内存序:

1
2
3
4
5
6
7
8
typedef enum memory_order {
memory_order_relaxed,
memory_order_consume,
memory_order_acquire,
memory_order_release,
memory_order_acq_rel,
memory_order_seq_cst
} memory_order;
内存序类型 用于读/写 含义
memory_order_relaxed 读/写 仅要求原子性内存一致性
memory_order_consume 读操作所在线程该操作后面的和该变量 有依赖关系的 读写操作不会被优化到先于该操作执行
memory_order_acquire 读操作所在线程该操作后面的读写操作不会被优化到先于该操作执行
memory_order_release 写操作所在线程该操作前面的读写操作不会被优化到后于该操作执行
memory_order_acq_rel 读/写 是 memory_order_acquire 和 memory_order_release 组成的双向屏障,上下皆不能跨过该指令
memory_order_seq_cst 读/写 双向屏障,并且该线程所有原子指令并且也指定为 memory_order_seq_cst 的都已全局内存修改顺序为参照

值得一提的是,若一个原子变量在一个线程中施加了 memory_order_release,但是在其他线程中没有使用 memory_order_acquire 或 memory_order_consume 读取,那么他就不会具备 memory_order_release 所赋予的屏障功能。(即只有被观测才会起作用,读操作也是如此)

2
如上图所示,就像加锁一样会构成临界区。但是外面的变量可以移入临界区,却不能移出去,所以称 memory_order_acquire 和 memory_order_release 如同单向屏障一般。

内存模型

一言以蔽之,引入内存模型的原因,有以下几个原因:

  1. 编译器优化:在某些情况下,即使是简单的语句,也不能保证是原子操作。
  2. CPU out-of-order:CPU 为了提升计算性能,可能会调整指令的执行顺序。
  3. CPU Cache 不一致:在 CPU Cache 的影响下,在某个 CPU 下执行了指令,不会立即被其它 CPU 所看到。

从上面的内存序中,按照访问控制的角度可以分为三种模型:

  1. Sequential Consistency 模型
  2. Acquire-Release 模型
  3. Relax 模型

其中,Sequential Consistency 模型约束最强,Acquire-Release 次之,Relax 模型最弱。

Sequential Consistency 模型

对应 memory_order_seq_cst 内存序。Sequential Consistency 模型有以下特点:

  • 每个线程的执行顺序与代码顺序严格一致
  • 线程的执行顺序可能会交替进行,但是从单个线程的角度来看,仍然是顺序执行

例如:

1
2
3
4
5
6
7
8
9
x = y = 0;

thread1:
x = 1;
r1 = y;

thread2:
y = 1;
r2 = x;

那么可能的执行顺序为:

可能性 第一步 第二步 第三步 第四步
1 x = 1 r1 = y y = 1 r2 = x
2 y = 1 r2 = x x = 1 r1 = y
3 x = 1 y = 1 r1 = y r2 = x
4 x = 1 r2 = x y = 1 r1 = y
5 y = 1 x = 1 r1 = y r2 = x
6 y = 1 x = 1 r2 = x r1 = y

std::atomic 默认值都是使用 memory_order_seq_cst,保证不出错。但是相对的,限制了 CPU 并行处理的能力,会降低效率。这个模型的所有线程都参考全局的内存修改顺序。因此,我们可认为所有变量的读写都直接从内存进行,从而完全不用考虑 Cache,Store Buffer 这些因素。

Acquire-Release 模型

对应 memory_order_consume、memory_order_acquire、memory_order_release、memory_order_acq_rel 内存序。对于一个原子变量 A,对 A 的写操作(Release)和读操作(Acquire)之间进行同步,并建立排序约束关系,即对于写操作(release)X,在写操作 X 之前的所有读写指令都不能放到写操作 X 之后;对于读操作(acquire)Y,在读操作 Y 之后的所有读写指令都不能放到读操作 Y 之前。

Relax 模型

对应的是 memory_order_relaxed 内存序。其对于内存序的限制最小,也就是说这种方式只能保证当前的数据访问是原子操作(不会被其他线程的操作打断),但是对内存访问顺序没有任何约束,也就是说对不同的数据的读写可能会被重新排序。

本文用以记录常用汇编指令以供快速查找回忆,仅限于 X86_64 的 AT&T 格式。

语法格式

1. 引用寄存器前加 %。如

1
mov    %rsp, %rbp

2. 指令长度后缀

对于访问内存的数据,指令后加上 b w l q,操作 1 2 4 8 字节。如

1
2
3
4
movb   $0x1,0x201c3f(%rip)
nopw %cs:0x0(%rax,%rax,1)
movl $0x5,-0xc(%rbp)
movq $0x400b30,-0x18(%rbp)

3. 立即数前加 $。16 进制数用 0x 开头。如

1
2
movl   $1, %eax
mov $0x0,%eax

4. 注释可以用 ! 开头,也可以用 ;

5. 操作数顺序

从源操作数到目的操作数,如下将 %rsp 寄存器中的数传给 %rbp 寄存器。

1
mov    %rsp,%rbp

6. 数据声明

命令 数据类型
.ascii 文本字符串
.asciz 以空字符串结尾的文本字符串
.byte 字节值
.double 双精度浮点数
.float 单精度浮点数
.single 单精度浮点数同上
.int 32位整数
.long 32位整数同上
.octa 16字节整数
.quad 8字节整数
.short 16位整数
.comm 声明未初始化的数据的通用内存区域
.lcomm 声明未初始化的数据的本地通用内存区域

7. 文件组成

命令 作用
.org 定义当前汇编位置
.globl 让段全局可见
.text 存放代码指令正文段
.bss 存放未初始化的全局和静态变量,运行时该区域初始化为 0
.rodata read only data
.data 可读可写的数据段

8. 寻址方式

  • 直接寻址:把某个地址上的值放到寄存器中
1
mov    $0x8000,%eax
  • 间址寻址:把寄存器上的值所代表的地址所指向的值放到寄存器中
1
2
movl   $0x8000,%ebx  
movl (%ebx),%eax ; 间址寻址, 把地址 0x8000(在寄存器 %ebx 中)上的值放到 %eax 中
  • 基址寻址:以寄存器里的数值作为基址,加上一个常数得到最终地址,把地址上的值放到寄存器中
1
2
movl   $0x8000,%eax  
movl 4(%eax),%ebx ; 基址寻址, 把地址 0x8004(0x8000+4)上的值放到 %eax 中
  • 变址寻址:以两个寄存器里的数值之和加上一个常数得到最终地址,把地址上的值放到寄存器中
1
2
3
4
movl   $0x8000,%eax
movl $0x4,%ebx
movl (%eax,%ebx),%ecx ; 变址寻址, 把地址 0x8004(0x8000+4)上的值放到 %ecx 中
movl 4(%eax,%ebx),%ecx ; 变址寻址, 把地址 0x8008(0x8000+4+4)上的值放到 %ecx 中
  • 比例变址寻址:以一个寄存器里的数值加上另一个寄存器里的数字,乘以一个比例因子(1,2,4,8)再加上一个常数得到最终地址,把地址上的值放到寄存器中
1
2
3
4
5
6
movl   $0x2000,%eax   
movl $0x2,%ebx
movl (,%eax,4),%ecx ; 比例变址寻址, 把地址 0x8000(0 + 0x2000*4)上的值放到 %ecx 中
movl 6(,%eax,4), %ecx ; 比例变址寻址, 把地址 0x8006(0 + 0x2000*4 + 6)上的值放到 %ecx 中
movl (%ebx,%eax,4),%ecx ; 比例变址寻址, 把地址 0x8002(0x2 + 0x2000*4)上的值放到 %ecx 中
movl 6(%ebx,%eax,4),%ecx ; 比例变址寻址, 把地址 0x8008(0x2 + 0x2000*4 + 6)上的值放到 %ecx 中

常见指令

1. mov 用于将源操作数移动到目的操作数

1
mov    %rsp,%rbp      ; %rbp = %rsp

2. add 用于将源操作数加给目的操作数

1
addl   %eax,%ebx      ; %ebx = %ebx + %eax

3. sub 用于将两个数相减

1
subl   %eax,%ebx      ; %ebx = %ebx - %eax

4. inc 用于加一

1
incl   %eax           ; %eax = %eax + 1

5. dec 用于减一

1
decl   %eax           ; %eax = %eax - 1

6. push 用于将数据压入栈

1
pushl  %eax           ; 入栈,%esp = %esp - 0x4, %esp = %eax 

7. pop 用于将数据出栈

1
popl   %eax           ; 出栈,%eax = %esp, %esp = %esp + 0x4

8. jmp 跳转

1
2
3
4
5
6
7
8
9
10
11
jmp    label          ; 无条件跳转为 label, %rip = label
je label ; 相等 ZF = 1, %rip = label
jne label ; 不相等 ZF = 0, %rip = label
jg label ; 大于 %rip = label
jge label ; 大于等于 %rip = label
jl label ; 小于 %rip = label
jle label ; 小于等于 %rip = label
ja label ; 无符号比较 大于 %rip = label
jae label ; 无符号比较 大于等于 %rip = label
jb label ; 无符号比较 小于 %rip = label
jbe label ; 无符号比较 小于等于 %rip = label

9. mul 乘法

1
2
imull  %eax,%ebx      ; %ebx = %eax * %ebx  用于有符号数
mull %eax,%ebx ; %ebx = %eax * %ebx 用于无符号数

10. div 除法

1
2
idivl  %ebx           ; %edx = %eax % %ebx, %eax = %eax / %ebx  用于有符号数
divl %ebx ; %edx = %eax % %ebx, %eax = %eax / %ebx 用于无符号数

11. and 按位与

1
andl   %eax,%ebx      ; %ebx = %ebx & %eax

12. or 按位或

1
orl    %eax,%ebx      ; %ebx = %ebx | %eax

13. xor 按位异位

1
xorl   %eax,%ebx      ; %ebx = %eax ^ %ebx

14. shl 和 sal 位左移

1
2
shll   $1,%eax        ; %eax = %eax << 1  逻辑左移,填充 0
sall $1,%eax ; %eax = %eax << 1 算数左移,填充 0

15. shr 和 sar 位右移

1
2
shrl   $1,%eax        ; %eax = %eax >> 1  逻辑右移,填充 0
sarl $1,%eax ; %eax = %eax >> 1 算数右移,填充 符号位

16. lea 装载有效地址

1
leal   8(%ebx),%eax   ; %eax = 8 + %ebx 可理解为 %eax = &(*(%ebx)) + 8

17. call 函数调用

1
call   func_name      ; 将下一条指令的 %rip push 到栈中,之后 %rip = func_name 

18. ret 函数返回

1
ret                   ; 将函数返回地址的下一条要执行指令的值赋值给 %rip,push %rip

19. test 与运算并设置标志寄存器

1
testl  %eax,%ebx      ; %eax & %ebx,不会改变这两个寄存器值,改变标志寄存器零标志位(ZF)、符号标志位(SF)、奇偶标志位(PF)和进位标志位(CF),但不会影响溢出标志位(OF)

20. cmp 比较操作数大小

1
cmpl   %eax,%ebx      ; 根据 %ebx - %eax 的值来改变零标志位(ZF)、符号标志位(SF)、奇偶标志位(PF)、进位标志位(CF)和溢出标志位(OF)

21. rep 重复执行指令直到某一条件

1
2
repz   movsb          ; 重复执行 movsb 直到 ZF = 0
repne scasb ; 重复执行 scasb 直到 ZF = 1

22. lock 锁定总线

1
lock addl $1,(%eax)   ; 锁定总线,并使 *(%eax) = *(%eax) + 1,因为总线是锁定的,不会被其他处理器打断

23. xadd 交换两个操作数值,使他们相加

1
xaddl  %eax,%ebx      ; tmp = %eax,%eax = %ebx,%ebx = tmp + %ebx 交换两个数,并将和写到 %ebx

24. nop 空操作

1
nop                   ; 什么都不做,充当占位符或者插入延迟

25. hlt 使处理器暂停直到收到中断信号

1
hlt                   ; 使处理器进入暂停状态,直到发生外部中断。它通常用于操作系统内核中,以降低功耗和发热量。只有特权级别为 0 (内核态) 才能使用,否则会导致异常

26. xchg 交换两个操作数的值

1
xchgl  %eax,%ebx       ; tmp = %eax,%eax = %ebx,%ebx = tmp

27. cld 清除方向寄存器(DF)

1
cld                    ; 清除方向寄存器,使 %rdi 递增

28. movsb 移动字符串

1
movsb                  ; 以 %rsi 为源地址,%rdi 为目的地址,将字符以一个字节拷贝。每次执行 movsb,%rsi 和 %rdi 以方向标志寄存器(DF)自动递增或递减

29. scasb 查找字符

1
scasb                  ; 将被查找字符放到 %al 中,与 %rdi 地址的字符串依次比较,根据比较结果设置标志寄存器

30. cli 禁用所有中断

1
2
cli                    ; 禁用所用中断
hlt ; 使处理器保持暂停状态,直到中断被重新启用

问题

今天调试代码的时候看到地址的时候突然感到奇怪:我记得我之前看到的代码地址空间好多都是 0x400xxx 开头的,怎么这次的地址空间是 0x5562b845axxx 呢?是什么导致了这个差异?

我换了地址空间为 0x400xxx 开头的机器,准备了相同的代码,在两台不同的机器上编译:

1
2
3
4
#include <stdio.h>
int main() {
printf("%p\n", main);
}

这个简单的程序可以打出 main 函数的地址。经测试,在不同的机上打出的结果有很大差异。

1
2
3
4
5
@└────> # ./a.out 
0x5562b845a649

@└────> # ./b.out
0x400596

答案

经查阅资料,这个问题是 Linux 的 ASLR (Address Space Layout Randomization)导致的。这项技术会在装载时,装载到随机地址,防止黑客利用固定地址注入恶意代码。对于 b.out,没有使用该技术。所以 b.out 的代码段虚拟地址一直是 0x400000 开头。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
@└────> # readelf -h b.out 
ELF Header:
Magic: 7f 45 4c 46 02 01 01 00 00 00 00 00 00 00 00 00
Class: ELF64
Data: 2's complement, little endian
Version: 1 (current)
OS/ABI: UNIX - System V
ABI Version: 0
Type: EXEC (Executable file) // 这里是 EXEC
Machine: Advanced Micro Devices X86-64
Version: 0x1
Entry point address: 0x4004b0 // 这里是 _start 的绝对地址
Start of program headers: 64 (bytes into file)
Start of section headers: 15608 (bytes into file)
Flags: 0x0
Size of this header: 64 (bytes)
Size of program headers: 56 (bytes)
Number of program headers: 9
Size of section headers: 64 (bytes)
Number of section headers: 30
Section header string table index: 29

可以看到,对于 b.out,他的文件类型是 Executable file,_start 的地址是 0x400xxx 开头。这种就是没有使用 ASLR 技术的。而对于 a.out,结果如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
@└────> # readelf -h a.out 
ELF Header:
Magic: 7f 45 4c 46 02 01 01 00 00 00 00 00 00 00 00 00
Class: ELF64
Data: 2's complement, little endian
Version: 1 (current)
OS/ABI: UNIX - System V
ABI Version: 0
Type: DYN (Shared object file) // 这里是 DYN
Machine: Advanced Micro Devices X86-64
Version: 0x1
Entry point address: 0x560 // 这里是 _start 的相对地址
Start of program headers: 64 (bytes into file)
Start of section headers: 12744 (bytes into file)
Flags: 0x0
Size of this header: 64 (bytes)
Size of program headers: 56 (bytes)
Number of program headers: 9
Size of section headers: 64 (bytes)
Number of section headers: 31
Section header string table index: 30

对于 a.out,文件类型为 Shared object file,而且 _start 的地址是个相对地址。就是这个导致的这个差异。每次装载 a.out 时,代码会被加载到随机的位置。可以看到,每次运行,得到的地址都不同。

1
2
3
4
5
6
@└────> # ./a.out 
0x559536d9d649
@└────> # ./a.out
0x559a7a6df649
@└────> # ./a.out
0x55ca5dbd4649

发生根因

之所以发生这个原因,是因为操作系统版本导致的。低版本操作系统默认不使用 ASLR。想要在不同的操作系统上复现这两个方式也很简单:

1
@└────> # gcc 1.c -fPIC -pie

这种方式编译出来的就是使用了 ASLR 技术的。其中 -pie 的意思是 position-independent executable,位置无关的可执行文件。编译时还需要加上 -fPIC (Position-Independent Code)生成位置无关代码。而

1
@└────> # gcc 1.c -no-pie

方式编出来的就是固定地址。有些工具必须使用 -no-pie 才可以使用。这样固定的情况也比较好调试,因为虚拟地址固定。

Linux 中常用的文件描述符

  • 0 文件描述符,表示标准输入。
  • 1 文件描述符,表示标准输出。
  • 2 文件描述符,表示标准错误。

标准情况下,这些文件描述符和以下设备关联:

  • 0 文件描述符关联键盘,并返回给前端。
  • 1 正确返回值,返回给前端。
  • 2 错误返回值,返回给前端。

> 符号

在 shell 中,我们经常使用 > 符号,把输出重定位到一个文件。例如:

1
cat /proc/xxx/maps > memory.txt

以上输出是把某个进程的内存布局重定向到一个文件。其中,> 是 1> 的简写,实际意思是把标准输出重定向到后面的文件。这样屏幕上就不会有打印了,打印会重定向到文件中。

>& 符号

本质上,>& 符号不是一个符号。我们经常见到 2>&1 符号,实际意义是,将标准错误重新定位到标准输出。那为什么要加个 & 呢?因为不加 & 的话操作系统不会认为你是想把标准错误重定位给标准输出,而是想重定向到一个叫 “1” 的文件。所以 &1 表示 1 输出通道。举例,strace 命令可以查看系统调用,这个结果是输出到标准错误的。

1
strace ls > log 2>&1

将标准输出重定向到 log 文件,并将标准错误重定向到标准输出。这样标准错误也会被重定向到 log 文件。

&> 符号

&> 意思是把标准错误和标准输出都重定向到某个文件。

1
strace ls &> log

写起来比较简单,且省力。