快轉到主要內容
Background Image

多頭注意力機制 (Multi-Head Attention) 核心筆記

多頭注意力機制 (Multi-Head Attention) 核心筆記:拆解與並行
#

一、 運作原理
#

多頭注意力的關鍵在於轉置 (Transpose) 操作。transpose(1, 2) 讓張量形狀從 (b, tokens, heads, dim) 變為 (b, heads, tokens, dim)。這個操作的目的是:在邏輯上,將一個 token 的嵌入向量資訊分派給不同的注意力頭;在計算上,則將所有頭的運算集合起來,在張量層面進行並行處理,以高效地捕捉輸入序列中不同子空間的特徵關係。


二、 實作步驟詳解:從混雜到有序的並行
#

「多頭」並不是真的有 num_heads 個獨立的 for 迴圈在跑。它的「多」是體現在張量的維度上,並透過高效的矩陣運算來實現並行處理 (Parallel Processing)

1. 準備階段:生成潛在資訊 (Projection with nn.Linear)

  • 程式碼: keys = self.W_key(x)
  • 輸入形狀: (b, num_tokens, d_in)
  • 輸出形狀: (b, num_tokens, d_out)
  • 目的: 將每個輸入 token 的 d_in 維向量,投影到一個更高維(或相同維度)的 d_out 空間。可以想像 d_out 這個長向量中,已經混雜地包含了未來所有注意力頭所需要的全部資訊。例如,如果 d_out=512,這 512 個維度裡可能同時包含了語法、語意、位置等多方面的潛在特徵。

2. 邏輯分組:定義「頭」的邊界 (Reshape with .view)

  • 程式碼: keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
  • 形狀變化: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
  • 目的: 這是第一次在邏輯上體現「多頭」 的地方。此操作不改變數據本身,只改變解讀它的方式。
  • 比喻: 就像我們宣告,一條長度為 512 的繩子 (d_out),現在起要被視為 8 (num_heads) 段長度為 64 (head_dim) 的短繩拼接而成。至此,對於每一個 token,我們都定義了它在 8 個不同「頭」(或稱子空間)中的獨立表徵。

3. 組織隊形:為並行計算而轉置 (Restructure with .transpose)

  • 程式碼: keys = keys.transpose(1, 2)
  • 形狀變化: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
  • 目的: 這是實現高效並行計算的關鍵。它重新組織了數據的排列順序。
    • 轉置前: 資料以 “Token” 為主進行組織。先看第 1 個 token 在 8 個頭裡的樣子,再看第 2 個 token 在 8 個頭裡的樣子…
    • 轉置後: 資料以 “Head” 為主進行組織。先看第 1 個頭對所有 token 的樣子,再看第 2 個頭對所有 token 的樣子…
  • 為何如此重要?
    • 後續的矩陣乘法 queries @ keys.transpose(2, 3) 會在最後兩個維度上運算。
    • 經過 transpose(1, 2) 後,num_heads 維度被推到了前面,與 b (批次大小) 一樣,被 PyTorch 當作批次維度 (Batch Dimension) 來處理。
    • 這使得一個簡單的矩陣乘法指令,就能同時完成所有 num_heads 個頭部的注意力計算,無需使用 for 迴圈,從而極大地發揮了 GPU 的並行計算優勢。

三、 總結
#

多頭注意力機制是一個「先整合,再拆分,後重組」的過程:

  1. 整合 (nn.Linear): 將輸入 token 投影到一個富含資訊的空間。
  2. 拆分 (.view + .transpose): 在邏輯上將資訊拆成多份給不同的「頭」,並在記憶體佈局上重組它們,使其適合並行處理。
  3. 重組 (.view + out_proj): 將所有頭計算出的上下文向量拼接起來,並透過一個最終的線性層進行資訊融合,得出最終輸出。

相關文章

揭秘 LLM 大型語言模型的訓練過程:一場精密的植物栽培之旅
·1 分鐘
本文簡潔的介紹了 LLM 大型語言模型的訓練過程,並以植物栽培之旅為比喻,讓讀者更容易理解。