FlashAttention Analysis
FlashAttention ์ฌ์ธต ๋ถ์
IO-aware Exact Attention ยท Tiling ยท Online Softmax ยท Recomputation ยท FA-2/FA-3 ยท Flash-Decoding
๊ฐ์
FlashAttention์ ์ดํ ์ ์ ๊ทผ์ฌํ์ง ์์ผ๋ฉด์(exact) ๋ ๋น ๋ฅด๊ณ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ ์ผ๋ก ๋ง๋๋ IO-aware ์๊ณ ๋ฆฌ์ฆ์ ๋๋ค. ํต์ฌ ํต์ฐฐ์ ์ดํ ์ ์ด ์ฐ์ฐ(FLOPs)์ด ์๋๋ผ ๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ(HBM ์ฝ๊ธฐ/์ฐ๊ธฐ)์ ๋ณ๋ชฉ์ด ์๋ค๋ ์ ์ ๋๋ค.
ํ์ค ์ดํ ์ ์ ๊ฑฐ๋ํ NรN ํ๋ ฌ์ HBM์ ๋ง๋ค๊ณ ์ฌ๋ฌ ๋ฒ ์ฝ๊ณ ์๋๋ค. FlashAttention์ QยทKยทV๋ฅผ SRAM์ ๋ง๋ ๋ธ๋ก์ผ๋ก ์ชผ๊ฐ(tiling) online softmax๋ก ์ ์ง์ ์ผ๋ก ๊ณ์ฐํด NรN์ ์์ ๋ง๋ค์ง ์๊ณ , ์ญ์ ํ์์๋ ํต๊ณ๋ง ์ ์ฅํด ์ฌ๊ณ์ฐ(recomputation)ํฉ๋๋ค. ๊ฒฐ๊ณผ๋ ์ ํํ๋ฉด์ ๋ฉ๋ชจ๋ฆฌ O(N)ยท์๋ 2~4๋ฐฐ์ ๋๋ค. ๋ณธ ๋ฌธ์๋ ๋ฌธ์ (๋ฉ๋ชจ๋ฆฌ ๋ฐ์ด๋) โ GPU ๋ฉ๋ชจ๋ฆฌ ๊ณ์ธต โ ํด๋ฒโ (tiling+online softmax) โ ํด๋ฒโก(recomputation)์ ๊ฒฐ๊ณผ โ ์งํ(FA-2/FA-3ยทFlash-Decoding)์ ์์น์ ์์๋ก ๋ถ์ํฉ๋๋ค.
ํต์ฌ ๊ฐ๋
- IO-aware: HBM ์ฝ๊ธฐ/์ฐ๊ธฐ๋ฅผ ์ค์ฌ ์ ์ฒด ์๋๋ฅผ ๋์ด์ฌ๋ฆฌ๋ ์ค๊ณ์ ๋๋ค.
- Tiling: QยทKยทV๋ฅผ SRAM์ ๋ง๋ ๋ธ๋ก์ผ๋ก ๋๋์ด ๋ธ๋ก ๋จ์๋ก ๊ณ์ฐํฉ๋๋ค.
- Online softmax: running max์ running sum์ผ๋ก ๋ธ๋ก ๊ฒฐ๊ณผ๋ฅผ ์ ์ง์ ์ผ๋ก ํฉ์นฉ๋๋ค.
- Recomputation: ์ญ์ ํ์์ ํ์ํ ๊ฐ์ ํต๊ณ๋ง ์ ์ฅํ๊ณ ๋ค์ ๊ณ์ฐํฉ๋๋ค.
- FA-2/FA-3: ๋ณ๋ ฌํ์ Hopper ๋น๋๊ธฐยทFP8์ผ๋ก FlashAttention์ ํ์ฅํฉ๋๋ค.
๋น๊ต/๋ถ์
| ํญ๋ชฉ | ํ์ค ์ดํ ์ | FlashAttention |
|---|---|---|
| ์ค๊ฐ ํ๋ ฌ | SยทP๋ฅผ HBM์ ์ค์ฒดํ | NรN์ ๋ง๋ค์ง ์๊ณ ๋ธ๋ก ๋จ์๋ก ๋์ |
| ๋ฉ๋ชจ๋ฆฌ ๋ณต์ก๋ | O(Nยฒ) | O(N) |
| ์ฃผ์ ๋ณ๋ชฉ | HBM ์ฝ๊ธฐ/์ฐ๊ธฐ | ๋ธ๋ก ๊ณ์ฐ๊ณผ ์ฌ๊ณ์ฐ |
| ๊ธด ์ํ์ค | ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด ๊ธ๊ฒฉํ ์ฆ๊ฐ | ๊ธธ์ด๊ฐ ๊ธธ์๋ก ์ด์ ์ด ์ปค์ง |
๋์ ์๋ฆฌ
FlashAttention์ ๋์ ์๋ฆฌ๋ ํ์ค ์ดํ
์
์ด ์ด๋ค ์ง์ ์์ HBM ์๋ณต์ ๋ฐ๋ณตํ๋์ง ํ์
ํ ๋ค, ๊ฐ์ ์ํ์ ๊ฒฐ๊ณผ๋ฅผ ์ ์งํ ์ฑ ๋ฐ์ดํฐ๋ฅผ SRAM ์์์ ์ต๋ํ ์ค๋ ๋จธ๋ฌด๋ฅด๊ฒ ๋ง๋๋ ๋ฐ ์์ต๋๋ค. ์๋ ์์๋๋ก ๋ณด๋ฉด ์ tiling, online softmax, recomputation์ด ํจ๊ป ํ์ํด์ง๋์ง ์์ฐ์ค๋ฝ๊ฒ ์ฐ๊ฒฐ๋ฉ๋๋ค.
1. ๋ฌธ์ โ ํ์ค ์ดํ ์ ์ ๋ฉ๋ชจ๋ฆฌ ๋ฐ์ด๋(IO ๋ณ๋ชฉ)
ํ์ค ์ดํ ์ ์ S=QK^T(NรN) โ P=softmax(S)(NรN) โ O=PV๋ก ๊ณ์ฐํฉ๋๋ค. ์ฌ๊ธฐ์ N์ ์ํ์ค ๊ธธ์ด์ด๋ฉฐ, ๊ฑฐ๋ํ NรN ํ๋ ฌ SยทP๋ฅผ HBM์ ์ค์ฒดํ(materialize)ํ๊ณ ์ฌ๋ฌ ํจ์ค๋ก ์ฝ๊ณ ์๋๋ค. ์๊ฐยท๋ฉ๋ชจ๋ฆฌ๊ฐ Nยฒ์ ๋น๋กํฉ๋๋ค.
๊ทธ๋ฆผ 1. ํ์ค ์ดํ ์ ์ NรN ์ค์ฒดํ์ ๋ฉ๋ชจ๋ฆฌ ๋ฐ์ด๋ ํน์ฑ
ํต์ฌ โ ์ดํ ์ ์ ์ฐ์ฐ์ด ์๋๋ผ HBM ์ ๊ทผ์ด ๋ณ๋ชฉ
-
์ดํ ์ ์ ์ง์ง ๋ณ๋ชฉ์ FLOPs๊ฐ ์๋๋ผ HBM ์ ๊ทผ(๋ฐ์ดํฐ ์ด๋)์ ๋๋ค. A100์ ์ฐ์ฐ ์ฒ๋ฆฌ๋์ 312 TFLOPS(FP16)๋ก ๋งค์ฐ ๋น ๋ฅธ ๋ฐ๋ฉด HBM ๋์ญํญ์ ~2TB/s์ ๊ทธ์ณ, ์ฐ์ฐ ์ ๋์ด ๋ฐ์ดํฐ๋ฅผ ๊ธฐ๋ค๋ฆฌ๋ memory-bound ์ํ๊ฐ ๋ฉ๋๋ค.
-
๊ทธ๋์ ๊ทผ์ฌ ์ดํ ์ (LinformerยทPerformerยทsparse)์ด FLOPs๋ฅผ ์ค์ฌ๋ ์ค์ wall-clock ์๋๋ ์ ๋นจ๋ผ์ง์ง ์์ต๋๋ค โ ๊ทธ๋ค์ IO(๋ฉ๋ชจ๋ฆฌ ์ฝ๊ณ ์ฐ๊ธฐ)๋ฅผ ๊ณ ๋ คํ์ง ์์๊ธฐ ๋๋ฌธ์ ๋๋ค. ์ด๊ฒ์ด FlashAttention์ ํต์ฌ ์ง์ ์ ๋๋ค.
-
ํด๊ฒฐ์ ์ด์ ๋ 'IO-aware' โ NรN์ HBM์ ๋ง๋ค์ง ์๊ณ ๋ฉ๋ชจ๋ฆฌ ์ด๋์ ์ต์ํํ๋ ๊ฒ์ ๋๋ค.
2. GPU ๋ฉ๋ชจ๋ฆฌ ๊ณ์ธต โ IO-awareness์ ๋ฌด๋
FlashAttention์ ์ดํดํ๋ ค๋ฉด GPU์ ๊ฐํ๋ฅธ ๋ฉ๋ชจ๋ฆฌ ๊ณ์ธต์ ๋ด์ผ ํฉ๋๋ค. ์๊ณ ๋งค์ฐ ๋น ๋ฅธ on-chip SRAM๊ณผ ํฌ๊ณ ์๋์ ์ผ๋ก ๋๋ฆฐ HBM์ ์ฐจ์ด๊ฐ ๋ชจ๋ ๊ฒ์ ์ถ๋ฐ์ ์ ๋๋ค.
๊ทธ๋ฆผ 2. SRAMโHBM ๋ฉ๋ชจ๋ฆฌ ๊ณ์ธต๊ณผ IO ๋ณต์ก๋
๋ฉ๋ชจ๋ฆฌ ๊ณ์ธต๊ณผ IO ๋ณต์ก๋
-
SRAM์ ์์ญ MB์ ๋ถ๊ณผํ์ง๋ง ๋์ญํญ์ด HBM์ ์ฝ 10๋ฐฐ์ ๋๋ค(์ฐ์ฐ ์ ๋ ๋ฐ๋ก ์). HBM์ ์์ญ GB๋ก ํฌ์ง๋ง ์๋์ ์ผ๋ก ๋๋ฆฝ๋๋ค.
-
ํ์ค ์ดํ ์ ์ ๋นํจ์จ์ NรN ํ๋ ฌ์ HBM์ ๋ง๋ค๊ณ scalingยทmaskingยทsoftmaxยทdropout ๋จ๊ณ๋ง๋ค ๋ค์ ์ฝ๊ณ ์ฐ๋ ๋ฐ ์์ต๋๋ค. IO-aware ์ค๊ณ์ ๋ชฉํ๋ ๋ฐ์ดํฐ๋ฅผ SRAM์ ์ฌ๋ ค ๊ณ์ฐ์ ๋๋ด๊ณ HBM ์๋ณต์ ์ต์ํํ๋ ๊ฒ์ ๋๋ค.
-
FlashAttention์ IO ๋ณต์ก๋๋ O(Nยฒdยฒ/M) HBM ์ ๊ทผ(M=SRAM ํฌ๊ธฐ, d=ํค๋ ์ฐจ์)์ผ๋ก, ํ์ค ์ดํ ์ ์ ฮฉ(Nd+Nยฒ)๋ณด๋ค ํจ์ฌ ์ ์ต๋๋ค. SRAM์ด ํด์๋ก ์ ๋ฆฌํ๋ฉฐ, ์ ํ ์ดํ ์ ์์ ์ด๋ณด๋ค ๋ ์ค์ผ ์ ์๋ค๋ ์ด๋ก ์ ํํ๋ ์ฆ๋ช ํ์ต๋๋ค. ์ฆ '์ฐ์ฐ์ ์ค์ด๋ ๊ฒ'์ด ์๋๋ผ '๋ฐ์ดํฐ ์ด๋์ ์ค์ด๋ ๊ฒ'์ด ํต์ฌ์ ๋๋ค.
3. ํด๋ฒ โ โ Tiling + Online Softmax
FlashAttention์ QยทKยทV๋ฅผ SRAM์ ๋ง๋ ๋ธ๋ก์ผ๋ก ์ชผ๊ฐ(tiling) ๋ธ๋ก ๋จ์๋ก ๊ณ์ฐํ๊ณ , softmax๋ฅผ ์ ์ง์ ์ผ๋ก(online) ๊ณ์ฐํด NรN ํ๋ ฌ์ ์์ ๋ง๋ค์ง ์์ต๋๋ค.
๊ทธ๋ฆผ 3. Tiling๊ณผ online softmax(running max/sum rescaling)
Online Softmax ๋์ (๊ฐ๋ )
Tiling + Online Softmax์ ํต์ฌ
-
Tiling โ QยทKยทV๋ฅผ ๋ธ๋ก์ผ๋ก ์ชผ๊ฐ ๋ธ๋ก ์์ SRAM์ ์ฌ๋ ค ๊ณ์ฐํ๊ณ ์ถ๋ ฅ๋ง ๋์ ํฉ๋๋ค. ๊ฑฐ๋ํ NรN ํ๋ ฌ SยทP๋ฅผ HBM์ ์ ํ ๋ง๋ค์ง ์์ต๋๋ค(materialization ํํผ).
-
๋์ ๊ณผ ํด๊ฒฐ โ softmax๋ ํ ์ ์ฒด์ maxยทํฉ์ด ํ์ํด ๋ธ๋ก ๋จ์๋ก ์ชผ๊ฐ๊ธฐ ์ด๋ ต์ต๋๋ค(๋น๊ฒฐํฉ์ ). Online softmax๋ ๋ธ๋ก๋ง๋ค running max(m)์ running sum(l)์ ๊ฐฑ์ ํ๊ณ , ์ ๋ธ๋ก์ max๊ฐ ๋ ํฌ๋ฉด ์ด์ ๊น์ง ๋์ ํ ์ถ๋ ฅ O๋ฅผ ์ง์ ์ธ์๋ก rescaleํด ๋ณด์ ํฉ๋๋ค. softmax๋ฅผ '๊ฒฐํฉ ๊ฐ๋ฅ'ํ๊ฒ ๋ง๋ ๊ฒ์ด FlashAttention์ ํต์ฌ ์ํ์ ๋๋ค.
-
์ ๋ถ ํ GPU ์ปค๋๋ก ์ตํฉ(fuse)ํฉ๋๋ค โ scalingยทmaskingยทsoftmaxยทdropoutยทPV๋ฅผ ํ ํจ์ค์ ์ฒ๋ฆฌํฉ๋๋ค. ๊ฒฐ๊ณผ๋ ํ์ค ์ดํ ์ ๊ณผ ์ ํํ ๋์ผ(exact)ํ๋ฉฐ ๊ทผ์ฌ๊ฐ ์๋๋๋ค. ๋จ์ง IO ํจํด๋ง ๋ฐ๊ฟ๋๋ค.
# Q๋ฅผ ๋ธ๋ก ๋จ์๋ก ๊ณ ์ ํ๊ณ , KยทV ๋ธ๋ก์ ์ํํ๋ฉฐ ์ถ๋ ฅ์ ๋์
m = -inf
l = 0
O = 0
for j in range(num_kv_blocks):
# K_j, V_j ๋ฅผ SRAM์ ๋ก๋
S_ij = Q_i @ K_j.T
m_new = max(m, rowmax(S_ij))
P_ij = exp(S_ij - m_new)
l = exp(m - m_new) * l + rowsum(P_ij)
O = exp(m - m_new) * O + P_ij @ V_j
m = m_new
O = O / l
4. ํด๋ฒ โก โ Recomputation, ๊ทธ๋ฆฌ๊ณ ๊ฒฐ๊ณผ
๊ทธ๋ฆผ 4. Recomputation(ํต๊ณ๋ง ์ ์ฅํด ์ญ์ ํ์์ ์ฌ๊ณ์ฐ)๊ณผ ๊ฒฐ๊ณผ
Recomputation
-
์ญ์ ํ์๋ ๋ณดํต ์์ ํ์ NรN ์ดํ ์ ํ๋ ฌ์ด ํ์ํฉ๋๋ค โ ์ ์ฅํ๋ฉด O(Nยฒ) ๋ฉ๋ชจ๋ฆฌ์ ๋๋ค. FlashAttention์ NรN์ ์ ์ฅํ์ง ์๊ณ softmax ์ ๊ทํ ํต๊ณ(running maxยทsum)๋ง ์ ์ฅํฉ๋๋ค.
-
์ญ์ ํ ๋ ๊ทธ ํต๊ณ๋ก SยทP๋ฅผ SRAM์์ ์ฆ์ ์ฌ๊ณ์ฐ(recompute)ํ์ฌ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ ํ์ผ๋ก ๋ง๋ญ๋๋ค. ํต์ฌ ํต์ฐฐ์ recompute๋ก FLOPs๋ ๋์ง๋ง, memory-bound๋ผ HBM ์ ๊ทผ ๊ฐ์๊ฐ ๋ ์ปค์ '์์๋'๊ฐ ์คํ๋ ค ๋นจ๋ผ์ง๋ค๋ ๊ฒ์ ๋๋ค.
๊ฒฐ๊ณผ (FlashAttention-1)
๋ํ ์ ํ ์ดํ ์ ์์ HBM ์ ๊ทผ์ ์ ๊ทผ์ ์ผ๋ก ๋ ์ค์ผ ์ ์๋ค๋ IO ๋ณต์ก๋ ํํ์ ์ฆ๋ช ํ๊ณ (์ต์ ), block-sparse๋ก ํ์ฅํ๋ฉด IO ์ธ์ + ํฌ์ ํจํด์ผ๋ก ๊ธฐ์กด ๊ทผ์ฌ ์ดํ ์ ๋ณด๋ค๋ ๋น ๋ฅธ ๊ทผ์ฌ ๋ฒ์ ์ด ๋ฉ๋๋ค.
์ฅ๋จ์
| ์ฅ์ | ๋จ์ |
|---|---|
| exact ์ ํ์ฑ์ ์ ์งํ๋ค | ์ปค๋๊ณผ ์์์ด ๋ณต์กํด ๊ตฌํ ๋๋๊ฐ ๋๋ค |
| ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด O(N)์ด๋ค | ํ๋์จ์ด์ ์ํ์ค ๊ธธ์ด์ ๋ฐ๋ผ ์ด๋์ด ๋ฌ๋ผ์ง๋ค |
| ๊ธด ๋ฌธ๋งฅ์์ ํนํ ์ ๋ฆฌํ๋ค | KV ์บ์ ์์ฒด๋ฅผ ์ค์ด์ง๋ ์๋๋ค |
| block-sparse๋ก๋ ํ์ฅํ ์ ์๋ค | ์งง์ ์ํ์ค์์๋ ๊ฐ์ ํญ์ด ์์ ์ ์๋ค |
๊ด๋ จ ๊ธฐ์
๊ทธ๋ฆผ 5. ์ธ๋๋ณ ์งํ, Flash-Decoding, ๊ทธ๋ฆฌ๊ณ FlashAttention์ ์์น
์ธ๋๋ณ ์งํ
FlashAttention-2๋ ๋ณ๋ ฌํ์ work ๋ถํ ์ ๊ฐ์ ํด A100์์ ์ฝ 70%์ ์ด๋ก ์ ์ต๋ FLOPS๋ฅผ ๋ฌ์ฑํฉ๋๋ค. FlashAttention-3๋ Hopper(H100)์ ๋น๋๊ธฐ ๊ธฐ๋ฅ(warp-specialization์ผ๋ก ์ฐ์ฐยท๋ฐ์ดํฐ ์ด๋ ์ค์ฒฉ, TMA)๊ณผ FP8 ์ ์ ๋ฐ(incoherent processing)์ ํ์ฉํด, FP16์์ 740 TFLOPS(75% ํ์ฉ)ยทFP8์์ ์ฝ 1.2 PFLOPS์ ์ด๋ฆ ๋๋ค(FA-2๋ H100์์ 35%๋ง ํ์ฉ).
Flash-Decoding โ ์ถ๋ก (๋์ฝ๋ฉ)์ฉ ๋ณํ
์ถ๋ก ์ ๋์ฝ๋ฉ ๋จ๊ณ๋ ์ง์๊ฐ 1๊ฐ์ธ๋ฐ KV๋ ์์ฒ ๊ฐ๋ผ memory-bound์ด๊ณ , ์ง์ ๊ธธ์ด๋ก ๋ณ๋ ฌํํ๋ FlashAttention๋ง์ผ๋ก๋ ์์ ๋ฐฐ์น์์ ๋ณ๋ ฌ์ฑ์ด ๋ถ์กฑํฉ๋๋ค. Flash-Decoding์ KV ๊ธธ์ด ์ถ์ผ๋ก ๋ถํ ํด ๋ถ๋ถ ์ดํ ์ ์ ๋ณ๋ ฌ ๊ณ์ฐํ๊ณ , log-sum-exp๋ก ํฉ์ณ(reduction) ์์ ๋ฐฐ์น์์๋ GPU๋ฅผ ๊ฐ๋ ์ฑ์๋๋ค. xFormersยทvLLM ๋ฑ์ ํตํฉ๋์ด ์์ผ๋ฉฐ, FlashDecoding++๋ ๋์ฝ๋ฉ์์ ์ต๋ 2.02๋ฐฐ ํฅ์์ ๋ณด๊ณ ํฉ๋๋ค.
๋ด๋ถ ๋ฌธ์
- PagedAttention Analysis
- KV Cache Quantization Analysis
- KV Cache Offloading Analysis
- Continuous Batching Analysis
- LLM Inference Scheduler Analysis
- Memory Centric LLM Serving Survey
์๋ฌธ
| ์๋ฃ | ํต์ฌ |
|---|---|
| Dao et al., 2022, FlashAttention | IO-aware exact attention์ ๊ธฐ๋ณธ ์๋ฆฌ์ IO ๋ณต์ก๋ ํํ |
| Dao et al., 2023, FlashAttention-2 | ๋ณ๋ ฌํ์ work partition์ ๊ฐ์ ํ ํ์ ์ธ๋ |
| Shah et al., 2024, FlashAttention-3 | Hopper ๋น๋๊ธฐ ์คํ๊ณผ FP8 ์ต์ ํ |
| FlashDecoding++ ์๋ฃ | ๋์ฝ๋ฉ ๋จ๊ณ์์ KV ๊ธธ์ด ์ถ ๋ณ๋ ฌํ์ ์ฑ๋ฅ ํฅ์ |
์ฑ๋ฅ ์์ฝ
| ํญ๋ชฉ | ๋ด์ฉ |
|---|---|
| exact (์ ํ) | ๊ทผ์ฌ๊ฐ ์๋ ์ ํํ ์ดํ ์ โ ํ์ง ์์ค ์์ |
| ๋ฉ๋ชจ๋ฆฌ O(N) | NรN ๋์ ์ ํ ๋ฉ๋ชจ๋ฆฌ โ ์ฝ 10~20๋ฐฐ ์ ๊ฐ |
| ์๋ 2~4๋ฐฐ | HBM ์ ๊ทผ ๊ธ๊ฐ โ GPT-2 ์ต๋ 7.6๋ฐฐ, BERT-large 15%โ, LRA 2.4๋ฐฐ |
| ๊ธด ๋ฌธ๋งฅ | ๋ฉ๋ชจ๋ฆฌ ํ๊ณ ์ํ๋ก ๋ ๊ธด ์ํ์ค ํ์ต ๊ฐ๋ฅ(ํ์ง๋ ํฅ์) |
์ธ๋๋ณ ๋น๊ต
| ์ธ๋ | ๋ฐํ | ํต์ฌ |
|---|---|---|
| FlashAttention-1 | NeurIPS 2022 (2205.14135) | tiling + online softmax + recomputation. IO-aware ์ ํ ์ดํ ์ |
| FlashAttention-2 | 2023 (2307.08691) | ๋ณ๋ ฌํยทwork ๋ถํ ๊ฐ์ , A100 ~70% ํ์ฉ, FA-1 ๋๋น ์ฝ 2๋ฐฐ |
| FlashAttention-3 | 2024 (2407.08608) | Hopper ๋น๋๊ธฐ(warp-specializationยทTMA)+FP8, FA-2 ๋๋น 1.5~2๋ฐฐ |
ํต์ฌ ์ ๋ฆฌ
FlashAttention์ IO-aware exact attention์ผ๋ก, NรN์ HBM์ ๋ง๋ค์ง ์๊ณ (tiling + online softmax) ์ญ์ ํ๋ ํต๊ณ๋ก ์ฌ๊ณ์ฐ(recompute)ํ๋ค. ์ดํ ์ ์ด memory-bound๋ผ๋ ํต์ฐฐ์์ ์ถ๋ฐํด, ์ฐ์ฐ์ ์ค์ด๋ ๋์ HBM ์ ๊ทผ์ ์ค์ฌ ์ ํํ๋ฉด์ ๋ฉ๋ชจ๋ฆฌ O(N)ยท์๋ 2~4๋ฐฐ๋ฅผ ๋ฌ์ฑํ๋ค.
FA-1(์๋ฆฌ)โFA-2(๋ณ๋ ฌํยทwork ๋ถํ )โFA-3(Hopper ๋น๋๊ธฐยทFP8)๋ก ํ๋์จ์ด์ ๋ง์ถฐ ์งํํ๊ณ , ์ถ๋ก ๋์ฝ๋ฉ์ฉ Flash-Decoding์ KV ๊ธธ์ด ์ถ์ผ๋ก ๋ถํ ๋ณ๋ ฌํํ๋ค. ์ค์ํ ์ ์ FlashAttention์ด KV ์บ์๋ฅผ '์ค์ด์ง' ์๋๋ค๋ ๊ฒ โ ์ดํ ์ ๊ณ์ฐ์ IO ํจํด์ ์ต์ ํํ ๋ฟ์ด๋ค.
๋ฐ๋ผ์ KV ์ ๊ฐ(์์ํยทMLAยท์คํ๋ก๋ฉ)๊ณผ ์ง๊ต์ ์ผ๋ก ํจ๊ป ์ฐ์ด๋ฉฐ, PagedAttention/vLLM๋ ๋ด๋ถ์ ์ผ๋ก Flash ์ปค๋์ ์ฌ์ฉํ๋ค. ๋ฉ๋ชจ๋ฆฌ ์์คํ ๊ด์ ์์๋ SRAMโHBM ๊ณ์ธต์์ ๋ฐ์ดํฐ๋ฅผ ์ด๋์ ๋๊ณ ์ธ์ ๋ค์ ๊ณ์ฐํ ์ง์ ๋ํ ์ผ๋ฐ์ ์ธ IO ํธ๋ ์ด๋์คํ๋ฅผ ์ ๋ณด์ฌ์ฃผ๋ ์ฌ๋ก๋ค.
์ฃผ์ โ ๋ณธ๋ฌธ ์์น๋ ์๋ ผ๋ฌธยท๊ธฐ์ ์๋ฃ(FlashAttention arXiv 2205.14135ยทNeurIPS 2022, FA-2 arXiv 2307.08691, FA-3 arXiv 2407.08608, Flash-Decoding/FlashDecoding++ ์๋ฃ)์ ๋ณด๊ณ ๊ฐ์ด๋ค. 'GPT-2 7.6๋ฐฐ', 'BERT 15%', '๋ฉ๋ชจ๋ฆฌ 10~20๋ฐฐ', 'A100 312 TFLOPSยท~2TB/s', '740 TFLOPSยท1.2 PFLOPS' ๊ฐ์ ๊ฐ์ ํน์ ๋ชจ๋ธยทํ๋์จ์ดยท์ํ์ค ๊ธธ์ดยท์ ๋ฐ๋ ์กฐ๊ฑด์ ๋ณด๊ณ ๊ฐ์ผ๋ก ์ผ๋ฐํ์ ์ฃผ์๊ฐ ํ์ํ๋ค. SRAMยทHBM ์ฉ๋/๋์ญํญ ์์น๋ GPU ์ธ๋์ ๋ฐ๋ผ ๋ค๋ฅด๋ฉฐ, ์๋ ํฅ์์ ์ํ์ค๊ฐ ๊ธธ์๋ก(๋ memory-bound์ผ์๋ก) ์ปค์ง๋ ๊ฒฝํฅ์ด ์๋ค.