Ryotta's Basic

LLM
๐Ÿค– LLM ๊ฒ€์ฆ์™„๋ฃŒ

MTP Analysis

MTP (Multi-Token Prediction) ์‹ฌ์ธต ๋ถ„์„

Multi-Token Prediction ยท ๋ณ‘๋ ฌ head / ์ˆœ์ฐจ module ยท Self-Speculative Decoding ยท DeepSeek-V3

MTP(Multi-Token Prediction)๋Š” LLM์ด ๋งค ์œ„์น˜์—์„œ ๋‹ค์Œ ํ† ํฐ 1๊ฐœ๋งŒ ์˜ˆ์ธกํ•˜๋Š” ๋Œ€์‹ , ๋‹ค์Œ ์—ฌ๋Ÿฌ ํ† ํฐ์„ ํ•จ๊ป˜ ์˜ˆ์ธกํ•˜๋„๋ก ๋งŒ๋“œ๋Š” ๊ธฐ๋ฒ•์ž…๋‹ˆ๋‹ค. ํ•™์Šต์—์„œ๋Š” ๋” ๋ฐ€๋„ ๋†’์€ supervision์„ ์ œ๊ณตํ•ด ํ‘œํ˜„๋ ฅ๊ณผ ๋ฐ์ดํ„ฐ ํšจ์œจ์„ ๋†’์ด๊ณ , ์ถ”๋ก ์—์„œ๋Š” ์—ฌ๋Ÿฌ ํ† ํฐ ์ดˆ์•ˆ์„ ๋งŒ๋“ค์–ด speculative decoding์˜ ๊ฐ€์†์— ํ™œ์šฉํ•ฉ๋‹ˆ๋‹ค.

KV ์บ์‹œ๋ฅผ ์ค„์ด๋Š” MLAยทGQAยท์–‘์žํ™”๊ฐ€ '์ฝ์„ ๋ฐ์ดํ„ฐ์˜ ํฌ๊ธฐ'๋ฅผ ์ค„์ด๋Š” ์ ‘๊ทผ์ด๋ผ๋ฉด, MTP๋Š” ๋””์ฝ”๋”ฉ์—์„œ ๊ฐ€์ค‘์น˜ ์ฝ๊ธฐ ํšŸ์ˆ˜๋ฅผ ์ค„์ด๋Š” ์ ‘๊ทผ์ž…๋‹ˆ๋‹ค. ๋‘˜์€ ์„œ๋กœ ๋‹ค๋ฅธ ๋ณ‘๋ชฉ์„ ๊ณต๋žตํ•˜๋ฏ€๋กœ ํ•จ๊ป˜ ์“ฐ๊ธฐ ์ข‹์Šต๋‹ˆ๋‹ค.

ํ•ต์‹ฌ ๊ฐœ๋…

  • Multi-token target: ํ•œ ์œ„์น˜์—์„œ ๋‹ค์Œ 2๊ฐœ ์ด์ƒ ํ† ํฐ์„ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค.
  • Shared trunk: ๊ณตํ†ต Transformer ๋ณธ์ฒด ์œ„์— ์˜ˆ์ธก ๊ฒฝ๋กœ๋ฅผ ์–น์Šต๋‹ˆ๋‹ค.
  • Parallel head / sequential module: ๊ตฌํ˜„์€ ๋ณ‘๋ ฌ head ๋˜๋Š” ์ˆœ์ฐจ module๋กœ ๋‚˜๋‰ฉ๋‹ˆ๋‹ค.
  • Self-speculative decoding: MTP๊ฐ€ ๋งŒ๋“  ์ดˆ์•ˆ์„ ๋ฉ”์ธ ๋ชจ๋ธ์ด ๊ฒ€์ฆํ•ฉ๋‹ˆ๋‹ค.
  • Orthogonality: KV ์••์ถ•๊ณผ ๋‹ฌ๋ฆฌ ๋””์ฝ”๋”ฉ ํšŸ์ˆ˜ ์ž์ฒด๋ฅผ ์ค„์ž…๋‹ˆ๋‹ค.

1. ํ•ต์‹ฌ ์•„์ด๋””์–ด - ํ•œ ๋ฒˆ์— ์—ฌ๋Ÿฌ ํ† ํฐ ์˜ˆ์ธก

๊ธฐ์กด next-token prediction(NTP)์€ ํ•œ ์œ„์น˜์—์„œ ๋ฐ”๋กœ ๋‹ค์Œ ํ† ํฐ 1๊ฐœ๋งŒ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค. MTP๋Š” ์ด๋ฅผ ํ™•์žฅํ•ด ๋‹ค์Œ n๊ฐœ ํ† ํฐ์„ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค.

llm_0030_mtp_analysis

๊ทธ๋ฆผ 1. NTP๋Š” ๋‹ค์Œ ํ† ํฐ 1๊ฐœ๋ฅผ, MTP๋Š” ๋‹ค์Œ n๊ฐœ๋ฅผ ์˜ˆ์ธกํ•œ๋‹ค.

ํ•ญ๋ชฉ NTP MTP
ํ•™์Šต ๋ชฉํ‘œ ๋‹ค์Œ ํ† ํฐ 1๊ฐœ ๋‹ค์Œ ์—ฌ๋Ÿฌ ํ† ํฐ
์ถœ๋ ฅ ๊ตฌ์กฐ ๋‹จ์ผ head ์—ฌ๋Ÿฌ head ๋˜๋Š” ์ˆœ์ฐจ module
ํ•™์Šต ์‹ ํ˜ธ ์ƒ๋Œ€์ ์œผ๋กœ ํฌ์†Œํ•จ ๋” ๋ฐ€๋„ ๋†’์Œ
์ถ”๋ก  ํ™œ์šฉ ๊ธฐ๋ณธ autoregressive speculative decoding ์ดˆ์•ˆ

MTP๊ฐ€ ๋…ธ๋ฆฌ๋Š” ํšจ๊ณผ๋Š” ๋‘ ๊ฐ€์ง€์ž…๋‹ˆ๋‹ค.

  • ๋” ๋‚˜์€ ํ•™์Šต ์‹ ํ˜ธ - ํ•œ ํ† ํฐ์ด ์—ฌ๋Ÿฌ ๋ฏธ๋ž˜ ํ† ํฐ์„ ์˜ˆ์ธกํ•˜๋ฉด supervision์ด ์กฐ๋ฐ€ํ•ด์ ธ ์ฝ”๋“œ์™€ ์ถ”๋ก ํ˜• ๊ณผ์ œ์—์„œ ์„ฑ๋Šฅ์ด ์ข‹์•„์ง‘๋‹ˆ๋‹ค.
  • ๋” ๋น ๋ฅธ ์ถ”๋ก  - ์˜ˆ์ธกํ•œ ์—ฌ๋Ÿฌ ํ† ํฐ์„ ์ดˆ์•ˆ์œผ๋กœ ์จ์„œ ๋ฉ”์ธ ๋ชจ๋ธ์˜ forward ํšŸ์ˆ˜๋ฅผ ์ค„์ž…๋‹ˆ๋‹ค.

Meta์˜ Multi-token prediction ๋…ผ๋ฌธ์€ 13B ๋ชจ๋ธ์—์„œ HumanEval +12%, MBPP +17%๋ฅผ ๋ณด๊ณ ํ–ˆ๊ณ , 4-token prediction ๋ชจ๋ธ์€ ์ถ”๋ก ์—์„œ ์ตœ๋Œ€ 3๋ฐฐ ๋น ๋ฅผ ์ˆ˜ ์žˆ๋‹ค๊ณ  ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค.

2. ๋‘ ๊ฐ€์ง€ MTP ์•„ํ‚คํ…์ฒ˜

MTP ๊ตฌํ˜„์€ ํฌ๊ฒŒ ๋‘ ๊ฐˆ๋ž˜์ž…๋‹ˆ๋‹ค. Meta๋Š” ๋ณ‘๋ ฌ ๋…๋ฆฝ head๋ฅผ, DeepSeek-V3๋Š” ์ˆœ์ฐจ module์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

MTP Analysis

๊ทธ๋ฆผ 2. Meta์˜ ๋ณ‘๋ ฌ ๋…๋ฆฝ head์™€ DeepSeek-V3์˜ ์ˆœ์ฐจ module.

๊ตฌํ˜„ ๊ตฌ์กฐ ํŠน์ง•
Meta ๊ณต์œ  trunk + ๋…๋ฆฝ head n๊ฐœ ๊ตฌํ˜„์ด ๋‹จ์ˆœํ•˜๊ณ  ํ•™์Šต ์‹œ๊ฐ„ ์˜ค๋ฒ„ํ—ค๋“œ๊ฐ€ ๊ฑฐ์˜ ์—†์Œ
DeepSeek-V3 ์ˆœ์ฐจ MTP module D๊ฐœ ์ธ๊ณผ ์‚ฌ์Šฌ์„ ์œ ์ง€ํ•˜๋ฉฐ ํ•™์Šต ๋ณด์กฐ์™€ ์ถ”๋ก  ์žฌํ™œ์šฉ์„ ํ•จ๊ป˜ ๋…ธ๋ฆผ

Meta - ๋ณ‘๋ ฌ ๋…๋ฆฝ head

  • ๊ณต์œ  trunk ์œ„์— n๊ฐœ์˜ ๋…๋ฆฝ ์ถœ๋ ฅ head๋ฅผ ๋‘ก๋‹ˆ๋‹ค.
  • ๊ฐ head๋Š” t+1, t+2, ..., t+n์„ ๋ณ‘๋ ฌ๋กœ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค.
  • head ๊ฐ„ ์˜์กด์„ฑ์„ ์ง์ ‘ ๋ชจ๋ธ๋งํ•˜์ง€๋Š” ์•Š์ง€๋งŒ ๊ตฌ์กฐ๊ฐ€ ๋‹จ์ˆœํ•ฉ๋‹ˆ๋‹ค.

DeepSeek-V3 - ์ˆœ์ฐจ module

  • ๊ฐ module์€ 1๊ฐœ์˜ Transformer ์ธต๊ณผ ๊ณต์œ  ์ž„๋ฒ ๋”ฉ, ๊ณต์œ  ์ถœ๋ ฅ head๋กœ ๊ตฌ์„ฑ๋ฉ๋‹ˆ๋‹ค.
  • k๋ฒˆ์งธ module์€ ์•ž ๋‹จ๊ณ„ hidden state์™€ ์‹ค์ œ ๋ฏธ๋ž˜ ํ† ํฐ ์ž„๋ฒ ๋”ฉ์„ ํ•จ๊ป˜ ์จ์„œ ๋‹ค์Œ ํ† ํฐ์„ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค.
  • ํ•™์Šต์šฉ ๋ณด์กฐ objective๋กœ ์“ฐ๋ฉฐ, ์ถ”๋ก  ์‹œ์—๋Š” module์„ ๋ฒ„๋ ค๋„ ๋ฉ”์ธ ๋ชจ๋ธ์ด ๋™์ž‘ํ•ฉ๋‹ˆ๋‹ค.

๊ฐœ๋…์  ํ๋ฆ„

h = trunk(x_1:t)
for k in 1..D:
  e = Emb(x_{t+k})
  h = M_k(concat(h, e))
  p = OutHead(h)
  loss += CE(p, x_{t+k+1})

3. ์ถ”๋ก  ๊ฐ€์† - Speculative Decoding

MTP์˜ ์ถ”๋ก  ๊ฐ€์†์€ speculative decoding ํ˜•ํƒœ๋กœ ์ด๋ค„์ง‘๋‹ˆ๋‹ค. ๊ฐ€๋ฒผ์šด MTP๊ฐ€ ์—ฌ๋Ÿฌ ํ† ํฐ์„ ์ดˆ์•ˆ์œผ๋กœ ๋งŒ๋“ค๊ณ , ๋ฌด๊ฑฐ์šด ๋ฉ”์ธ ๋ชจ๋ธ์ด ๊ทธ๊ฒƒ๋“ค์„ ํ•œ ๋ฒˆ์˜ forward๋กœ ๊ฒ€์ฆํ•ฉ๋‹ˆ๋‹ค.

MTP Analysis

๊ทธ๋ฆผ 3. ์ดˆ์•ˆ ์ƒ์„ฑ, ์ผ๊ด„ ๊ฒ€์ฆ, ์ˆ˜๋ฝ/๊ฑฐ์ ˆ ํ๋ฆ„.

๋‹จ๊ณ„ ์—ญํ• 
์ดˆ์•ˆ ์ƒ์„ฑ MTP module์ด ์—ฌ๋Ÿฌ ํ›„๋ณด ํ† ํฐ์„ ๋น ๋ฅด๊ฒŒ ๋งŒ๋“ฆ
๊ฒ€์ฆ ๋ฉ”์ธ ๋ชจ๋ธ์ด ํ•œ ๋ฒˆ์˜ forward๋กœ ํ›„๋ณด๋ฅผ ํ™•์ธ
์ˆ˜๋ฝ/๊ฑฐ์ ˆ ๋งž๋Š” prefix๋Š” ์ˆ˜๋ฝํ•˜๊ณ , ํ‹€๋ฆฐ ์ง€์ ๋ถ€ํ„ฐ ๋‹ค์‹œ ์ƒ์„ฑ

์™œ ๋นจ๋ผ์ง€๋Š”๊ฐ€

  • ๋””์ฝ”๋”ฉ์€ ๋Œ€์ฒด๋กœ ๋ฉ”๋ชจ๋ฆฌ ๋ฐ”์šด๋“œ์ž…๋‹ˆ๋‹ค.
  • ๊ฐ™์€ ๊ฐ€์ค‘์น˜๋ฅผ ์ฝ๋Š” ํ•œ ๋ฒˆ์˜ forward๋กœ ์—ฌ๋Ÿฌ ํ† ํฐ์„ ๊ฒ€์ฆํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์ดˆ์•ˆ์ด ๋งž๋Š” ๋งŒํผ ๋ฉ”์ธ ๋ชจ๋ธ forward ํšŸ์ˆ˜๊ฐ€ ์ค„์–ด๋“ญ๋‹ˆ๋‹ค.

Meta ๋…ผ๋ฌธ์€ 4-token prediction์ด ์ตœ๋Œ€ 3๋ฐฐ ๋น ๋ฅผ ์ˆ˜ ์žˆ๋‹ค๊ณ  ๋ณด๊ณ ํ•ฉ๋‹ˆ๋‹ค. DeepSeek-V3 Technical Report๋Š” MTP objective๊ฐ€ speculative decoding์— ํ™œ์šฉ๋  ์ˆ˜ ์žˆ๋‹ค๊ณ  ๋ช…์‹œํ•ฉ๋‹ˆ๋‹ค.

4. ํšจ๊ณผยท์ฑ„ํƒ๊ณผ ๋‹ค๋ฅธ ๊ธฐ๋ฒ•๊ณผ์˜ ๊ด€๊ณ„

MTP Analysis

๊ทธ๋ฆผ 4. ๋ณด๊ณ ๋œ ํšจ๊ณผ์™€ ๊ด€๋ จ ๊ธฐ๋ฒ•, KV ์ ˆ๊ฐ ๊ธฐ๋ฒ•๊ณผ์˜ ๊ด€๊ณ„.

์ถœ์ฒ˜ ๊ด€์ฐฐ
Meta MTP ๋…ผ๋ฌธ HumanEval +12%, MBPP +17%, 4-token prediction ์ตœ๋Œ€ 3x ๊ฐ€์†
DeepSeek-V3 Technical Report MTP objective๋ฅผ ์ฑ„ํƒํ•˜๊ณ  speculative decoding์— ํ™œ์šฉ ๊ฐ€๋Šฅํ•˜๋‹ค๊ณ  ์„ค๋ช…
DeepSeek-V3 README ๋ชจ๋ธ ๋ฐฐํฌ๋ณธ์— MTP module weights๊ฐ€ ํฌํ•จ๋จ

๊ด€๋ จ ๊ธฐ๋ฒ•

  • Medusa - ์—ฌ๋Ÿฌ decoding head๋ฅผ ๋ถ™์—ฌ ์ถ”๋ก ์„ ๊ฐ€์†ํ•˜๋Š” ๋ฐฉ์‹์ž…๋‹ˆ๋‹ค.
  • EAGLE - ์ธ๊ณผ ์‚ฌ์Šฌ์„ ์œ ์ง€ํ•˜๋Š” draft๋ฅผ ์จ์„œ speculative decoding์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
  • Block-wise parallel decoding - MTP์™€ ๊ฐ™์€ self-speculative ๊ณ„์—ด์˜ ๊ณ ์ „์  ์•„์ด๋””์–ด์ž…๋‹ˆ๋‹ค.
  • SGLang - DeepSeek-V3 README์—์„œ๋Š” MTP ์ง€์›์ด ์ง„ํ–‰ ์ค‘์ด๋ผ๊ณ  ์•ˆ๋‚ดํ•ฉ๋‹ˆ๋‹ค.

KV ์ ˆ๊ฐ ๊ธฐ๋ฒ•๊ณผ์˜ ๊ด€๊ณ„

์ถ• ์ค„์ด๋Š” ๋Œ€์ƒ ์˜ˆ
KV ์ ˆ๊ฐ ์ฝ์„ ๋ฐ์ดํ„ฐ์˜ ํฌ๊ธฐ MLA, GQA, KV ์–‘์žํ™”
MTP/speculative ์ฝ๋Š” ํšŸ์ˆ˜ Multi-token prediction, Medusa, EAGLE

๋‘˜์€ ์„œ๋กœ ๋‹ค๋ฅธ ๋ณ‘๋ชฉ์„ ๊ฑด๋“œ๋ฆฝ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ MLA๋กœ KV๋ฅผ ์ค„์ด๊ณ , MTP๋กœ ๋””์ฝ”๋”ฉ ํšŸ์ˆ˜๋ฅผ ์ค„์ด๋Š” ์‹์˜ ์กฐํ•ฉ์ด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค. DeepSeek-V3๋Š” MLA์™€ MTP๋ฅผ ํ•จ๊ป˜ ์‚ฌ์šฉํ•œ ๋Œ€ํ‘œ ์‚ฌ๋ก€์ž…๋‹ˆ๋‹ค.

์žฅ๋‹จ์ 

์žฅ์  ๋‹จ์ 
์ฝ”๋“œ์™€ ์ƒ์„ฑ ๊ณผ์ œ์—์„œ ์„ฑ๋Šฅ ํ–ฅ์ƒ์„ ๊ธฐ๋Œ€ํ•  ์ˆ˜ ์žˆ๋‹ค ์ถ”๊ฐ€ head ๋˜๋Š” module ์„ค๊ณ„๊ฐ€ ํ•„์š”ํ•˜๋‹ค
speculative decoding์— ๋ฐ”๋กœ ์—ฐ๊ฒฐ๋œ๋‹ค ์ดˆ์•ˆ ํ’ˆ์งˆ์ด ๋‚ฎ์œผ๋ฉด ์ด๋“์ด ์ค„ ์ˆ˜ ์žˆ๋‹ค
KV ์••์ถ• ๊ธฐ๋ฒ•๊ณผ ํ•จ๊ป˜ ์“ธ ์ˆ˜ ์žˆ๋‹ค ์˜ˆ์ธก ๊นŠ์ด n, k๋ฅผ ์ž˜ ์žก์•„์•ผ ํ•œ๋‹ค
ํ•™์Šต ๋ณด์กฐ objective๋กœ๋„ ์“ธ ์ˆ˜ ์žˆ๋‹ค ๊ตฌํ˜„ ๋ฐฉ์‹์— ๋”ฐ๋ผ ์ปค๋„/์„œ๋น™ ๋ณต์žก๋„๊ฐ€ ๋Š˜ ์ˆ˜ ์žˆ๋‹ค

๊ด€๋ จ ๊ธฐ์ˆ 

๋ฌธ์„œ/๊ธฐ์ˆ  ์—ฐ๊ฒฐ์ 
MLA Analysis KV ์บ์‹œ๋ฅผ ์ค„์ด๋Š” ๋Œ€ํ‘œ์  ์•„ํ‚คํ…์ฒ˜
PagedAttention Analysis KV ๋ฐฐ์น˜ ๊ด€๋ฆฌ์™€ MTP์˜ ์ง๊ต์„ฑ ๋น„๊ต ๊ธฐ์ค€
FlashAttention Analysis ๋””์ฝ”๋”ฉ/์–ดํ…์…˜ ์ปค๋„ ์ตœ์ ํ™”์˜ ๋Œ€ํ‘œ ์‚ฌ๋ก€
DeepSeek-V3 Technical Report MTP objective์™€ MoE/MLA ์กฐํ•ฉ
Multi-token prediction paper MTP์˜ ๊ธฐ๋ณธ ์‹คํ—˜๊ณผ self-speculative decoding
Block-wise parallel decoding self-speculative decoding ๊ณ„์—ด์˜ ๊ธฐ๋ฐ˜ ์•„์ด๋””์–ด
Medusa ๋‹ค์ค‘ head ๊ธฐ๋ฐ˜ ์ถ”๋ก  ๊ฐ€์†

ํ•ต์‹ฌ ์ •๋ฆฌ

MTP๋Š” ๋‹ค์Œ ํ† ํฐ 1๊ฐœ๊ฐ€ ์•„๋‹ˆ๋ผ ์—ฌ๋Ÿฌ ํ† ํฐ์„ ํ•จ๊ป˜ ์˜ˆ์ธกํ•ด ํ•™์Šต ์‹ ํ˜ธ๋ฅผ ์กฐ๋ฐ€ํ•˜๊ฒŒ ๋งŒ๋“ค๊ณ , ๊ทธ ๊ฒฐ๊ณผ๋ฅผ speculative decoding์˜ ์ดˆ์•ˆ์œผ๋กœ ์žฌํ™œ์šฉํ•˜๋Š” ๊ธฐ๋ฒ•์ž…๋‹ˆ๋‹ค. Meta์˜ ๋…ผ๋ฌธ์€ 13B ๋ชจ๋ธ์—์„œ HumanEval +12%, MBPP +17%์™€ ์ตœ๋Œ€ 3๋ฐฐ ์ถ”๋ก  ๊ฐ€์†์„ ๋ณด๊ณ ํ–ˆ๊ณ , DeepSeek-V3๋Š” MTP objective๋ฅผ ์ฑ„ํƒํ•ด ํ•™์Šต๊ณผ ์ถ”๋ก  ๋‘˜ ๋‹ค๋ฅผ ๋…ธ๋ฆฝ๋‹ˆ๋‹ค.

๊ตฌํ˜„์€ ๋ณ‘๋ ฌ ๋…๋ฆฝ head์™€ ์ˆœ์ฐจ module๋กœ ๋‚˜๋‰˜๋ฉฐ, DeepSeek-V3์ฒ˜๋Ÿผ MLA์™€ ํ•จ๊ป˜ ์“ฐ๋ฉด KV ์บ์‹œ ์ ˆ๊ฐ๊ณผ ๋””์ฝ”๋”ฉ ๊ฐ€์†์„ ๋™์‹œ์— ๋…ธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์š”์ง€๋Š” '์ฝ์„ ๋ฐ์ดํ„ฐ ํฌ๊ธฐ'๋ฅผ ์ค„์ด๋Š” ๊ธฐ๋ฒ•๊ณผ '์ฝ๋Š” ํšŸ์ˆ˜'๋ฅผ ์ค„์ด๋Š” ๊ธฐ๋ฒ•์ด ์„œ๋กœ ์ง๊ตํ•œ๋‹ค๋Š” ์ ์ž…๋‹ˆ๋‹ค.