一起读读大模型源码:浅谈LLAMA2核心函数generate源码

1,167次阅读
没有评论

今天是2023年11月6日,星期一,北京,天气晴。我们今天来看看Llama2的生成函数,在谈LLAMA2的generate源码之前,先介绍Temperature超参数及sample_top_p的原理。供大家一起参考。

一、Temperature

Temperature 是一个超参数,可用于控制生成语言模型中生成文本的随机性和创造性。用于调整模型的softmax输出层中预测词的概率。

softmax函数:

Temperature 参数(T)添加到softmax函数:

Temperature参数通常设置为 0.1 到 1.0 之间(T=1时形变为标准的Softmax函数),下图分别显示了在5:0.5和5:0.1时的图像(紫线为softmax,黑线为添加T参数的softmax),可以看到:

  • 当T值更大时,函数图像会变的更加的平缓,预测词的概率被拉平,这意味着所有词被选择的可能性更大。这会产生更有创意和多样化的文本,因为模型更有可能生成不寻常或意想不到的词。

  • 当T值更小时,函数图像会变的更加的陡峭,预测词的概率会变尖锐,这意味着选择最有可能的词的概率更高。这会产生更保守和可预测的文本,因为模型不太可能生成意想不到或不寻常的词。

一起读读大模型源码:浅谈LLAMA2核心函数generate源码

=5:0.5

一起读读大模型源码:浅谈LLAMA2核心函数generate源码

=5:0.1

小结:Temperature 参数是文本生成模型中用于控制生成文本的随机性和创造性的一个重要的超参数。

二、sample_top_p

一起读读大模型源码:浅谈LLAMA2核心函数generate源码

平缓和陡峭的概率分布图-文献【2】

采样意味着根据当前条件概率分布随机选择输出词 ,使用采样方法时文本生成本身不再是确定性的。对单词序列进行采样时的大问题: 模型通常会产生不连贯的乱码。在LLAMA2中,缓解这一问题的方式是通过top_p(也称:nucleus sampling)

def sample_top_p(probs, p):
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort[mask] = 0.0
    # 归一化
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    # multinomial为多项式抽样函数
    next_token = torch.multinomial(probs_sort, num_samples=1)
    next_token = torch.gather(probs_idx, -1, next_token)
    return next_token_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

sample_top_p函数的作用:每个时间步,按照字出现的概率由高到底排序,当概率之和大于top-p的时候,就不取后面的样本了。然后对取到的这些字的概率重新归一化后,进行采样。这样做的好处是,既保证了质量,又增加了适当的随机性。

三、核心函数generate()

这一块直接在代码中进行注释:

def generate(
        self,
        prompt_tokens: List[List[int]],  # 输入的提示
        max_gen_len: int,  # 最大生成长度
        temperature: float = 0.6,  # 影响生成文本的随机性
        top_p: float = 0.9,  # 用于决定采样过程中保留的 token 集合的概率阈值
        logprobs: bool = False,  # 是否返回每个 token 的对数概率
        echo: bool = False,  # 是否返回输入的提示
)
 -> Tuple[List[List[int]], Optional[List[List[float]]]]:

    # ---------------------------初始化长度为 total_len tokens张量,并填充 pad_id----------------------------------
    params = self.model.params
    bsz = len(prompt_tokens)
    assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

    min_prompt_len = min(len(t) for t in prompt_tokens)
    max_prompt_len = max(len(t) for t in prompt_tokens)
    assert max_prompt_len <= params.max_seq_len
    total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)

    pad_id = self.tokenizer.pad_id
    tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=“cuda”)
    # 将prompt_tokens中的token复制到tokens张量中。
    for k, t in enumerate(prompt_tokens):
        tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=“cuda”)
    if logprobs:
        # 创建一个与tokens相同形状的token_logprobs张量,并用0填充
        token_logprobs = torch.zeros_like(tokens, dtype=torch.float)

    prev_pos = 0
    eos_reached = torch.tensor([False] * bsz, device=“cuda”)
    input_text_mask = tokens != pad_id
    # ————————————————————-

    for cur_pos in range(min_prompt_len, total_len):
        # 调用模型的forward方法获取logits
        logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
        if logprobs:
            # 计算token level的logprobs
            token_logprobs[:, prev_pos + 1: cur_pos + 1] = -F.cross_entropy(
                input=logits.transpose(12),
                target=tokens[:, prev_pos + 1: cur_pos + 1],
                reduction=“none”,
                ignore_index=pad_id,
            )
        # 根据温度参数和top_p参数对logits进行softmax和采样,得到下一个token
        if temperature > 0:
            # sample_top_p函数对probs进行采样
            probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
            next_token = sample_top_p(probs, top_p)
        else:
            # 将logits中概率最大的token作为下一个token。
            next_token = torch.argmax(logits[:, -1], dim=-1)

        next_token = next_token.reshape(-1)
        # only replace token if prompt has already been generated
        next_token = torch.where(
            input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
        )
        # tokens张量更新
        tokens[:, cur_pos] = next_token
        eos_reached |= (~input_text_mask[:, cur_pos]) & (
                next_token == self.tokenizer.eos_id
        )
        prev_pos = cur_pos
        # 检查是否已经生成了所有的eos token,如果是则停止生成
        if all(eos_reached):
            break

    if logprobs:
        # token_logprobs列表化
        token_logprobs = token_logprobs.tolist()
    out_tokens, out_logprobs = [], []
    for i, toks in enumerate(tokens.tolist()):
        # cut to max gen len
        # 对于 tokens 张量中的每一行(即每一个生成的序列),如果 echo 参数为假,则去掉提示部分
        start = 0 if echo else len(prompt_tokens[i])
        toks = toks[start: len(prompt_tokens[i]) + max_gen_len]
        probs = None
        if logprobs:
            probs = token_logprobs[i][start: len(prompt_tokens[i]) + max_gen_len]
        # cut to eos tok if any
        # 存在结束标记,则去掉结束标记之后的部分
        if self.tokenizer.eos_id in toks:
            eos_idx = toks.index(self.tokenizer.eos_id)
            toks = toks[:eos_idx]
            probs = probs[:eos_idx] if logprobs else None
        out_tokens.append(toks)
        out_logprobs.append(probs)
    # 返回生成的tokens和对数概率(如果logprobs参数为真)
    return (out_tokens, out_logprobs if logprobs else None)

总结

本文介绍了Temperature以及sample_top_p的原理,并且阅读了LLAMA2的核心生成函数的源码。关于更多细节实现,请关注llama源码。

参考文献

【1】https://github.com/facebookresearch/llama/blob/main/llama/generation.py

【2】The Curious Case of Neural Text Degeneration

关于我们

老刘,刘焕勇,NLP开源爱好者与践行者,主页:https://liuhuanyong.github.io。


老刘说NLP,将定期发布语言资源、工程实践、技术总结等内容,欢迎关注。


对于想加入更优质的知识图谱、事件图谱、大模型AIGC实践、相关分享的,可关注公众号,在后台菜单栏中点击会员社区->会员入群加入。

 

Read More 

正文完
可以使用微信扫码关注公众号(ID:xzluomor)
post-qrcode
 0
评论(没有评论)

文心AIGC

2023 年 11 月
 12345
6789101112
13141516171819
20212223242526
27282930  
文心AIGC
文心AIGC
人工智能ChatGPT,AIGC指利用人工智能技术来生成内容,其中包括文字、语音、代码、图像、视频、机器人动作等等。被认为是继PGC、UGC之后的新型内容创作方式。AIGC作为元宇宙的新方向,近几年迭代速度呈现指数级爆发,谷歌、Meta、百度等平台型巨头持续布局
文章搜索
热门文章
潞晨尤洋:日常办公没必要上私有模型,这三类企业才需要 | MEET2026

潞晨尤洋:日常办公没必要上私有模型,这三类企业才需要 | MEET2026

潞晨尤洋:日常办公没必要上私有模型,这三类企业才需要 | MEET2026 Jay 2025-12-22 09...
面向「空天具身智能」,北航团队提出星座规划新基准丨NeurIPS’25

面向「空天具身智能」,北航团队提出星座规划新基准丨NeurIPS’25

面向「空天具身智能」,北航团队提出星座规划新基准丨NeurIPS’25 鹭羽 2025-12-13 22:37...
钉钉又发新版本!把 AI 搬进每一次对话和会议

钉钉又发新版本!把 AI 搬进每一次对话和会议

钉钉又发新版本!把 AI 搬进每一次对话和会议 梦晨 2025-12-11 15:33:51 来源:量子位 A...
商汤Seko2.0重磅发布,合作短剧登顶抖音AI短剧榜No.1

商汤Seko2.0重磅发布,合作短剧登顶抖音AI短剧榜No.1

商汤Seko2.0重磅发布,合作短剧登顶抖音AI短剧榜No.1 十三 2025-12-15 14:13:14 ...
MEET2026挤爆了,AI圈今年最该听的20+场演讲&对谈都在这

MEET2026挤爆了,AI圈今年最该听的20+场演讲&对谈都在这

MEET2026挤爆了,AI圈今年最该听的20+场演讲&对谈都在这 西风 2025-12-11 15:...
最新评论
ufabet ufabet มีเกมให้เลือกเล่นมากมาย: เกมเดิมพันหลากหลาย ครบทุกค่ายดัง
tornado crypto mixer tornado crypto mixer Discover the power of privacy with TornadoCash! Learn how this decentralized mixer ensures your transactions remain confidential.
ดูบอลสด ดูบอลสด Very well presented. Every quote was awesome and thanks for sharing the content. Keep sharing and keep motivating others.
ดูบอลสด ดูบอลสด Pretty! This has been a really wonderful post. Many thanks for providing these details.
ดูบอลสด ดูบอลสด Pretty! This has been a really wonderful post. Many thanks for providing these details.
ดูบอลสด ดูบอลสด Hi there to all, for the reason that I am genuinely keen of reading this website’s post to be updated on a regular basis. It carries pleasant stuff.
Obrazy Sztuka Nowoczesna Obrazy Sztuka Nowoczesna Thank you for this wonderful contribution to the topic. Your ability to explain complex ideas simply is admirable.
ufabet ufabet Hi there to all, for the reason that I am genuinely keen of reading this website’s post to be updated on a regular basis. It carries pleasant stuff.
ufabet ufabet You’re so awesome! I don’t believe I have read a single thing like that before. So great to find someone with some original thoughts on this topic. Really.. thank you for starting this up. This website is something that is needed on the internet, someone with a little originality!
ufabet ufabet Very well presented. Every quote was awesome and thanks for sharing the content. Keep sharing and keep motivating others.
热评文章
预见未来:96位前沿先锋超万字核心观点总结,抢抓未来产业新高地

预见未来:96位前沿先锋超万字核心观点总结,抢抓未来产业新高地

预见未来:96位前沿先锋超万字核心观点总结,抢抓未来产业新高地 henry 2025-12-11 10:27:...
Meta公开抄阿里Qwen作业,还闭源了…

Meta公开抄阿里Qwen作业,还闭源了…

Meta公开抄阿里Qwen作业,还闭源了… Jay 2025-12-11 11:48:25 来源:量子位 Ja...
MEET2026挤爆了,AI圈今年最该听的20+场演讲&对谈都在这

MEET2026挤爆了,AI圈今年最该听的20+场演讲&对谈都在这

MEET2026挤爆了,AI圈今年最该听的20+场演讲&对谈都在这 西风 2025-12-11 15:...
钉钉又发新版本!把 AI 搬进每一次对话和会议

钉钉又发新版本!把 AI 搬进每一次对话和会议

钉钉又发新版本!把 AI 搬进每一次对话和会议 梦晨 2025-12-11 15:33:51 来源:量子位 A...