KV Cache Quantization Analysis
KV Cache ์์ํ ์ฌ์ธต ๋ถ์
KV Cache Quantization ยท Per-channel Key / Per-token Value ยท KIVI ยท KVQuant ยท FP8/INT8 ยท Mixed Precision
KV ์บ์ ์์ํ๋ LLM ์ถ๋ก ์์ Key/Value ์บ์์ ์ ๋ฐ๋๋ฅผ ๋ฎ์ถฐ(FP16 โ INT8/INT4/INT2 ๋๋ FP8) ๋ฉ๋ชจ๋ฆฌ ์ฉ๋๊ณผ ๋์ญํญ์ ์ ๊ฐํ๋ ๊ธฐ๋ฒ์ ๋๋ค. KV ์บ์๋ ์ํ์คยท๋ฐฐ์น์ ๋น๋กํด ์ปค์ ธ ๋ชจ๋ธ ๊ฐ์ค์น๋งํผ์ด๋ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฐจ์งํ๊ณ , ๋์ฝ๋ฉ์ ์ด ์บ์๋ฅผ ์ฝ๋ ์๋์ ์ข์ฐ๋๋ฏ๋ก, KV๋ฅผ ์๊ฒ ๋ง๋๋ ๊ฒ์ ๊ณง ๋ ํฐ ๋ฐฐ์นยท๋ ๊ธด ๋ฌธ๋งฅยท๋ ๋น ๋ฅธ ๋์ฝ๋ฉ์ ์๋ฏธํฉ๋๋ค. KV ์บ์๋ฅผ ์ค์ด๋ ๋ฐฉ๋ฒ์ ๋นํธํญ, ์ ๋, ์ต๊ทผ ํ ํฐ ๋ณด์กด, ์ปค๋ ์ตํฉ์ ์กฐํฉ ๋ฌธ์ ์ด๊ธฐ๋ ํฉ๋๋ค. ๊ฐ์ 2๋นํธ๋ผ๋ KIVI์ KVQuant์ ํ์ง์ด ๋ค๋ฅธ ์ด์ ๋ Key/Value ๋ถํฌ์ RoPE ์ฒ๋ฆฌ ๋ฐฉ์์ด ๋ค๋ฅด๊ธฐ ๋๋ฌธ์ ๋๋ค.
์ค์ ์์๋ "๋ช ๋นํธ๊น์ง ๋ฎ์ถ ์ ์๋"๋ณด๋ค "์ด๋ค KV๋ฅผ ์ด๋ค ์ ๋๋ก ์ค์ผ ๊ฒ์ธ๊ฐ"๊ฐ ๋ ์ค์ํฉ๋๋ค. Key๋ ์ฑ๋ ์ด์์น์ ๋ฏผ๊ฐํ๊ณ , Value๋ ํ ํฐ๋ณ ์ค์ฐจ ๋์ ์ ๋ฏผ๊ฐํ๋ฉฐ, ํ๋ ์์ํฌ๋ ๋ค์ ์ญ์์ํ ์ค๋ฒํค๋์ ์ปค๋ ์ตํฉ ์ฌ๋ถ์ ๋ฐ๋ผ ์ฑ๋ฅ์ด ๊ฐ๋ฆฝ๋๋ค. ๋ณธ ๋ฌธ์๋ KV ์บ์์ ๋ฉ๋ชจ๋ฆฌ ๋ฌธ์ โ ์์ํ ๊ธฐ์ด โ KV๊ฐ ์ ํน๋ณํ๊ฐ(Key/Value ๋น๋์นญ) โ ์ฃผ์ ๊ธฐ๋ฒ(KIVIยทKVQuantยทFP8/INT8ยทํผํฉ ์ ๋ฐ๋) โ ์์คํ ํตํฉ โ ์ฅ๋จ์ โ ํธ๋ ์ด๋์คํยท๋ฉ๋ชจ๋ฆฌ ํฐ์ด๋ง ์ฐ๊ฒฐ์ ์์๋ก ๋ถ์ํฉ๋๋ค.
1. KV ์บ์ ๋ฉ๋ชจ๋ฆฌ ๋ฌธ์
KV ์บ์๋ ๊ณผ๊ฑฐ ํ ํฐ๋ค์ Key/Value๋ฅผ ์ ์ฅํด ์ดํ ์ ์ฌ๊ณ์ฐ์ ํผํ์ง๋ง, ๊ทธ ํฌ๊ธฐ๊ฐ ์ถ๋ก ์ ๋ณ๋ชฉ์ด ๋ฉ๋๋ค. ํฌ๊ธฐ๋ ๋ฐฐ์นยท์ํ์ค ๊ธธ์ดยท๋ ์ด์ดยทํค๋ ์์ ๋ชจ๋ ๋น๋กํ๋ฏ๋ก, ๊ธด ๋ฌธ๋งฅ์ด๋ ํฐ ๋ฐฐ์น์์๋ ๋ชจ๋ธ ๊ฐ์ค์น๋ฅผ ๋์ด์๊ธฐ๋ ํฉ๋๋ค.
์ฉ๋๊ณผ ๋์ญํญ
| ํญ๋ชฉ | ์์ํ ์ | ์์ํ ํ |
|---|---|---|
| ์ ์ฅ ์ฉ๋ | ์ํ์คยท๋ฐฐ์นยท๋ ์ด์ด์ ๋น๋กํด ๊ธ์ฆ | 8/4/3/2๋นํธ๋ก ์ถ์ |
| ์ฝ๊ธฐ ๋์ญํญ | ๋์ฝ๋ฉ๋ง๋ค ์ ์ฒด KV๋ฅผ ์ฝ์ | ์ฝ๊ธฐ๋ ๊ฐ์, dequant ๋น์ฉ ์ถ๊ฐ |
| ๋ฐฐ์น ํฌ๊ธฐ | HBM ์์ง์ผ๋ก ์ ํ | ๊ฐ์ GPU์ ๋ ๋ง์ ์์ฒญ ์์ฉ |
| ๋ฌธ๋งฅ ๊ธธ์ด | ๊ธด ๋ฌธ๋งฅ์์ ๊ธ๊ฒฉํ ์ฆ๊ฐ | ๋ ๊ธด ๋ฌธ๋งฅ ์๋น์ ์ ๋ฆฌ |
| ### ๋ ๊ฐ์ง ์ ๊ฐ โ ์ฉ๋๊ณผ ๋์ญํญ |
-
๋ฉ๋ชจ๋ฆฌ ์ฉ๋(capacity) โ KV๊ฐ ์์์ง๋ฉด ๊ฐ์ GPU์ ๋ ๊ธด ๋ฌธ๋งฅ, ๋ ํฐ ๋ฐฐ์น๋ฅผ ๋ด์ ์ ์์ต๋๋ค โ ์ฒ๋ฆฌ๋ ํฅ์(๊ฑฐ์ ํญ์ ์ด๋).
-
๋ฉ๋ชจ๋ฆฌ ๋์ญํญ(bandwidth) โ ๋์ฝ๋ฉ์ ๋งค ์คํ KV๋ฅผ ์ฝ๋ ๋ฉ๋ชจ๋ฆฌ ๋ฐ์ด๋ ์์ ์ด๋ฏ๋ก, ์ฝ์ KV๊ฐ ์์์ง๋ฉด ๋์ฝ๋ฉ์ด ๋นจ๋ผ์ง๋๋ค(๋จ ์ญ์์ํ ์ฐ์ฐ๊ณผ ์์๋ ์ ์์).
2. ์์ํ ๊ธฐ์ด
์์ํ๋ ์ฐ์์ ์ธ ์ค์๊ฐ(FP16/BF16)์ ๋ ์ ์ ๋นํธ์ ์ ์(INT8/INT4/INT2)๋ ์ ๋นํธ ๋ถ๋์์์ (FP8)์ผ๋ก ๋งคํํ๋ ๊ฒ์ ๋๋ค. ํต์ฌ์ scale๊ณผ zero-point๋ง ์์ผ๋ฉด ์ ์์ ๊ทผ์ฌ ์ค์๋ฅผ ์ค๊ฐ ์ ์๋ค๋ ์ ์ ๋๋ค.
๊ทธ๋ฆผ 1. ์ค์โ์ ์ ๊ฒฉ์ ๋งคํ, affine ์์ํ ์์, ๋นํธ๋ณ ๋ฉ๋ชจ๋ฆฌ ์ ๊ฐ
ํต์ฌ ๊ฐ๋
-
๊ท ๋ฑ(affine) ์์ํ โ q = round(x/s + z)๋ก ์์ํ, x_hat = sยท(qโz)๋ก ์ญ์์ํ. s๋ scale, z๋ zero-point.
-
๋์นญ vs ๋น๋์นญ โ ๋ถํฌ๊ฐ 0 ์ค์ฌ์ด ์๋๋ฉด zero-point zโ 0์ธ ๋น๋์นญ์ด ์ ๋ฆฌํฉ๋๋ค.
-
์ ๋(granularity) โ per-tensor(ํ ์ 1๊ฐ scale)๋ณด๋ค per-channel/per-token(ํยท์ด๋ณ scale)์ด ์ธ๋ฐํด ์ ํํ์ง๋ง, scale ์ ์ฅ ๊ณต๊ฐ์ด ๋์ด ์คํจ ์ ๊ฐ์ด ์ด๋ก ์น๋ณด๋ค ์ฝ๊ฐ ์์์ง๋๋ค(ํนํ INT4/INT8).
-
์ด์์น(outlier) โ ์์์ ํฐ ๊ฐ์ด ๊ฒฉ์ ๋ฒ์๋ฅผ ๋ํ ๋๋จธ์ง ๊ฐ์ ํด์๋๋ฅผ ๋จ์ด๋จ๋ฆฝ๋๋ค. KV ์์ํ์ ํต์ฌ ๋์ ์ ๋๋ค.
3. KV ์บ์๋ ์ ํน๋ณํ๊ฐ โ Key์ Value์ ๋น๋์นญ
KV ์บ์๋ ๊ฐ์ค์น์ ๋ฌ๋ฆฌ ์ ๋ ฅ๋ง๋ค ๋ถํฌ๊ฐ ๋ณํ๋ ๋ฐํ์ ํ์ฑ๊ฐ์ด๋ผ ์ด์์น๊ฐ ๋ง๊ณ , ๋ฌด์๋ณด๋ค Key ์บ์์ Value ์บ์์ ๋ถํฌ๊ฐ ์๋ก ๋ค๋ฆ ๋๋ค. ์ด ๋น๋์นญ์ ๋ฌด์ํ๊ณ ๋จ์ per-tensor INT4๋ก ์์ํํ๋ฉด ์ ํ๋๊ฐ ํฌ๊ฒ ๋ฌด๋์ง๋๋ค.
๊ทธ๋ฆผ 2. Key๋ ์ฑ๋ ์ด์์น โ per-channel, Value๋ โ per-token (KIVI์ ํต์ฌ ๋ฐ๊ฒฌ)
Key=per-channel, Value=per-token
-
Key ์บ์ โ ์์์ ๊ณ ์ ๋ ์ฑ๋์ด ๋งค์ฐ ํฐ ๊ฐ์ ๊ฐ์ง๋๋ค(์ฑ๋ ์ด์์น). ๊ทธ ์ฑ๋์ ์ฑ๋ ๋จ์๋ก ์์ํ(per-channel)ํด์ผ ์ด์์น๊ฐ ๋ค๋ฅธ ์ฑ๋์ ํด์๋๋ฅผ ๋ง์น์ง ์์ต๋๋ค.
-
Value ์บ์ โ ๋๋ ทํ ์ฑ๋ ์ด์์น๋ ์์ง๋ง, ์ดํ ์ ์ถ๋ ฅ์ด Value๋ค์ ๊ฐ์คํฉ(value mixer)์ด๋ฏ๋ก ํ ํฐ๋ณ๋ก ์ค์ฐจ๋ฅผ ๊ฐ๋๋ per-token ์์ํ๊ฐ ์์ ํฉ๋๋ค.
-
RoPE ๋ฌธ์ โ ํ์ ์์น ์ธ์ฝ๋ฉ์ด ์ฑ๋ ์์ ์์ด Key์ ์ด์์น ๊ตฌ์กฐ๋ฅผ ํํธ๋ฌ๋จ๋ฆฝ๋๋ค. KVQuant๋ RoPE ์ ์ฉ '์ '์ Key๋ฅผ per-channel ์์ํํด ์ด ๋ฌธ์ ๋ฅผ ํผํฉ๋๋ค.
-
์ต๊ทผ ํ ํฐ ๋ณด์กด โ ๊ฐ์ฅ ์ต๊ทผ์ Key/Value ์ผ๋ถ๋ฅผ full precision sliding window๋ก ๋จ๊ธฐ๋ฉด ์ด๋ ค์ด ์ถ๋ก ๊ณผ์ ์ ์ ํ๋๊ฐ ํ๋ณต๋ฉ๋๋ค. ์๋ฅผ ๋ค์ด KIVI ๋ ผ๋ฌธ์์ Llama-2-7B์ GSM8K๊ฐ naive 2๋นํธ์์ 13.50โ5.76์ผ๋ก ๋ฌด๋์ง์ง๋ง, full-precision ์์ฐจ๋ฅผ ๋ KIVI-2๋ 13.50โ12.74๋ก ๋๋ถ๋ถ ํ๋ณตํฉ๋๋ค.
4. ์ฃผ์ KV ์์ํ ๊ธฐ๋ฒ
๊ทธ๋ฆผ 3. FP8/INT8, KIVI, KVQuant, ํผํฉ ์ ๋ฐ๋์ ๋น๊ต
KIVI โ ๋น๋์นญ 2๋นํธ (ICML 2024, arXiv 2402.02750)
KV ์บ์ ๋ถํฌ ๋ถ์์์ ์ถ๋ฐํด, Key๋ per-channelยทValue๋ per-token์ผ๋ก ์์ํํ๋ ๋น๋์นญ 2๋นํธ ๊ธฐ๋ฒ์ ๋๋ค. ํ๋์ด ํ์ ์๊ณ (plug-and-play), ์ต๊ทผ ํ ํฐ์ full precision ์์ฐจ๋ก ์ ์งํ๋ฉฐ, ์ญ์์ํ๋ฅผ matmul๊ณผ ์ตํฉํ ํ๋์จ์ด ์นํ์ ๊ตฌํ์ ์ ๊ณตํฉ๋๋ค. ๋ ผ๋ฌธ์ ๋ฐ๋ฅด๋ฉด Llama-2ยทFalconยทMistral์์ ๊ฑฐ์ ๊ฐ์ ํ์ง์ ์ ์งํ๋ฉด์ (๊ฐ์ค์น ํฌํจ) ์ต๋ ๋ฉ๋ชจ๋ฆฌ๋ฅผ 2.6๋ฐฐ ์ค์ด๊ณ , ์ด๋ก์จ ์ต๋ 4๋ฐฐ ํฐ ๋ฐฐ์น์ 2.35~3.47๋ฐฐ ์ฒ๋ฆฌ๋์ ๋ฌ์ฑํฉ๋๋ค.
KVQuant โ 3๋นํธ (NeurIPS 2024, arXiv 2401.18079, UC Berkeley)
๋ค ๊ฐ์ง ๊ธฐ๋ฒ์ ๊ฒฐํฉํด sub-4-bit ์ ๋ฐ๋๋ฅผ ๊ฐ๋ฅ์ผ ํฉ๋๋ค: (i) per-channel Key ์์ํ, (ii) RoPE ์ ์ฉ ์ Key ์์ํ(pre-RoPE), (iii) ๋ฏผ๊ฐ๋ ๊ฐ์ค ๋น๊ท ๋ฑ ์์ํ(NUQ), (iv) per-vector dense-and-sparse(์ด์์น๋ฅผ ๋ฐ๋ก ๋ถ๋ฆฌ). LLaMAยทLlama-2/3ยทMistral์์ 3๋นํธ๋ก perplexity ์ ํ 0.1 ๋ฏธ๋ง์ ๋ฌ์ฑํ๋ฉฐ, LLaMA-7B ๊ธฐ์ค ๋จ์ผ A100-80GB์์ 100๋ง(1M) ํ ํฐ, 8-GPU์์ 1000๋ง(10M) ํ ํฐ ๋ฌธ๋งฅ์ ์๋นํ ์ ์๋ค๊ณ ๋ณด๊ณ ํฉ๋๋ค. ์ปค์คํ CUDA ์ปค๋๋ก ์ฝ 1.7๋ฐฐ speedup๋ ์ ์ํฉ๋๋ค.
FP8 / INT8 โ 8๋นํธ (ํ๋ ์์ํฌ ๊ธฐ๋ณธ)
-
FP8 โ OCP๊ฐ ์ ์ํ E4M3(4์ง์ยท3๊ฐ์, ยฑ240 ๋ฒ์, FP32 scale ํ์)์ E5M2(5์ง์ยท2๊ฐ์) ๋ ํ์. vLLM์ kv_cache_dtype="fp8"(E4M3/E5M2)์ ์ง์ํ๋ฉฐ, E4M3๋ ์ ํ๋ ์ ํ๊ฐ ๋์ฒด๋ก ๋ฏธ๋ฏธํฉ๋๋ค. NVIDIA Hopper/Ada, AMD MI300 ๋ฑ์ด ํ๋์จ์ด ๋ณํ์ ๊ฐ์ํฉ๋๋ค.
-
INT8 โ TensorRT-LLM์ FP8(E4M3)๊ณผ INT8 KV ์บ์๋ฅผ ๋ชจ๋ ์ง์ํฉ๋๋ค. ๋ค๋ง INT8/INT4๋ scale ์ ์ฅ์ด ์ถ๊ฐ๋ก ํ์ํด FP8๋ณด๋ค ๋ฉ๋ชจ๋ฆฌ ์ด๋์ด ์ฝ๊ฐ ์ค๊ณ , FP8์ ์ค์ผ์ผ ๋ถ๋ด์ด ์์ ๊ตฌํ์ด ๋จ์ํฉ๋๋ค.
ํผํฉ ์ ๋ฐ๋ (Mixed Precision)
์ค์ํ ํ ํฐยท์ด์์น๋ ๊ณ ์ ๋ฐ๋ก, ๋๋จธ์ง๋ ์ ๋นํธ๋ก ์ฐจ๋ฑ ์ ์ฉํ๋ ๋ฐฉํฅ์ ๋๋ค. MiKV๋ ์ค์ KV๋ ๊ณ ์ ๋ฐ๋ก ์ ์งํ๊ณ Q์ attention map์ ๋ถ๋์์์ ์ผ๋ก ๋๋ฉฐ, KVmix๋ ๋ ์ด์ด ์ค์๋(gradient ๊ธฐ๋ฐ)์ ๋ฐ๋ผ ๋นํธํญ์ ํ ๋นํด Key 2.19๋นํธยทValue 2.38๋นํธ ๊ฐ์ ๊ทน์ ๋นํธ์์ 4.9๋ฐฐ ์์ถยท5.3๋ฐฐ ์๋๋ฅผ ๋ณด๊ณ ํฉ๋๋ค. attention sink(์์ชฝ ํต์ฌ ํ ํฐ) ๋ณด์กด์ด ๊ณตํต ์ด์ ์ ๋๋ค.
์ ํ๋ vs ๋นํธํญ ์ ๋ฆฌ
| ๋นํธํญ | ๋ํ ๋ฐฉ์ | ์ ํ๋ ๊ฒฝํฅ | ๋ฉ๋ชจ๋ฆฌ ๊ฒฝํฅ |
|---|---|---|---|
| 8๋นํธ | FP8 / INT8 | ๊ฑฐ์ ๋ฌด์์ค | ์ฝ 2๋ฐฐ ์ ๊ฐ |
| 4๋นํธ | ๊ทธ๋ฃน/ํผํฉ ์์ํ | ๋์ฒด๋ก ์ํธ | ์ฝ 4๋ฐฐ ์ ๊ฐ |
| 3๋นํธ | KVQuant | < 0.1 perplexity ์ ํ ์์ค๊น์ง ๊ฐ๋ฅ | ๊ฐํ ์ ๊ฐ |
| 2๋นํธ | KIVI | residual window๊ฐ ์์ผ๋ฉด ์ค์ฉ์ | ๋งค์ฐ ๊ฐํ ์ ๊ฐ |
| 2๋นํธ ๋ฏธ๋ง | mixed precision | ์ค์ ํ ํฐ ๋ณด์กด์ด ํ์ | ์ ์ฑ ๋ณต์ก๋ ์ฆ๊ฐ |
5. ์์คํ ํตํฉ
์์ํ๋ ์ถ๋ก ํ์ดํ๋ผ์ธ์ ์ด๋์์, ์ด๋ค ๋น์ฉ์ผ๋ก ์ผ์ด๋๋์ง๊ฐ ์ค์ ์ด๋์ ์ข์ฐํฉ๋๋ค. ์ผ๋ฐ์ ์ผ๋ก ์ ์ฅ์ ์ ๋นํธ๋ก ํ๋ ์ดํ ์ ์ฐ์ฐ ์์ฒด๋ ๊ณ ์ ๋ฐ๋ก ์ํํฉ๋๋ค.
๊ทธ๋ฆผ 4. ๋์ฝ๋ฉ ์คํ ์์ ์์ํ/์ญ์์ํ ์์น์ ์ฉ๋ยท๋์ญํญ ๋ ์ธก๋ฉด
๊ตฌํยท์์คํ ๊ณ ๋ ค์ฌํญ
-
์์ํ ์์น โ ์ ํ ํฐ์ K,V๋ฅผ ๊ณ์ฐํ ์งํ ์์ํํด ์บ์์ ์ ์ฅํ๊ณ , ์ดํ ์ ์ง์ ์ ์ญ์์ํ(๋๋ ์ ๋นํธ matmul)ํฉ๋๋ค.
-
์ญ์์ํ ์ค๋ฒํค๋ โ dequant๋ฅผ matmul๊ณผ ์ตํฉ(fused kernel)ํ์ง ์์ผ๋ฉด ๋์ญํญ ์ ๊ฐ์ผ๋ก ์ป์ ์ง์ฐ ์ด๋์ด ์ฌ๋ผ์ง ์ ์์ต๋๋ค.
-
vLLM โ kv_cache_dtype="fp8". ํ์ฌ ์ฃผ ์ด๋์ ์ฉ๋(โ2๋ฐฐ ํ ํฐ)์ด๋ฉฐ, ์ง์ฐ ์ด๋์ ๋ฐฑ์๋ ์์กด์ ๋๋ค(๊ณผ๊ฑฐ์๋ fused dequant ๋ฏธ๊ตฌํ์ผ๋ก ์ง์ฐ ์ด๋์ด ์ ํ์ ์ด์๊ณ , FlashAttention-3 ๋ฐฑ์๋์์๋ ์ดํ ์ ๋ FP8๋ก ์ํ). LLM Compressor๋ก ๋ณด์ (calibration)๋ scale์ ์ฐ๋ฉด ํ์ง์ด ์ข์์ง๋๋ค.
-
TensorRT-LLM โ FP8ยทINT8 KV ์บ์ ์ง์. SqueezeBits ๋ฒค์น๋งํฌ์์ vLLM์ FP8๋ ์ฒ๋ฆฌ๋ ๊ฐ์ ์ด ๊ฑฐ์ ์์๋ ๋ฐ๋ฉด(ํ๋ฆฌํ ์์ฃผ์์ ์ํญ ์ ํ), TensorRT-LLM์ FP8ยทINT8์ ์ฒ๋ฆฌ๋ ํฅ์์ ๋ณด์์ต๋๋ค. ํ๋ ์์ํฌยท์ํฌ๋ก๋์ ๋ฐ๋ผ ๊ฒฐ๊ณผ๊ฐ ํฌ๊ฒ ๋ค๋ฆ ๋๋ค.
-
PagedAttention๊ณผ ์ง๊ต โ ๋ธ๋ก ์์ K,V๋ฅผ ์ ๋นํธ๋ก ๋ด์ผ๋ฉด ๋ธ๋ก๋น ๋ ๋ง์ ํ ํฐ์ด ๋ค์ด๊ฐ๋ฏ๋ก, ๋ ๊ธฐ๋ฒ์ ํจ๊ป ์ธ ์ ์์ต๋๋ค.
-
์ ํ๋ ํ๊ฐ โ perplexity์ ํจ๊ป ์ฅ๋ฌธ๋งฅ ๋ฒค์น๋งํฌ(LongBenchยทpasskey retrievalยทRULER), GSM8K ๊ฐ์ ์ถ๋ก ๊ณผ์ ๋ก ๊ฒ์ฆํฉ๋๋ค.
6. ์ฅ๋จ์
์ฅ์
-
๋ฉ๋ชจ๋ฆฌ ์ ๊ฐ ํจ๊ณผ๊ฐ ์ง์ ์ ์ ๋๋ค. FP8/INT8๋ง์ผ๋ก๋ ๋์ฒด๋ก ์ฝ 2๋ฐฐ, INT4/INT2 ๊ณ์ด์ ๊ทธ๋ณด๋ค ๋ ํฐ ์ ๊ฐ์ด ๊ฐ๋ฅํด ๊ธด ๋ฌธ๋งฅ๊ณผ ํฐ ๋ฐฐ์น๋ฅผ ๊ฐ์ GPU์ ๋ด๊ธฐ ์ฌ์์ง๋๋ค.
-
๋์ญํญ ๋ณ๋ชฉ ์ํ์ ๋์์ด ๋ฉ๋๋ค. ๋์ฝ๋ฉ ๋จ๊ณ๋ ๋งค ์คํ ๊ณผ๊ฑฐ KV๋ฅผ ๋ฐ๋ณตํด์ ์ฝ๊ธฐ ๋๋ฌธ์, KV ํฌ๊ธฐ๊ฐ ์ค๋ฉด HBM ์ฝ๊ธฐ๋๊ณผ ์คํ๋ก๋ฉ ์ ์ก๋๋ ํจ๊ป ์ค์ด๋ญ๋๋ค.
-
๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ ๊ธฐ๋ฒ๊ณผ ๊ฒฐํฉํ๊ธฐ ์ฝ์ต๋๋ค. PagedAttention, KV offloading, ๋ฉ๋ชจ๋ฆฌ ํฐ์ด๋ง๊ณผ ์ง๊ต์ ์ด๋ผ ๊ธฐ์กด ์๋น ์คํ ์์ ๋ง๋ถ์ฌ ํจ๊ณผ๋ฅผ ๋์ ํ ์ ์์ต๋๋ค.
๋จ์
-
๋ฎ์ ๋นํธํญ์์๋ ์ ํ๋ ์ ํ ์ํ์ด ํฝ๋๋ค. ํนํ Key/Value ๋น๋์นญ, RoPE, ์ต๊ทผ ํ ํฐ ๋ณด์กด์ ๋ฌด์ํ๋ฉด 2๋นํธ ๊ทผ์ฒ์์ ํ์ง์ด ๊ธ๊ฒฉํ ๋ฌด๋์ง ์ ์์ต๋๋ค.
-
๊ตฌํ ๋ณต์ก๋๊ฐ ๋์์ง๋๋ค. per-channel/per-token scale ๊ด๋ฆฌ, ์ด์์น ๋ถ๋ฆฌ, fused dequant kernel, backend๋ณ dtype ์ ์ฝ์ ํจ๊ป ๊ณ ๋ คํด์ผ ํฉ๋๋ค.
-
ํ๋ ์์ํฌ์ ์ํฌ๋ก๋์ ๋ฐ๋ผ ์ฑ๋ฅ ์ด๋์ด ๋ฌ๋ผ์ง๋๋ค. ์ฉ๋ ์ ๊ฐ์ ๋น๊ต์ ์์ ์ ์ด์ง๋ง, ์ง์ฐ ์๊ฐ๊ณผ ์ฒ๋ฆฌ๋ ๊ฐ์ ์ kernel fusion, attention backend, batch ํฌ๊ธฐ์ ํฌ๊ฒ ์ข์ฐ๋ฉ๋๋ค.
7. ํธ๋ ์ด๋์คํ์ ๋ฉ๋ชจ๋ฆฌ ํฐ์ด๋ง ์ฐ๊ฒฐ
๊ทธ๋ฆผ 5. ์ ํ๋-๋นํธ ํธ๋ ์ด๋์คํ, ๊ทธ๋ฆฌ๊ณ precision tiering(ํฐ์ด๋ณ ์ ๋ฐ๋ ์ฐจ๋ฑ)
ํต์ฌ ํธ๋ ์ด๋์คํ
-
์ ํ๋ โ ์์ถ๋ฅ โ ๋นํธ๊ฐ ๋ฎ์์๋ก ๋ฉ๋ชจ๋ฆฌ๋ ์ค์ง๋ง ์ ํ๋๊ฐ ๋จ์ด์ง๋๋ค. 8๋นํธ๋ ๊ฑฐ์ ๊ณต์ง, 4๋นํธ๋ ์ค์ฉ์ , 2๋นํธ ์ดํ๋ ์ ๊ตํ ๊ธฐ๋ฒ์ด ํ์์ ๋๋ค.
-
์ฉ๋ ์ด๋ โ dequant ๋น์ฉ โ ๋์ญํญ ์ ๊ฐ๊ณผ ์ญ์์ํ ์ฐ์ฐ์ด ์์๋ ์ ์์ด, ์ปค๋ ์ตํฉ ์ฌ๋ถ๊ฐ ์ง์ฐ ์ด๋์ ๊ฒฐ์ ํฉ๋๋ค.
-
์ ๋ฐ๋ โ scale ์ ์ฅ โ ์ธ๋ฐํ ์ ๋(per-channel/per-token)์ผ์๋ก ์ ํํ์ง๋ง scale ์ค๋ฒํค๋๊ฐ ๋๊ณ , GQA/MQAยทFlashAttention๊ณผ์ ํธํ์ฑ๋ ๊ณ ๋ คํด์ผ ํฉ๋๋ค.
๋ฉ๋ชจ๋ฆฌ ํฐ์ด๋งยท์คํ๋ก๋ฉ๊ณผ์ ๊ฒฐํฉ (์ง๊ต์ ยท์๋ณด์ )
์์ํ๋ KV๋ฅผ '์ค์ด๊ณ ', ํฐ์ด๋ง์ KV๋ฅผ '์ฎ๊น๋๋ค'. ๋ ๊ธฐ๋ฒ์ ์ง๊ต์ ์ด๋ฉฐ ํจ๊ป ์ฐ๋ฉด ํจ๊ณผ๊ฐ ๋ฐฐ๊ฐ๋ฉ๋๋ค. ์์ํ๋ก KV๋ฅผ ์ค์ด๋ฉด CXL/PCIe๋ก ์ฎ๊ธธ ๋ฐ์ดํฐ์ ๋์ญํญ ๋ถ๋ด์ด ํจ๊ป ๊ฐ์ํ๊ธฐ ๋๋ฌธ์ ๋๋ค.
-
precision tiering โ hot KV๋ GPU์์ ๊ณ ์ ๋ฐ(FP16/FP8), warm KV๋ CPU DRAM์์ INT8/INT4, cold KV๋ CXL/NVMe์์ ๋ ๊ณต๊ฒฉ์ ์ธ INT4/INT2๋ก ๋๋ ์์ ํฐ์ด๋ณ ์ ๋ฐ๋ ์ฐจ๋ฑ์ด ๊ฐ๋ฅํฉ๋๋ค.
-
์ ์ก๋ ์ ๊ฐ โ ์๋ ํฐ์ด๋ก ์ฎ๊ธฐ๋ KV๋ฅผ ๋ ๊ณต๊ฒฉ์ ์ผ๋ก ์์ํํ๋ฉด, ๋๋ฆฐ ๋งํฌ(PCIe/CXL)๋ฅผ ํต๊ณผํ๋ ๋ฐ์ดํฐ๋ ์์ฒด๊ฐ ์ค์ด ์คํ๋ก๋ฉ์ ๋์ญํญ ๋ณ๋ชฉ์ ์ํํฉ๋๋ค.
-
์ ๋ขฐ์ฑ ๊ด์ โ ์ ๋นํธ์ผ์๋ก ๋นํธ ์ค๋ฅ์ ๋ํ ๋ฏผ๊ฐ๋๊ฐ ๋ฌ๋ผ์ง๋ฏ๋ก, ํฐ์ด๋ณ ์ ๋ฐ๋์ ECC(์ค๋ฅ ์ ์ )๋ฅผ ํจ๊ป ์ค๊ณํ๋ ๊ณต๋ ์ค๊ณ ์ฌ์ง๊ฐ ์์ต๋๋ค.
8. ๊ด๋ จ ๊ธฐ์
| ๋ฌธ์/์ฐ๊ตฌ | ์ฐ๊ฒฐ์ |
|---|---|
| PagedAttention Analysis | KV ๋ธ๋ก ๋จ์ ๊ด๋ฆฌ์ block table ๊ธฐ๋ฐ ์๋น |
| KV Cache Offloading Analysis | KV๋ฅผ GPU ๋ฐ์ผ๋ก ์ฎ๊ธฐ๋ ์คํ๋ก๋ฉ๊ณผ์ ๊ฒฐํฉ |
| StreamingLLM Analysis | attention sink์ ์ต๊ทผ ํ ํฐ ๋ณด์กด ์์ด๋์ด |
| KIVI (arXiv 2402.02750) | 2-bit asymmetric KV quantization |
| KVQuant (arXiv 2401.18079) | pre-RoPE, NUQ, dense-and-sparse, 3-bit KV quantization |
| MiKV / KVmix | mixed precision๊ณผ ์ค์ ํ ํฐ ๋ณด์กด |
9. ํต์ฌ ์ ๋ฆฌ
KV ์บ์ ์์ํ๋ Key/Value์ ์ ๋ฐ๋๋ฅผ ๋ฎ์ถฐ ์ฉ๋๊ณผ ๋์ญํญ์ ๋์์ ์ค์ด๋ ๊ฐ์ฅ ์ง์ ์ ์ธ KV ์ ๊ฐ ์๋จ์ด๋ค. ํต์ฌ์ Key์ Value์ ๋ถํฌ๊ฐ ๋ค๋ฅด๋ค๋ ์ ์ผ๋ก, Key๋ ์ฑ๋ ์ด์์น ๋๋ฌธ์ per-channel, Value๋ per-token์ผ๋ก ์์ํํด์ผ ํ๋ฉฐ(KIVI), RoPE ์ ์์ํยท๋น๊ท ๋ฑ ์์ํยท์ด์์น ๋ถ๋ฆฌ(KVQuant)์ ์ต๊ทผ/์ค์ ํ ํฐ์ ๊ณ ์ ๋ฐ ๋ณด์กด์ด ์ ๋นํธ ์ ํ๋์ ์ด์ ๋ค. 8๋นํธ(FP8/INT8)๋ ๊ฑฐ์ ๋ฌด์์ค๋ก ํ๋ ์์ํฌ๊ฐ ๊ธฐ๋ณธ ์ง์ํ๊ณ , 3๋นํธ๋ KVQuant, 2๋นํธ๋ KIVI์ ๋น๋์นญ+์์ฐจ, 2๋นํธ ๋ฏธ๋ง์ ํผํฉ ์ ๋ฐ๋๊ฐ ํ์ํ๋ค. ์์ํ๋ ๋ฉ๋ชจ๋ฆฌ ํฐ์ด๋งยท์คํ๋ก๋ฉ๊ณผ ์ง๊ต์ ์ด์ด์, hot KV๋ ๊ณ ์ ๋ฐยทcold KV๋ ์ ๋นํธ๋ก ๋๋ precision tiering์ผ๋ก ๊ฒฐํฉํ๋ฉด CXL/PCIe ์ ์ก๋๊น์ง ํจ๊ป ์ค์ผ ์ ์๋ค.
์ฃผ์ โ ๋ณธ๋ฌธ ์์น๋ ๊ฐ ๋ ผ๋ฌธ(KIVI arXiv 2402.02750, KVQuant arXiv 2401.18079, MiKV arXiv 2402.18096, KVmix arXiv 2506.08018)๊ณผ ๊ด๋ จ ํ๋ ์์ํฌ์ ๋ณด๊ณ ๊ฐ์ผ๋ก, ๋ชจ๋ธยท๋ฐ์ดํฐ์ ยทํ๋์จ์ดยท๋ฐฐ์น ์กฐ๊ฑด์ ๋ฐ๋ผ ๋ฌ๋ผ์ง๋ค. KIVI์ 2.6๋ฐฐ๋ ๊ฐ์ค์น๋ฅผ ํฌํจํ peak memory ๊ธฐ์ค์ด๊ณ , KVQuant์ 1M/10M ๋ฌธ๋งฅ ์์น๋ ํน์ ๋ชจ๋ธ๊ณผ GPU ๊ตฌ์ฑ์ ์ ์ ๋ก ํ๋ค.
๋ฉ๋ชจ๋ฆฌ ํฌ๊ธฐ ๊ฐ๊ฐ
KV_bytes โ 2 ร L ร n_kv_heads ร head_dim ร seq_len ร batch ร bytes
bytes: FP16=2, FP8/INT8=1, INT4=0.5, INT2=0.25
๋นํธํญ๋ณ ์์ฝ
| ๋นํธํญ | ์ ํ๋ | ํ์ ๊ธฐ๋ฒ |
|---|---|---|
| 8๋นํธ(FP8/INT8) | ๊ฑฐ์ ๋ฌด์์ค | per-tensor scale๋ก ์ถฉ๋ถ, ํ๋ ์์ํฌ ๊ธฐ๋ณธ |
| 4๋นํธ | ๋์ฒด๋ก ์ํธ | ๊ทธ๋ฃน ์์ํ(group-wise) |
| 3๋นํธ | ์ ๊ตํ๋ฉด ๊ฑฐ์ ๋ฌด์์ค | KVQuant(pre-RoPEยทNUQยทdense-and-sparse) |
| 2๋นํธ | ์ฃผ์ ํ์ | ๋น๋์นญ(์ฑ๋/ํ ํฐ)+full-precision ์์ฐจ(KIVI) |
| 2๋นํธ ๋ฏธ๋ง | ์ด๋ ค์ | ํผํฉ ์ ๋ฐ๋(์ค์ ํ ํฐ ๊ณ ์ ๋ฐ) |