Encoder 和 Decoder 的数学区别
1. 主要区别概览
输入序列 x1,x2,…,xT
逐步更新隐藏状态 ht=f(ht−1,xt)
最后一个隐藏状态 hT 或上下文向量 c
上下文向量 c + 目标序列的前一 token yt′−1
逐步更新隐藏状态 st′=g(st′−1,yt′−1,c)
生成的 token yt′,通过 softmax 计算概率分布
2. Encoder(编码器)的数学公式
2.1 输入
编码器的输入是一个序列:
x=(x1,x2,…,xT) 其中,每个 xt 是输入的 token(例如单词的向量表示)。
2.2 递归计算
编码器使用递归公式计算隐藏状态:
ht=f(ht−1,xt) 其中:
f 是 RNN/LSTM/GRU 的转换函数。
最简单的 RNN 计算方式:
ht=tanh(Whht−1+Wxxt+b) LSTM 计算方式(包含遗忘门、输入门、输出门):
ft=σ(Wf[ht−1,xt]+bf) it=σ(Wi[ht−1,xt]+bi) ot=σ(Wo[ht−1,xt]+bo) c~t=tanh(Wc[ht−1,xt]+bc) ct=ft⊙ct−1+it⊙c~t ht=ot⊙tanh(ct) 最终,编码器的隐藏状态 hT 形成上下文向量 c:
c=m(h1,h2,…,hT) 其中 m 可能是直接取最后一个隐藏状态(简单情况)或某种加权平均(如注意力机制)。
3. Decoder(解码器)的数学公式
3.1 输入
解码器的输入是:
目标序列的前一个 token yt′−1。
3.2 递归计算
解码器使用自己的隐藏状态 st′ 进行更新:
st′=g(st′−1,yt′−1,c) 其中:
g 是解码器的转换函数(通常与编码器的 f 结构类似)。
RNN 版本:
st′=tanh(Wsst′−1+Wyyt′−1+Wcc+b) LSTM 版本:
ft′=σ(Wf[st′−1,yt′−1,c]+bf) it′=σ(Wi[st′−1,yt′−1,c]+bi) ot′=σ(Wo[st′−1,yt′−1,c]+bo) c~t′=tanh(Wc[st′−1,yt′−1,c]+bc) ct′=ft′⊙ct′−1+it′⊙c~t′ st′=ot′⊙tanh(ct′) 3.3 输出生成
解码器的隐藏状态 st′ 用于计算当前 token yt′ 的概率:
P(yt′∣yt′−1,…,y1,c)=softmax(Wost′+bo) softmax 计算每个可能单词的概率分布,并选择最可能的单词作为当前输出。
4. Encoder 和 Decoder 之间的核心数学区别
方面
编码器(Encoder)
解码器(Decoder)
输入序列 x=(x1,x2,...,xT)
编码器的上下文向量 c + 目标 token yt′−1
ht=f(ht−1,xt)
st′=g(st′−1,yt′−1,c)
最终隐藏状态 hT(即上下文向量 c)
数学核心区别:
输入不同:
编码器的输入是整个源语言句子 x1,x2,...,xT。
解码器的输入是**编码器输出的上下文向量 c + 目标序列的前一个 token yt′−1。
隐藏状态计算方式不同:
编码器:ht=f(ht−1,xt),仅依赖过去信息(单向)或前后信息(双向)。
解码器:st′=g(st′−1,yt′−1,c),依赖于编码器输出 + 先前已生成的 token。
输出不同:
编码器输出最终隐藏状态 hT(或经过注意力机制的上下文向量 c),用作整个输入的表示。
解码器使用 softmax 生成目标序列中的 token,逐步预测下一个单词。
5. 总结
编码器 负责 处理整个输入序列,并生成一个固定长度的上下文向量 c,用来表示整个输入序列的信息。
解码器 逐步接收上下文向量 c,并依赖前一个时间步的输出,一步步生成目标序列。
数学层面上:
编码器是一个标准的递归网络,输入的是 源语言序列,只依赖于 过去信息(或前后信息,若是双向)。
解码器依赖编码器的输出 + 先前已生成的 token,是一个 自回归模型,用 softmax 生成下一个 token。
h, x, s, y到底是啥
在 序列到序列(seq2seq) 任务中,编码器的最终隐藏状态 hT 被称为上下文变量或上下文向量(context vector)。它是整个输入序列的信息压缩,并传递给解码器,以生成目标输出。
例子:机器翻译
我们以 英语到法语翻译 为例:
1. 词向量表示
假设:
这些 xt 是输入单词的词向量(word embeddings)。
2. 编码器(Encoder)
编码器是一个 循环神经网络(RNN, LSTM, GRU),它逐步读取输入单词,并更新隐藏状态(hidden state) ht:
输入第一个单词 "I":
h1=f(x1,h0) 其中 h0 是初始隐藏状态(通常设为全零)。
输入第二个单词 "love":
h2=f(x2,h1) 输入第三个单词 "you":
h3=f(x3,h2)
最终,编码器的最后一个隐藏状态 hT 作为上下文向量:
它浓缩了整个输入句子的意思,并作为解码器的输入。
3. 解码器(Decoder)
解码器是另一个 RNN,它接收 上下文向量 ( h_T ) 作为初始状态,并逐步生成目标序列("Je t'aime"):
生成第一个单词 "Je":
s1=g(y0,hT) 其中 y0 是起始标记 <SOS>。
生成第二个单词 "t'":
s2=g(y1,s1) 其中 y1 是 "Je"。
生成第三个单词 "aime":
s3=g(y2,s2) 其中 y2 是 "t'"。
最终,解码器输出 "Je t'aime",完成翻译。
4. 关键概念总结
5. 图示
Last updated