多頭注意力機制 (Multi-Head Attention) 核心筆記:拆解與並行#
一、 運作原理#
多頭注意力的關鍵在於轉置 (Transpose) 操作。transpose(1, 2) 讓張量形狀從 (b, tokens, heads, dim) 變為 (b, heads, tokens, dim)。這個操作的目的是:在邏輯上,將一個 token 的嵌入向量資訊分派給不同的注意力頭;在計算上,則將所有頭的運算集合起來,在張量層面進行並行處理,以高效地捕捉輸入序列中不同子空間的特徵關係。
二、 實作步驟詳解:從混雜到有序的並行#
「多頭」並不是真的有 num_heads 個獨立的 for 迴圈在跑。它的「多」是體現在張量的維度上,並透過高效的矩陣運算來實現並行處理 (Parallel Processing)。
1. 準備階段:生成潛在資訊 (Projection with nn.Linear)
- 程式碼:
keys = self.W_key(x) - 輸入形狀:
(b, num_tokens, d_in) - 輸出形狀:
(b, num_tokens, d_out) - 目的: 將每個輸入 token 的
d_in維向量,投影到一個更高維(或相同維度)的d_out空間。可以想像d_out這個長向量中,已經混雜地包含了未來所有注意力頭所需要的全部資訊。例如,如果d_out=512,這 512 個維度裡可能同時包含了語法、語意、位置等多方面的潛在特徵。
2. 邏輯分組:定義「頭」的邊界 (Reshape with .view)
- 程式碼:
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) - 形狀變化:
(b, num_tokens, d_out)->(b, num_tokens, num_heads, head_dim) - 目的: 這是第一次在邏輯上體現「多頭」 的地方。此操作不改變數據本身,只改變解讀它的方式。
- 比喻: 就像我們宣告,一條長度為 512 的繩子 (
d_out),現在起要被視為 8 (num_heads) 段長度為 64 (head_dim) 的短繩拼接而成。至此,對於每一個 token,我們都定義了它在 8 個不同「頭」(或稱子空間)中的獨立表徵。
3. 組織隊形:為並行計算而轉置 (Restructure with .transpose)
- 程式碼:
keys = keys.transpose(1, 2) - 形狀變化:
(b, num_tokens, num_heads, head_dim)->(b, num_heads, num_tokens, head_dim) - 目的: 這是實現高效並行計算的關鍵。它重新組織了數據的排列順序。
- 轉置前: 資料以 “Token” 為主進行組織。先看第 1 個 token 在 8 個頭裡的樣子,再看第 2 個 token 在 8 個頭裡的樣子…
- 轉置後: 資料以 “Head” 為主進行組織。先看第 1 個頭對所有 token 的樣子,再看第 2 個頭對所有 token 的樣子…
- 為何如此重要?
- 後續的矩陣乘法
queries @ keys.transpose(2, 3)會在最後兩個維度上運算。 - 經過
transpose(1, 2)後,num_heads維度被推到了前面,與b(批次大小) 一樣,被 PyTorch 當作批次維度 (Batch Dimension) 來處理。 - 這使得一個簡單的矩陣乘法指令,就能同時完成所有
num_heads個頭部的注意力計算,無需使用for迴圈,從而極大地發揮了 GPU 的並行計算優勢。
- 後續的矩陣乘法
三、 總結#
多頭注意力機制是一個「先整合,再拆分,後重組」的過程:
- 整合 (
nn.Linear): 將輸入 token 投影到一個富含資訊的空間。 - 拆分 (
.view+.transpose): 在邏輯上將資訊拆成多份給不同的「頭」,並在記憶體佈局上重組它們,使其適合並行處理。 - 重組 (
.view+out_proj): 將所有頭計算出的上下文向量拼接起來,並透過一個最終的線性層進行資訊融合,得出最終輸出。
