Skip to content

Commit fadbb66

Browse files
committed
ADD entry what is Grouped-Query Attention
1 parent 50badd4 commit fadbb66

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed
Loading
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# 什么是 Grouped-Query Attention?
2+
---------------------------
3+
4+
- by @karminski-牙医
5+
6+
![gqa](./assets/images/grouped-query-attention.png)
7+
8+
Grouped-Query Attention(分组查询注意力)是 Transformer 架构的改进型注意力机制,在多头注意力(MHA)和多查询注意力(MQA)之间取得平衡。通过分组共享键值投影,在保持模型容量的同时显著降低计算资源消耗。
9+
10+
### 工作原理
11+
12+
给定输入向量 $Q$(查询)、$K$(键)和 $V$(值),GQA 将查询头分组处理:
13+
14+
$$
15+
\text{GroupedQuery}(Q, K, V) = \text{Concat}(\text{group}_1, \ldots, \text{group}_g)W^O
16+
$$
17+
18+
每个组内共享键值投影:
19+
20+
$$
21+
\text{group}_i = \text{Attention}(QW_i^Q, KW_{\lfloor i/m \rfloor}^K, VW_{\lfloor i/m \rfloor}^V)
22+
$$
23+
24+
其中:
25+
- $g$ 为分组数(通常 $g \ll h$)
26+
- $m = h/g$ 每组包含的头数
27+
- $W_i^Q \in \mathbb{R}^{d_{model}\times d_k}$ 为每组独立的查询投影
28+
- $W_j^K, W_j^V \in \mathbb{R}^{d_{model}\times d_k}$ 为组间共享的键值投影
29+
30+
### 核心机制
31+
32+
- **动态分组策略**
33+
- 将 $h$ 个查询头划分为 $g$ 个组
34+
- 每组包含 $m=h/g$ 个查询头共享同一组键值投影
35+
- 通过线性投影实现特征空间的分组耦合
36+
37+
- **参数效率**
38+
总参数量为:
39+
$$
40+
\underbrace{hd_kd_{model}}_{\text{查询投影}} + \underbrace{2gd_kd_{model}}_{\text{键值投影}} = (h + 2g)d_kd_{model}
41+
$$
42+
相比 MHA 减少 $3hd_kd_{model} - (h+2g)d_kd_{model}$ 参数
43+
44+
### 优点
45+
46+
- **显存优化**:键值缓存显存占用降至 MHA 的 $g/h$,例如 8 头分组为 2 组时显存减少 75%
47+
- **质量保留**:PaLM 2 实验显示,GQA(g=8)与 MHA 相比在质量指标上差异小于 0.5%
48+
- **灵活扩展**:通过调整分组数 $g$ 实现质量与效率的连续调节:
49+
- $g=h$ 时退化为标准 MHA
50+
- $g=1$ 时等价于 MQA
51+
52+
### 缺点
53+
54+
- **分组调优成本**:需要实验确定最佳分组数,不同任务/架构可能有不同最优配置
55+
- **投影偏差风险**:共享键值投影可能限制不同组的特征多样性
56+
- **实现复杂度**:需要管理分组投影的矩阵运算,可能引入额外的张量变换开销
57+
58+
### 性能对比
59+
60+
GQA 论文中使用的T5 Large 和 XXL 模型在多头注意力、5% 训练的 T5-XXL 模型在多查询和分组查询注意力下,在摘要数据集 CNN/Daily Mail、arXiv、PubMed、MediaSum 和 MultiNews,翻译数据集 WMT,以及问答数据集 TriviaQA 上的推理时间和平均开发集性能比较。
61+
62+
| 模型 | 推理时间 (s) | 平均 | CNN | arXiv | PubMed | MediaSum | MultiNews | WMT | TriviaQA |
63+
|--------------|--------------|------|------|-------|--------|----------|-----------|------|----------|
64+
| MHA-Large | 0.37 | 46.0 | 42.9 | 44.6 | 46.2 | 35.5 | 46.6 | 27.7 | 78.2 |
65+
| MHA-XXL | 1.51 | 47.2 | 43.8 | 45.6 | 47.5 | 36.4 | 46.9 | 28.4 | 81.9 |
66+
| MQA-XXL | 0.24 | 46.6 | 43.0 | 45.0 | 46.9 | 36.1 | 46.5 | 28.5 | 81.3 |
67+
| GQA-8-XXL | 0.28 | 47.1 | 43.5 | 45.4 | 47.7 | 36.3 | 47.2 | 28.4 | 81.6 |
68+
69+
70+
## Refs
71+
[GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints](https://arxiv.org/pdf/2305.13245)

0 commit comments

Comments
 (0)