从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

424次阅读
没有评论

西风 发自 凹非寺

量子位 | 公众号 QbitAI

让大神Andrej Karpathy一键三连❤️(点赞+转发+评论),一个教你从头开始实现Llama3的代码库爆火。

X上转赞收藏量超6.8k,GitHub揽星2k+。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

火就火在,它教你从头用Meta开源的权重进行推理,详细解释和展开了注意力机制中多个头的矩阵乘法、位置编码以及所有中间层

换句话说,他解释了每行代码都在干啥。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

Karpathy看后直呼打造者Nishant Aklecha(后文暂称“纳哥”)是个有品的人:

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

完全展开后,比起模块相互嵌套和调用时,更容易理解每一步具体在做什么。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

网友们对其也是赞不绝口,纷纷致敬:

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

话不多说,一起来看纳哥是如何手把手教的。

(量子位在不改变原意的基础上,进行了编译整理)

从头实现llama3

在运行纳哥提供的文件前,大伙儿需要预先下载Meta官方提供的Llama3模型权重。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

纳哥表示自己没搞分词器,推荐用Karpathy的现成简洁版BPE代码。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

PS:

“字节级(byte-level)”BPE算法,在UTF-8编码的字符串上运行,广泛应用于大模型分词。Karpathy提供的这个代码库包含两个分词器,都能在给定文本上训练分词器的词汇表和合并规则、将文本编码为token、将token解码为文本。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

读取模型文件的方式通常取决于model classes的编写方式以及class中变量的命名。但由于纳哥是从头开始实现Llama3,所以将逐个张量地读取文件内容。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

通过此配置可以推断出模型的结构和参数信息,例如模型包含的Transformer层数、多头注意力块中的头数,以及词汇表的大小等细节。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

将文本转换为token时,纳哥使用tiktoken作为分词器。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

接下来,纳哥展示了在代码中将token转换为高维的嵌入表示。这是代码库中唯一使用内置神经网络模块的部分。

[17×1]的token矩阵变成了[17×4096]的嵌入矩阵。也就是说,每个token被转换为一个长度为4096的嵌入向量,总共有17个这样的嵌入向量。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

然后,纳哥对嵌入进行RMS归一化。经过这一步后,嵌入的形状不会改变,只有数值被归一化了。纳哥强调需要一个norm_eps,避免意外将RMS值设为0导致除以0的错误。

以下是公式:

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

构建Transformer的第一层,进行归一化处理,从模型字典中访问layer.0(即第一层)。归一化之后,张量的形状仍然是[17×4096],与嵌入时相同,但数值已被归一化。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

跟着纳哥从头实现注意力机制,加载Transformer第一层的注意力头。

从模型中加载query、key、value和output向量时,它们的形状分别是 [4096×4096]、[1024×4096]、[1024×4096] 和 [4096×4096]。

纳哥表示乍一看有点奇怪,因为理想情况是每个注意力头的q、k、v和o向量是独立的。而代码作者将它们捆绑在一起,是为了方便并行计算注意力头的矩阵乘法。

把所有这些向量解包开来:

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

下一步,纳哥将从多个注意力头中解包query,解包后的形状是[32x128x4096],32是Llama3中的注意力头数量,128是query向量的大小,4096是token嵌入的大小。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

在这里,纳哥访问了第一层第一个注意力头的query权重矩阵,query权重矩阵的大小是[128×4096]。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

将query权重矩阵与token嵌入相乘,获得每个token的query向量。结果的形状为[17×128],有17个token,每个token对应一个长度为128的query向量。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

接下来需要位置编码。

现在已经为prompt中的每个token生成了query向量,但每个单独的query向量并不知道它在prompt中的具体位置。

例如,query:“the answer to the ultimate question of life, the universe, and everything is ”(生命、宇宙和一切的终极问题的答案是)。

在这个prompt中,使用了三次”the”,需要根据它们在prompt中的位置,使这三个”the”token的query向量有所不同(每个向量的大小为[1×128])。

通过使用RoPE(旋转位置嵌入)来进行这些旋转操作。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

上一步中,纳哥将query向量分成对,并对每一对应用一个旋转角度偏移。

由此,得到的向量大小为 [17x64x2],这是将长度为128的query向量对每个prompt中的token分成64对。这64对中的每一对都会根据m*(theta) 进行旋转,其中m是要旋转query的token的位置。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

使用复数的点积来旋转一个向量:

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

现在每个token的query元素都有一个复数(角度变化向量),可以将query向量(之前分成的对)转换为复数,然后通过点积根据位置旋转query向量。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

获得旋转后的向量后,可以通过将复数重新视为实数来得到成对的query向量。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

旋转后的对现在已经合并,有一个新的query向量(旋转后的query向量),其形状为[17×128],其中17是token的数量,128是query向量的维度。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

key与query几乎相同。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

纳哥表示自己不会详细讲解key的数学原理,只需要记住以下几点:

key生成的key向量维度也是128;key的权重只有query的四分之一,这是因为key的权重在同一时间内被4个头共享,来减少计算量;key也会旋转添加位置信息,原因与query相同。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

此时,纳哥已经为每个token获得了旋转后的query和key。每个query和key现在的形状都是[17×128]。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

下一步,纳哥将对query矩阵和key矩阵进行相乘操作。这样做会生成一个评分矩阵,将每个token关联起来。这些评分描述了每个token的query与每个token的key之间的相关性,这就是自注意力机制。

注意力评分矩阵(qk_per_token)的形状为[17×17],其中17是prompt中的token数量。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

接下来需要对query key评分进行掩码处理。在Llama3的训练过程中,未来token的qk评分是被掩码的,只通过过去的token来预测token。

因此,在推理时,要将未来的token评分设置为0。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

接下来是value,接近注意力机制的最后一步。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

这些评分(0-1)用于确定每个token使用多少value矩阵。

和key一样,value的权重也在每4个注意力头之间共享,所以下面value权重矩阵的形状是[8x128x4096]。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

第一层,第一个注意力头的value权重矩阵如下所示:

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

然后是value向量。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

使用value权重来获取每个token的注意力值,矩阵的大小是[17×128],其中17是prompt中的token数量,128是每个token的value向量的维度。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

注意力:与每个token的value相乘后得到的注意力向量的形状为[17×128]。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

现在有了第一层第一个头的注意力value。然后纳哥运行一个循环,对第一层的每个头执行与上面的计算完全相同的数学运算。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

然后得到了第一层所有32个头的qkv_attention矩阵,接下来将所有注意力得分合并成一个大小为[17×4096]的大矩阵。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

对于第0层注意力机制的最后步骤,其一是将注意力得分矩阵与权重矩阵相乘。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

这是一个简单的线性层,所以只需进行矩阵乘法。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

现在得到了注意力机制后的嵌入value变化,应该被添加到原始的token嵌入中。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

对嵌入增量进行归一化处理,然后通过嵌入增量运行一个前馈神经网络。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

在Llama3中,加载前馈权重并实现前馈网络。使用了一种名为SwiGLU的前馈网络,这种网络结构在模型需要的时候,能够有效地增加非线性。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

现在完成了第一层之后每个token的新嵌入。现在只剩下31层了,只需通过一个循环来完成。

纳哥表示可以将这个编辑后的嵌入想象成包含了第一层中所有查询信息的嵌入。随着层数的增加,每一层都会对输入的信息进行越来越复杂的处理,直到最终得到一个能够全面了解下一个需要预测的token的嵌入。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

之前做的所有事情,对每一层都重复一次。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

然后得到了最终的嵌入,这是模型对下一个token的最优预测。这个嵌入的形状与常规的token嵌入相同,为[17×4096],其中17是token的数量,4096是嵌入的维度。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

最后,将嵌入解码成token值。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

使用输出解码器将最终的嵌入转换成一个token。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

接下来看纳哥使用最后一个token的嵌入来预测下一个value,希望预测的结果是42。

因为根据《银河系漫游指南》一书中的说法,42是“生命、宇宙及一切的终极问题的答案”。大多数LLM在这里都会回答42,这将验证整个代码的正确性。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

模型预测下一个token的编号为2983。这个编号对应数字42吗?

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+
从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

OK,结束。

“让研究变得更加触手可及”

简单介绍一下Nishant Aklecha。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

Nishant Aklecha是构建和改进定制语言模型平台Glaive AI的研究员,曾任职于摩根士丹利,负责训练和微调大语言模型。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

此外,他还和朋友一同创立了一个研究实验室,名为A10(AAAAAAAAAA)。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

他们的目标可以总结成一句话:让研究变得更加触手可及。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

除了放出这个代码库,Nishant Aklecha可谓好人做到底。

网友想更好地理解这个代码库的内容,Nishant直接一个YouTube视频甩了过来:

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

之前Nishant Aklecha还曾写过一篇Blog,详解了潜在一致性模型(LCM),同样收获了不少好评。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

啥也不说了,感兴趣的家人们赶紧码住吧。

从零复现Llama3代码库爆火,大神Kapathy一键三连,GitHub狂揽2k+

GitHub链接:https://github.com/naklecha/llama3-from-scratch

参考链接:
[1]https://x.com/naklecha/status/1792244347225641338
[2]https://naklecha.notion.site/explained-latent-consistency-models-13a9290c0fd3427d8d1a1e0bed97bde2
[3]https://www.youtube.com/watch?v=o29P0Kpobz0&t=530s
[4]https://www.youtube.com/watch?v=eMlx5fFNoYc

Read More 

正文完
可以使用微信扫码关注公众号(ID:xzluomor)
post-qrcode
 
评论(没有评论)
Generated by Feedzy