大模型中常用的注意力机制GQA详解以及Pytorch代码实现
liuian 2025-05-08 19:41 49 浏览
分组查询注意力 (Grouped Query Attention) 是一种在大型语言模型中的多查询注意力 (MQA) 和多头注意力 (MHA) 之间进行插值的方法,它的目标是在保持 MQA 速度的同时实现 MHA 的质量。
这篇文章中,我们将解释GQA的思想以及如何将其转化为代码。
GQA是在论文 GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints paper.中提出,这是一个相当简单和干净的想法,并且建立在多头注意力之上。
GQA
标准多头注意层(MHA)由H个查询头、键头和值头组成。每个头都有D个维度。Pytorch的代码如下:
from torch.nn.functional import scaled_dot_product_attention
# shapes: (batch_size, seq_len, num_heads, head_dim)
query = torch.randn(1, 256, 8, 64)
key = torch.randn(1, 256, 8, 64)
value = torch.randn(1, 256, 8, 64)
output = scaled_dot_product_attention(query, key, value)
print(output.shape) # torch.Size([1, 256, 8, 64])对于每个查询头,都有一个对应的键。这个过程如下图所示:
而GQA将查询头分成G组,每组共享一个键和值。可以表示为:
使用可视化的表示就能非常清楚的了解GQA的工作原理,就像我们上面说的那样,GQA是一个相当简单和干净的想法
Pytorch代码实现
让我们编写代码将这种将查询头划分为G组,每个组共享一个键和值。我们可以使用einops库有效地执行对张量的复杂操作。
首先,定义查询、键和值。然后设置注意力头的数量,数量是随意的,但是要保证num_heads_for_query % num_heads_for_key = 0,也就是说要能够整除。我们的定义如下:
import torch
# shapes: (batch_size, seq_len, num_heads, head_dim)
query = torch.randn(1, 256, 8, 64)
key = torch.randn(1, 256, 2, 64)
value = torch.randn(1, 256, 2, 64)
num_head_groups = query.shape[2] // key.shape[2]
print(num_head_groups) # each group is of size 4 since there are 2 kv_heads为了提高效率,交换seq_len和num_heads维度,einops可以像下面这样简单地完成:
from einops import rearrange
query = rearrange(query, "b n h d -> b h n d")
key = rearrange(key, "b s h d -> b h s d")
value = rearrange(value, "b s h d -> b h s d")然后就是需要在查询矩阵中引入”分组“的概念。
from einops import rearrange
query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups)
print(query.shape) # torch.Size([1, 4, 2, 256, 64])上面的代码我们将二维重塑为二维:对于我们定义的张量,原始维度8(查询的头数)现在被分成两组(以匹配键和值中的头数),每组大小为4。
最后最难的部分是计算注意力的分数。但其实它可以在一行中通过insum操作完成的
from einops import einsum, rearrange
# g stands for the number of groups
# h stands for the hidden dim
# n and s are equal and stands for sequence length
scores = einsum(query, key, "b g h n d, b h s d -> b h n s")
print(scores.shape) # torch.Size([1, 2, 256, 256])scores张量和上面的value张量的形状是一样的。我们看看到底是怎么操作的
einsum帮我们做了两件事:
1、一个查询和键的矩阵乘法。在我们的例子中,这些张量的形状是(1,4,2,256,64)和(1,2,256,64),所以沿着最后两个维度的矩阵乘法得到(1,4,2,256,256)。
2、对第二个维度(维度g)上的元素求和——如果在指定的输出形状中省略了维度,einsum将自动完成这项工作,这样的求和是用来匹配键和值中的头的数量。
最后是注意分数与值的标准乘法:
import torch.nn.functional as F
scale = query.size(-1) ** 0.5
attention = F.softmax(similarity / scale, dim=-1)
# here we do just a standard matrix multiplication
out = einsum(attention, value, "b h n s, b h s d -> b h n d")
# finally, just reshape back to the (batch_size, seq_len, num_kv_heads, hidden_dim)
out = rearrange(out, "b h n d -> b n h d")
print(out.shape) # torch.Size([1, 256, 2, 64])这样最简单的GQA实现就完成了,只需要不到16行python代码:
最后再简单提一句MQA:多查询注意(MQA)是另一种简化MHA的流行方法。所有查询将共享相同的键和值。原理图如下:
可以看到,MQA和MHA都可以从GQA推导出来。具有单个键和值的GQA相当于MQA,而具有与头数量相等的组的GQA相当于MHA。
GQA的好处是什么?
GQA是最佳性能(MQA)和最佳模型质量(MHA)之间的一个很好的权衡。
下图显示,使用GQA,可以获得与MHA几乎相同的模型质量,同时将处理时间提高3倍,达到MQA的性能。这对于高负载系统来说可能是必不可少的。
作者:Max Shap
相关推荐
- usb打印机改wifi打印机(usb打印机改无线网络打印机)
-
首先要把打印机通过USB端口连接到路由器上,连接成功后路由器上的USB指示灯会亮。然后在需要使用网络打印机的电脑上安装打印机的驱动程序,这样才能够正常使用打印服务器连接的打印机。登录路由器,在左侧的系...
- windows7没pdf打印机(win7系统自带的打印pdf找不到了)
-
建议安装Acrobat9,并安装9.1.3的AdobeReader/Acrobat的更新,去官网搜索即可,如果现有版本是9.1.0,则9.1.2和9.1.3的更新均需要安装.我实验的结果时9.0...
- 有两台iphone一台忘记密码(有两台iphone一台忘记锁屏密码)
-
iphone的锁屏密码输入错误次数过多,显示iphone已停用。解决办法:第一步:电脑上装好iTunes,并打开。第二步:关手机,插上数据线,注意只插手机这一端,先不接电脑。第三步:按住手机上的Hom...
- 快用苹果助手官网进不去(快用苹果助手怎么下载不了)
-
要在指定的网址上登录下载,苹果手机没有自动授信不能下载
- 复制快捷键ctrl+c(复制快捷键ctrl+c还有什么)
-
ctrl+c:复制;ctrl+v:粘贴,其他快捷键如下:Ctrl+Z撤消操作Ctrl+Y:恢复操作Delete(或Ctrl+D):删除所选的项目,将其移至回收站Shift+Delet...
- 校园网wifi免认证软件(校园网统一身份认证平台)
-
这个不存在犯法不犯法的问题,也就是说学校的网络是给你便捷使用的,反正都是给你使用的,你如何登录都没有任何的关系,其次就是你自己办的网的话,你有权利随意的更改,没办网的话那你就用学校的。1这是不道德和...
- 如何查看windows激活密钥(查看windows激活密钥命令)
-
可以按照以下步骤查看Windows系统的激活密钥:1.首先打开命令提示符,可通过在搜索栏中输入"cmd",然后右键管理员身份打开。2.在打开的命令提示符窗口中输入指令:slmgr/d...
- dlink路由器(dlink路由器无法连接网络)
-
设置D-Link无线路由器无线桥接的具体步骤如下:1、将电脑与路由器的任意lan口连接,打开浏览器输入192.168.1.1,进入路由器管理页面。点击lan口设置,将lan口ip改为192.168.2...
- c5game开箱网(c5game开箱网是正规的吗)
-
苹果c5game开箱操作很简单,首先进入c5game网站,选择打开自己的背包,然后找到自己想要开箱的物品,点击开箱按钮即可。在开箱过程中,会弹出一个开箱界面,按照界面提示进行操作,等待开箱过程结束即可...
- ps5官网(playstation 官网)
-
在官网买ps5需要玩家收到预购邀请才可以。索尼决定遴选出一批忠实玩家,率先向其提供PS5实机预定服务,数量有限,先到先得。玩家只需在PlayStation.com网站完成注册手续。若有幸等到预购邀请电...
- 电脑上dat文件用什么打开(电脑上dat文件怎么打开)
-
、打开电脑,找到“我的电脑”然后再打开硬盘C就可以看到相应的dat文件。2、硬盘C里面可以找到很多的dat文件,只是他们的文件拓展名不一样。3、然后在我的电脑当中输入“dat”就会弹出许多与dat相关...
- 一周热门
- 最近发表
- 标签列表
-
- python判断字典是否为空 (50)
- crontab每周一执行 (48)
- aes和des区别 (43)
- bash脚本和shell脚本的区别 (35)
- canvas库 (33)
- dataframe筛选满足条件的行 (35)
- gitlab日志 (33)
- lua xpcall (36)
- blob转json (33)
- python判断是否在列表中 (34)
- python html转pdf (36)
- 安装指定版本npm (37)
- idea搜索jar包内容 (33)
- css鼠标悬停出现隐藏的文字 (34)
- linux nacos启动命令 (33)
- gitlab 日志 (36)
- adb pull (37)
- python判断元素在不在列表里 (34)
- python 字典删除元素 (34)
- vscode切换git分支 (35)
- python bytes转16进制 (35)
- grep前后几行 (34)
- hashmap转list (35)
- c++ 字符串查找 (35)
- mysql刷新权限 (34)
