1

LLMs源码阅读之(四)Mistral系列

 2 months ago
source link: https://zhuanlan.zhihu.com/p/688059726
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

LLMs源码阅读之(四)Mistral系列

https://github.com/sliderSun
6 人赞同了该文章

sliderSun:灵魂拷问之word2vec

sliderSun:关于CNN、RNN、LSTM、Transformer、BERT参数计算的那些疑问

sliderSun:关于Transformer的那些个为什么

sliderSun:关于BERT中的那些为什么

sliderSun:Transformer、Like-Bert、对比学习、ChatGPT相关面试集锦

sliderSun:知识盛宴:探秘LLMs、Sora和LWM的神奇世界

sliderSun:LLMs源码阅读(一)ChatGLM

sliderSun:LLMs源码阅读(二)Baichuan

sliderSun:LLMs源码阅读之(三)LLaMA

sliderSun:LLMs源码阅读之(四)Mistral系列

sliderSun:LLMs源码阅读之(五)Gemma

sliderSun:LangChain:代码世界的魔法师,源码解读带你笑看技术黑洞

混合专家(MoE)代表了神经网络架构的范式转变。它由多个“专家”网络组成,每个专家网络设计用于处理特定类型的数据或任务,并由一个“门控网络”监督,动态地将输入数据引导到最合适的专家网络。

核心组成

  • 专家层 experts:每个专家网络设计用于处理在其擅长的领域(数据或任务)。
  • 门控网络 router:MoE架构中的决策核心。负责判定专家层中哪个专家更适合处理某个特定的输入数据。

专家层和门控网络共同作用,确保合适的任务由合适的专家来处理。门控网络动态地将输入数据引导到最合适的专家网络,专家层和门控网络同时训练。

上图展示了一个内嵌在语言模型中的MoE层的高层视图。在本质上,MoE层包含多个前馈子网络,被称为“专家”,每个专家都有可能专门处理数据的不同方面。图中突出显示的门控网络确定了在给定输入情况下哪种专家组合被启用。这种条件激活使得网络能够显著增加其容量,而无需相应地增加计算需求。

MoE 层的功能

在实践中,门控网络评估输入(在图表中表示为G(x)),并选择稀疏的专家组来处理它。这种选择由门控网络的输出调节,有效地确定了每个专家对最终输出的“投票”或贡献。例如,正如图表所示,可能只选择两个专家来计算每个特定输入标记的输出,通过集中计算资源到最需要的地方,使过程更加高效。

MoE层替换了每个其他Transformer前馈层。(a) 标准Transformer模型的编码器是一堆自注意力和前馈层,与残差连接和层归一化交错排列。(b) 通过用MoE层替换每个其他前馈层,我们得到MoE Transformer编码器的模型结构。(c) 当扩展到多个设备时,MoE层在设备间进行分片,而所有其他层则被则被复制。

Mixtral 8x7B模型中,其MoE的结构示意图如下所示

class MoeLayer(nn.Module):
    def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs):
        super().__init__()
        assert len(experts) > 0
        # 定义experts,就是一组(8个)Llama FFN,
        # Llama FFN就是两个Linear + Silu + Linear
        self.experts = nn.ModuleList(experts)
        # gate也是一个Linear,这个Linear weight的维度是[hidden_dim , num_experts]
        self.gate = gate
        self.args = moe_args

    def forward(self, inputs: torch.Tensor):
        # Gate Linear 将输入线性映射到num_experts
        # 即[bst*seq_len,hidden-dim] -> [bst*seq_len,num_experts]
        gate_logits = self.gate(inputs)
        # topk排序
        # weights返回topk的值
        # selected_experts 返回topk的index
        weights, selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok)
        # 对每个weight做softmax,归一化
        weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype)
        results = torch.zeros_like(inputs)
        for i, expert in enumerate(self.experts):
            # 根据selected_experts确定weight的行id和列id
            batch_idx, nth_expert = torch.where(selected_experts == i)
            # 通过上述id选择对应的加权数据 以及执行对应的expert,并将结果加权求和
            results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(
                inputs[batch_idx]
            )
        return results

注意力机制

Mistral 使用了GQA和SWA两种方法来加速计算Attention,在Attention的计算一般是Q 与shape为[bst, multi-head,seq_len, head_dim] 的KV进行注意力计算,其中seq_len为已处理所有tokens总数,GQA在多头上做文章使得多组Q共享一组KV;而SWA则是在seq_len这个维度做文章,不在将Q与所有seq-len的KV "直接"计算注意力,而是只与Sliding Window SizeKV"直接"计算注意力。

Grouped-query Attention (GQA)

对原来Multi-head attention进行分组,各组中的Q共用一对K、V。

GQA-1等价于MQA也就是Multi-head attention只分了一组,共享一对K、V。

GQA-H等价于MHA也就是Multi-head attention分成了H组(也就是Multi-head attention数量,原来的数量),相当于没变,所以说GQA是介于MQA和MHA之间。

Sliding Window Attention (SWA)

如下示意图,为Sliding Window Size为3的情况。

prompting阶段

在LLM推理时为了满足SWA,prompting阶段可以通过一个mask的掩码操作实现,如下:

if input_ids.shape[1] > 1:
        # seqlen推理时在prompt阶段为n,在generation阶段为1
	seqlen = input_ids.shape[1]
        # mask在推理时也只在prompt阶段有,
        #定义一个全1方阵
	tensor = torch.full(
		(seqlen, seqlen),
		dtype=h.dtype,
		fill_value=1,
		device=h.device,
	)
        # 上三角部分全为0
	mask = torch.tril(tensor, diagonal=0).to(h.dtype)
	# make the mask banded to account for sliding window
        # 这里代码diagonal应该等于(-self.args.sliding_window+1)才能满足window size为  
        # self.args.sliding_window,这应该是官方代码的一个小bug?
	mask = torch.triu(mask, diagonal=-self.args.sliding_window)
	mask = torch.log(mask)
"""
tensor = torch.ones((5, 5))
sliding_window = 3

mask = torch.tril(tensor, diagonal=0)
# make the mask banded to account for sliding window
# mask = torch.triu(mask, diagonal=-sliding_window)  # actually produces a large column of slide window size
mask = torch.triu(mask, diagonal=-sliding_window+1)  
mask = torch.log(mask)
print(mask)
tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [-inf, 0., 0., 0., -inf],
        [-inf, -inf, 0., 0., 0.]])

mask = torch.tril(tensor, diagonal=0)
# make the mask banded to account for sliding window
# mask = torch.triu(mask, diagonal=-sliding_window)  # actually produces a large column of slide window size
mask = torch.triu(mask, diagonal=-sliding_window)  
mask = torch.log(mask)
print(mask)
tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [-inf, 0., 0., 0., 0.]])

"""

generation阶段

在generation阶段,因为是自回归生成所以mask起不到作用,mistral使用了RotatingBufferCache(KV Cache)的配合实现的。

在不使用sliding window的情况下,随着自回归推理的进行,KV Cache是只增不减的。而在使用SWA的情况下,超出窗口长度的kv就可以不用再缓存了,因此使用一个轮转替换的策略,就是采用一种循环右移的存储方式,剔除离得远的K,保存靠近的K 。

比如窗口大小 W=4 ,则当第5个token需要缓存是,直接替换掉第1个token,这样就可以保持kv缓存有一个最大值(为窗口大小),而不会无限增长。

# The cache is a rotating buffer
# positions[-self.sliding_window:] 取最后w个位置的索引,取余
# [None, :, None, None]操作用于扩维度[1,w,1,1]
scatter_pos = (positions[-self.sliding_window:] % self.sliding_window)[None, :, None, None]
# repeat操作repeat维度 [bsz, w, kv_head, head_dim]
scatter_pos = scatter_pos.repeat(bsz, 1, self.n_kv_heads, self.args.head_dim)
# src取[:,-w,:,:] 所以src.shape=[bsz,w,kv_head,head_dim]
# 根据scatter_pos作为index 将src写入cache
self.cache_k[:bsz].scatter_(dim=1, index=scatter_pos, src=xk[:, -self.sliding_window:])
self.cache_v[:bsz].scatter_(dim=1, index=scatter_pos, src=xv[:, -self.sliding_window:])

SWA确实限制了每个token的Q只能关注固定大小(Window Size)内的其他token,然而,信息通过网络的传播并不仅仅局限于Window Size的大小,它还设计多层Transformer之间的信息传递。

class RotatingBufferCache:
    """
    This is an example that implements a less naive rotating buffer cache, allowing for variable length sequences.
    Allocated cache is rectangular which is wasteful (see PagedAttention for better mechanisms)
    """
    def __init__(self, n_layers: int, max_batch_size: int, sliding_window: int, n_kv_heads: int, head_dim: int):

        self.sliding_window = sliding_window
        self.n_kv_heads = n_kv_heads
        self.head_dim = head_dim

        self.cache_k = torch.empty((
            n_layers,
            max_batch_size,
            sliding_window,
            n_kv_heads,
            head_dim
        ))
        self.cache_v = torch.empty((
            n_layers,
            max_batch_size,
            sliding_window,
            n_kv_heads,
            head_dim
        ))
        # holds the valid length for each batch element in the cache
        self.kv_seqlens = None

    def get_view(self, layer_id: int, metadata: RotatingCacheInputMetadata) -> CacheView:
        return CacheView(self.cache_k[layer_id], self.cache_v[layer_id], metadata, self.kv_seqlens)

    def reset(self):
        self.kv_seqlens = None

    def init_kvseqlens(self, batch_size: int):
        self.kv_seqlens = torch.zeros((batch_size,), device=self.device, dtype=torch.long)

    @property
    def device(self):
        return self.cache_k.device

    def to(self, device: torch.device, dtype: torch.dtype):
        self.cache_k = self.cache_k.to(device=device, dtype=dtype)
        self.cache_v = self.cache_v.to(device=device, dtype=dtype)

        return self

    def update_seqlens(self, seqlens: List[int]):
        self.kv_seqlens += torch.tensor(seqlens, device=self.device, dtype=torch.long)

    def get_input_metadata(self, seqlens: List[int]) -> RotatingCacheInputMetadata:
        """
            inpput = seqlens [5,7,2] // seqpos [0, 1, 3] // sliding_window 3
            --> only cache last 3 tokens in each sequence
            - to_cache_mask = [0 0 1 1 1 | 0 0 0 0 1 1 1 | 1 1]
            - cached_elements = [3 | 3 | 2]
            --> absolute positions are used for rope
            - positions = [0 1 2 3 4 | 1 2 3 4 5 6 7 | 3 4]
            --> cache positions are positions cache_masked, modulo sliding_window + batch_idx * sliding_window
            - cache_positions = [2 0 1 | 5 3 4 | 6 7]
        """
        if self.kv_seqlens is None:
            self.init_kvseqlens(len(seqlens))
        assert len(seqlens) == len(self.kv_seqlens), f"Batch size is {len(self.kv_seqlens)}, got {len(seqlens)}, did you forget to reset cache?"
        seqpos = self.kv_seqlens.tolist()

        assert len(seqlens) > 0, seqlens
        masks = [
            [x >= seqlen - self.sliding_window for x in range(seqlen)]
            for seqlen in seqlens
        ]
        to_cache_mask = torch.tensor(sum(masks, []), device=self.device, dtype=torch.bool)
        cached_elements = torch.tensor([sum(mask) for mask in masks], device=self.device, dtype=torch.long)
        positions = torch.cat([torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)]).to(device=self.device, dtype=torch.long)
        batch_idx = torch.tensor(sum([[i]*seqlen for i, seqlen in enumerate(seqlens)], []), device=self.device, dtype=torch.long)
        cache_positions = positions % self.sliding_window + batch_idx * self.sliding_window

        first_prefill = seqpos[0] == 0
        subsequent_prefill = any(seqlen > 1 for seqlen in seqlens)
        if first_prefill:
            assert all([pos == 0 for pos in seqpos]), (seqpos)
            mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(self.sliding_window)
        elif subsequent_prefill:
            mask = BlockDiagonalMask.from_seqlens(
                q_seqlen=seqlens,
                kv_seqlen=[s + cached_s.clamp(max=self.sliding_window).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens)]
            ).make_local_attention_from_bottomright(self.sliding_window)
        else:
            mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
                q_seqlen=seqlens,
                kv_padding=self.sliding_window,
                kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=self.sliding_window).tolist()
            )

        return RotatingCacheInputMetadata(
            positions=positions,
            to_cache_mask=to_cache_mask,
            cached_elements=cached_elements,
            cache_positions=cache_positions[to_cache_mask],
            prefill=first_prefill or subsequent_prefill,
            mask=mask,
            seqlens=seqlens,
        )

来源:https://zhuanlan.zhihu.com/p/659105978

滑动窗口之外的token仍会影响下一个单词的预测。在每个注意力层中,信息最多可以向前传播 W 个token:经过两个注意力层后,信息可以向前传播 2W 个token记,依此类推。例如,在长度为 16K 的序列和滑动窗口为 4K 的情况下,经过 4 层后,信息已传播到完整的序列长度。

Attention

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args

        self.n_heads: int = args.n_heads
        self.head_dim: int = args.head_dim
        self.n_kv_heads: int = args.n_kv_heads

        self.repeats = self.n_heads // self.n_kv_heads

        self.scale = self.args.head_dim**-0.5

        self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)

    def forward(
        self,
        x: torch.Tensor,
        freqs_cis: torch.Tensor,
        cache: Optional[CacheView],
    ) -> torch.Tensor:
        seqlen_sum, _ = x.shape

        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(seqlen_sum, self.n_heads, self.head_dim)
        xk = xk.view(seqlen_sum, self.n_kv_heads, self.head_dim)
        xv = xv.view(seqlen_sum, self.n_kv_heads, self.head_dim)
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        if cache is None:
            key, val = xk, xv
        elif cache.prefill:
            key, val = cache.interleave_kv(xk, xv)
            cache.update(xk, xv)
        else:
            cache.update(xk, xv)
            key, val = cache.key, cache.value
            key = key.view(
                seqlen_sum * cache.sliding_window, self.n_kv_heads, self.head_dim
            )
            val = val.view(
                seqlen_sum * cache.sliding_window, self.n_kv_heads, self.head_dim
            )

        # Repeat keys and values to match number of query heads
        key, val = repeat_kv(key, val, self.repeats, dim=1)

        # xformers requires (B=1, S, H, D)
        xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
        output = memory_efficient_attention(
            xq, key, val, None if cache is None else cache.mask
        )

        return self.wo(output.view(seqlen_sum, self.n_heads * self.head_dim))

逆向工程彩蛋

孟繁续:Mixtral-8x7B 模型挖坑

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK