MTP Analysis
MTP (Multi-Token Prediction) ์ฌ์ธต ๋ถ์
Multi-Token Prediction ยท ๋ณ๋ ฌ head / ์์ฐจ module ยท Self-Speculative Decoding ยท DeepSeek-V3
MTP(Multi-Token Prediction)๋ LLM์ด ๋งค ์์น์์ ๋ค์ ํ ํฐ 1๊ฐ๋ง ์์ธกํ๋ ๋์ , ๋ค์ ์ฌ๋ฌ ํ ํฐ์ ํจ๊ป ์์ธกํ๋๋ก ๋ง๋๋ ๊ธฐ๋ฒ์ ๋๋ค. ํ์ต์์๋ ๋ ๋ฐ๋ ๋์ supervision์ ์ ๊ณตํด ํํ๋ ฅ๊ณผ ๋ฐ์ดํฐ ํจ์จ์ ๋์ด๊ณ , ์ถ๋ก ์์๋ ์ฌ๋ฌ ํ ํฐ ์ด์์ ๋ง๋ค์ด speculative decoding์ ๊ฐ์์ ํ์ฉํฉ๋๋ค.
KV ์บ์๋ฅผ ์ค์ด๋ MLAยทGQAยท์์ํ๊ฐ '์ฝ์ ๋ฐ์ดํฐ์ ํฌ๊ธฐ'๋ฅผ ์ค์ด๋ ์ ๊ทผ์ด๋ผ๋ฉด, MTP๋ ๋์ฝ๋ฉ์์ ๊ฐ์ค์น ์ฝ๊ธฐ ํ์๋ฅผ ์ค์ด๋ ์ ๊ทผ์ ๋๋ค. ๋์ ์๋ก ๋ค๋ฅธ ๋ณ๋ชฉ์ ๊ณต๋ตํ๋ฏ๋ก ํจ๊ป ์ฐ๊ธฐ ์ข์ต๋๋ค.
ํต์ฌ ๊ฐ๋
- Multi-token target: ํ ์์น์์ ๋ค์ 2๊ฐ ์ด์ ํ ํฐ์ ์์ธกํฉ๋๋ค.
- Shared trunk: ๊ณตํต Transformer ๋ณธ์ฒด ์์ ์์ธก ๊ฒฝ๋ก๋ฅผ ์น์ต๋๋ค.
- Parallel head / sequential module: ๊ตฌํ์ ๋ณ๋ ฌ head ๋๋ ์์ฐจ module๋ก ๋๋ฉ๋๋ค.
- Self-speculative decoding: MTP๊ฐ ๋ง๋ ์ด์์ ๋ฉ์ธ ๋ชจ๋ธ์ด ๊ฒ์ฆํฉ๋๋ค.
- Orthogonality: KV ์์ถ๊ณผ ๋ฌ๋ฆฌ ๋์ฝ๋ฉ ํ์ ์์ฒด๋ฅผ ์ค์ ๋๋ค.
1. ํต์ฌ ์์ด๋์ด - ํ ๋ฒ์ ์ฌ๋ฌ ํ ํฐ ์์ธก
๊ธฐ์กด next-token prediction(NTP)์ ํ ์์น์์ ๋ฐ๋ก ๋ค์ ํ ํฐ 1๊ฐ๋ง ์์ธกํฉ๋๋ค. MTP๋ ์ด๋ฅผ ํ์ฅํด ๋ค์ n๊ฐ ํ ํฐ์ ์์ธกํฉ๋๋ค.
๊ทธ๋ฆผ 1. NTP๋ ๋ค์ ํ ํฐ 1๊ฐ๋ฅผ, MTP๋ ๋ค์ n๊ฐ๋ฅผ ์์ธกํ๋ค.
| ํญ๋ชฉ | NTP | MTP |
|---|---|---|
| ํ์ต ๋ชฉํ | ๋ค์ ํ ํฐ 1๊ฐ | ๋ค์ ์ฌ๋ฌ ํ ํฐ |
| ์ถ๋ ฅ ๊ตฌ์กฐ | ๋จ์ผ head | ์ฌ๋ฌ head ๋๋ ์์ฐจ module |
| ํ์ต ์ ํธ | ์๋์ ์ผ๋ก ํฌ์ํจ | ๋ ๋ฐ๋ ๋์ |
| ์ถ๋ก ํ์ฉ | ๊ธฐ๋ณธ autoregressive | speculative decoding ์ด์ |
MTP๊ฐ ๋ ธ๋ฆฌ๋ ํจ๊ณผ๋ ๋ ๊ฐ์ง์ ๋๋ค.
- ๋ ๋์ ํ์ต ์ ํธ - ํ ํ ํฐ์ด ์ฌ๋ฌ ๋ฏธ๋ ํ ํฐ์ ์์ธกํ๋ฉด supervision์ด ์กฐ๋ฐํด์ ธ ์ฝ๋์ ์ถ๋ก ํ ๊ณผ์ ์์ ์ฑ๋ฅ์ด ์ข์์ง๋๋ค.
- ๋ ๋น ๋ฅธ ์ถ๋ก - ์์ธกํ ์ฌ๋ฌ ํ ํฐ์ ์ด์์ผ๋ก ์จ์ ๋ฉ์ธ ๋ชจ๋ธ์ forward ํ์๋ฅผ ์ค์ ๋๋ค.
Meta์ Multi-token prediction ๋ ผ๋ฌธ์ 13B ๋ชจ๋ธ์์ HumanEval +12%, MBPP +17%๋ฅผ ๋ณด๊ณ ํ๊ณ , 4-token prediction ๋ชจ๋ธ์ ์ถ๋ก ์์ ์ต๋ 3๋ฐฐ ๋น ๋ฅผ ์ ์๋ค๊ณ ์ค๋ช ํฉ๋๋ค.
2. ๋ ๊ฐ์ง MTP ์ํคํ ์ฒ
MTP ๊ตฌํ์ ํฌ๊ฒ ๋ ๊ฐ๋์ ๋๋ค. Meta๋ ๋ณ๋ ฌ ๋ ๋ฆฝ head๋ฅผ, DeepSeek-V3๋ ์์ฐจ module์ ์ฌ์ฉํฉ๋๋ค.
๊ทธ๋ฆผ 2. Meta์ ๋ณ๋ ฌ ๋ ๋ฆฝ head์ DeepSeek-V3์ ์์ฐจ module.
| ๊ตฌํ | ๊ตฌ์กฐ | ํน์ง |
|---|---|---|
| Meta | ๊ณต์ trunk + ๋ ๋ฆฝ head n๊ฐ | ๊ตฌํ์ด ๋จ์ํ๊ณ ํ์ต ์๊ฐ ์ค๋ฒํค๋๊ฐ ๊ฑฐ์ ์์ |
| DeepSeek-V3 | ์์ฐจ MTP module D๊ฐ | ์ธ๊ณผ ์ฌ์ฌ์ ์ ์งํ๋ฉฐ ํ์ต ๋ณด์กฐ์ ์ถ๋ก ์ฌํ์ฉ์ ํจ๊ป ๋ ธ๋ฆผ |
Meta - ๋ณ๋ ฌ ๋ ๋ฆฝ head
- ๊ณต์ trunk ์์ n๊ฐ์ ๋ ๋ฆฝ ์ถ๋ ฅ head๋ฅผ ๋ก๋๋ค.
- ๊ฐ head๋ t+1, t+2, ..., t+n์ ๋ณ๋ ฌ๋ก ์์ธกํฉ๋๋ค.
- head ๊ฐ ์์กด์ฑ์ ์ง์ ๋ชจ๋ธ๋งํ์ง๋ ์์ง๋ง ๊ตฌ์กฐ๊ฐ ๋จ์ํฉ๋๋ค.
DeepSeek-V3 - ์์ฐจ module
- ๊ฐ module์ 1๊ฐ์ Transformer ์ธต๊ณผ ๊ณต์ ์๋ฒ ๋ฉ, ๊ณต์ ์ถ๋ ฅ head๋ก ๊ตฌ์ฑ๋ฉ๋๋ค.
- k๋ฒ์งธ module์ ์ ๋จ๊ณ hidden state์ ์ค์ ๋ฏธ๋ ํ ํฐ ์๋ฒ ๋ฉ์ ํจ๊ป ์จ์ ๋ค์ ํ ํฐ์ ์์ธกํฉ๋๋ค.
- ํ์ต์ฉ ๋ณด์กฐ objective๋ก ์ฐ๋ฉฐ, ์ถ๋ก ์์๋ module์ ๋ฒ๋ ค๋ ๋ฉ์ธ ๋ชจ๋ธ์ด ๋์ํฉ๋๋ค.
๊ฐ๋ ์ ํ๋ฆ
h = trunk(x_1:t)
for k in 1..D:
e = Emb(x_{t+k})
h = M_k(concat(h, e))
p = OutHead(h)
loss += CE(p, x_{t+k+1})
3. ์ถ๋ก ๊ฐ์ - Speculative Decoding
MTP์ ์ถ๋ก ๊ฐ์์ speculative decoding ํํ๋ก ์ด๋ค์ง๋๋ค. ๊ฐ๋ฒผ์ด MTP๊ฐ ์ฌ๋ฌ ํ ํฐ์ ์ด์์ผ๋ก ๋ง๋ค๊ณ , ๋ฌด๊ฑฐ์ด ๋ฉ์ธ ๋ชจ๋ธ์ด ๊ทธ๊ฒ๋ค์ ํ ๋ฒ์ forward๋ก ๊ฒ์ฆํฉ๋๋ค.
๊ทธ๋ฆผ 3. ์ด์ ์์ฑ, ์ผ๊ด ๊ฒ์ฆ, ์๋ฝ/๊ฑฐ์ ํ๋ฆ.
| ๋จ๊ณ | ์ญํ |
|---|---|
| ์ด์ ์์ฑ | MTP module์ด ์ฌ๋ฌ ํ๋ณด ํ ํฐ์ ๋น ๋ฅด๊ฒ ๋ง๋ฆ |
| ๊ฒ์ฆ | ๋ฉ์ธ ๋ชจ๋ธ์ด ํ ๋ฒ์ forward๋ก ํ๋ณด๋ฅผ ํ์ธ |
| ์๋ฝ/๊ฑฐ์ | ๋ง๋ prefix๋ ์๋ฝํ๊ณ , ํ๋ฆฐ ์ง์ ๋ถํฐ ๋ค์ ์์ฑ |
์ ๋นจ๋ผ์ง๋๊ฐ
- ๋์ฝ๋ฉ์ ๋์ฒด๋ก ๋ฉ๋ชจ๋ฆฌ ๋ฐ์ด๋์ ๋๋ค.
- ๊ฐ์ ๊ฐ์ค์น๋ฅผ ์ฝ๋ ํ ๋ฒ์ forward๋ก ์ฌ๋ฌ ํ ํฐ์ ๊ฒ์ฆํ ์ ์์ต๋๋ค.
- ์ด์์ด ๋ง๋ ๋งํผ ๋ฉ์ธ ๋ชจ๋ธ forward ํ์๊ฐ ์ค์ด๋ญ๋๋ค.
Meta ๋ ผ๋ฌธ์ 4-token prediction์ด ์ต๋ 3๋ฐฐ ๋น ๋ฅผ ์ ์๋ค๊ณ ๋ณด๊ณ ํฉ๋๋ค. DeepSeek-V3 Technical Report๋ MTP objective๊ฐ speculative decoding์ ํ์ฉ๋ ์ ์๋ค๊ณ ๋ช ์ํฉ๋๋ค.
4. ํจ๊ณผยท์ฑํ๊ณผ ๋ค๋ฅธ ๊ธฐ๋ฒ๊ณผ์ ๊ด๊ณ
๊ทธ๋ฆผ 4. ๋ณด๊ณ ๋ ํจ๊ณผ์ ๊ด๋ จ ๊ธฐ๋ฒ, KV ์ ๊ฐ ๊ธฐ๋ฒ๊ณผ์ ๊ด๊ณ.
| ์ถ์ฒ | ๊ด์ฐฐ |
|---|---|
| Meta MTP ๋ ผ๋ฌธ | HumanEval +12%, MBPP +17%, 4-token prediction ์ต๋ 3x ๊ฐ์ |
| DeepSeek-V3 Technical Report | MTP objective๋ฅผ ์ฑํํ๊ณ speculative decoding์ ํ์ฉ ๊ฐ๋ฅํ๋ค๊ณ ์ค๋ช |
| DeepSeek-V3 README | ๋ชจ๋ธ ๋ฐฐํฌ๋ณธ์ MTP module weights๊ฐ ํฌํจ๋จ |
๊ด๋ จ ๊ธฐ๋ฒ
- Medusa - ์ฌ๋ฌ decoding head๋ฅผ ๋ถ์ฌ ์ถ๋ก ์ ๊ฐ์ํ๋ ๋ฐฉ์์ ๋๋ค.
- EAGLE - ์ธ๊ณผ ์ฌ์ฌ์ ์ ์งํ๋ draft๋ฅผ ์จ์ speculative decoding์ ์ํํฉ๋๋ค.
- Block-wise parallel decoding - MTP์ ๊ฐ์ self-speculative ๊ณ์ด์ ๊ณ ์ ์ ์์ด๋์ด์ ๋๋ค.
- SGLang - DeepSeek-V3 README์์๋ MTP ์ง์์ด ์งํ ์ค์ด๋ผ๊ณ ์๋ดํฉ๋๋ค.
KV ์ ๊ฐ ๊ธฐ๋ฒ๊ณผ์ ๊ด๊ณ
| ์ถ | ์ค์ด๋ ๋์ | ์ |
|---|---|---|
| KV ์ ๊ฐ | ์ฝ์ ๋ฐ์ดํฐ์ ํฌ๊ธฐ | MLA, GQA, KV ์์ํ |
| MTP/speculative | ์ฝ๋ ํ์ | Multi-token prediction, Medusa, EAGLE |
๋์ ์๋ก ๋ค๋ฅธ ๋ณ๋ชฉ์ ๊ฑด๋๋ฆฝ๋๋ค. ๋ฐ๋ผ์ MLA๋ก KV๋ฅผ ์ค์ด๊ณ , MTP๋ก ๋์ฝ๋ฉ ํ์๋ฅผ ์ค์ด๋ ์์ ์กฐํฉ์ด ๊ฐ๋ฅํฉ๋๋ค. DeepSeek-V3๋ MLA์ MTP๋ฅผ ํจ๊ป ์ฌ์ฉํ ๋ํ ์ฌ๋ก์ ๋๋ค.
์ฅ๋จ์
| ์ฅ์ | ๋จ์ |
|---|---|
| ์ฝ๋์ ์์ฑ ๊ณผ์ ์์ ์ฑ๋ฅ ํฅ์์ ๊ธฐ๋ํ ์ ์๋ค | ์ถ๊ฐ head ๋๋ module ์ค๊ณ๊ฐ ํ์ํ๋ค |
| speculative decoding์ ๋ฐ๋ก ์ฐ๊ฒฐ๋๋ค | ์ด์ ํ์ง์ด ๋ฎ์ผ๋ฉด ์ด๋์ด ์ค ์ ์๋ค |
| KV ์์ถ ๊ธฐ๋ฒ๊ณผ ํจ๊ป ์ธ ์ ์๋ค | ์์ธก ๊น์ด n, k๋ฅผ ์ ์ก์์ผ ํ๋ค |
| ํ์ต ๋ณด์กฐ objective๋ก๋ ์ธ ์ ์๋ค | ๊ตฌํ ๋ฐฉ์์ ๋ฐ๋ผ ์ปค๋/์๋น ๋ณต์ก๋๊ฐ ๋ ์ ์๋ค |
๊ด๋ จ ๊ธฐ์
| ๋ฌธ์/๊ธฐ์ | ์ฐ๊ฒฐ์ |
|---|---|
| MLA Analysis | KV ์บ์๋ฅผ ์ค์ด๋ ๋ํ์ ์ํคํ ์ฒ |
| PagedAttention Analysis | KV ๋ฐฐ์น ๊ด๋ฆฌ์ MTP์ ์ง๊ต์ฑ ๋น๊ต ๊ธฐ์ค |
| FlashAttention Analysis | ๋์ฝ๋ฉ/์ดํ ์ ์ปค๋ ์ต์ ํ์ ๋ํ ์ฌ๋ก |
| DeepSeek-V3 Technical Report | MTP objective์ MoE/MLA ์กฐํฉ |
| Multi-token prediction paper | MTP์ ๊ธฐ๋ณธ ์คํ๊ณผ self-speculative decoding |
| Block-wise parallel decoding | self-speculative decoding ๊ณ์ด์ ๊ธฐ๋ฐ ์์ด๋์ด |
| Medusa | ๋ค์ค head ๊ธฐ๋ฐ ์ถ๋ก ๊ฐ์ |
ํต์ฌ ์ ๋ฆฌ
MTP๋ ๋ค์ ํ ํฐ 1๊ฐ๊ฐ ์๋๋ผ ์ฌ๋ฌ ํ ํฐ์ ํจ๊ป ์์ธกํด ํ์ต ์ ํธ๋ฅผ ์กฐ๋ฐํ๊ฒ ๋ง๋ค๊ณ , ๊ทธ ๊ฒฐ๊ณผ๋ฅผ speculative decoding์ ์ด์์ผ๋ก ์ฌํ์ฉํ๋ ๊ธฐ๋ฒ์ ๋๋ค. Meta์ ๋ ผ๋ฌธ์ 13B ๋ชจ๋ธ์์ HumanEval +12%, MBPP +17%์ ์ต๋ 3๋ฐฐ ์ถ๋ก ๊ฐ์์ ๋ณด๊ณ ํ๊ณ , DeepSeek-V3๋ MTP objective๋ฅผ ์ฑํํด ํ์ต๊ณผ ์ถ๋ก ๋ ๋ค๋ฅผ ๋ ธ๋ฆฝ๋๋ค.
๊ตฌํ์ ๋ณ๋ ฌ ๋ ๋ฆฝ head์ ์์ฐจ module๋ก ๋๋๋ฉฐ, DeepSeek-V3์ฒ๋ผ MLA์ ํจ๊ป ์ฐ๋ฉด KV ์บ์ ์ ๊ฐ๊ณผ ๋์ฝ๋ฉ ๊ฐ์์ ๋์์ ๋ ธ๋ฆด ์ ์์ต๋๋ค. ์์ง๋ '์ฝ์ ๋ฐ์ดํฐ ํฌ๊ธฐ'๋ฅผ ์ค์ด๋ ๊ธฐ๋ฒ๊ณผ '์ฝ๋ ํ์'๋ฅผ ์ค์ด๋ ๊ธฐ๋ฒ์ด ์๋ก ์ง๊ตํ๋ค๋ ์ ์ ๋๋ค.