Ryotta's Basic

LLM
๐Ÿค– LLM ๊ฒ€์ฆ์™„๋ฃŒ

FlashAttention Analysis

FlashAttention ์‹ฌ์ธต ๋ถ„์„

IO-aware Exact Attention ยท Tiling ยท Online Softmax ยท Recomputation ยท FA-2/FA-3 ยท Flash-Decoding

๊ฐœ์š”

FlashAttention์€ ์–ดํ…์…˜์„ ๊ทผ์‚ฌํ•˜์ง€ ์•Š์œผ๋ฉด์„œ(exact) ๋” ๋น ๋ฅด๊ณ  ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์ ์œผ๋กœ ๋งŒ๋“œ๋Š” IO-aware ์•Œ๊ณ ๋ฆฌ์ฆ˜์ž…๋‹ˆ๋‹ค. ํ•ต์‹ฌ ํ†ต์ฐฐ์€ ์–ดํ…์…˜์ด ์—ฐ์‚ฐ(FLOPs)์ด ์•„๋‹ˆ๋ผ ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ(HBM ์ฝ๊ธฐ/์“ฐ๊ธฐ)์— ๋ณ‘๋ชฉ์ด ์žˆ๋‹ค๋Š” ์ ์ž…๋‹ˆ๋‹ค.

ํ‘œ์ค€ ์–ดํ…์…˜์€ ๊ฑฐ๋Œ€ํ•œ Nร—N ํ–‰๋ ฌ์„ HBM์— ๋งŒ๋“ค๊ณ  ์—ฌ๋Ÿฌ ๋ฒˆ ์ฝ๊ณ  ์”๋‹ˆ๋‹ค. FlashAttention์€ QยทKยทV๋ฅผ SRAM์— ๋งž๋Š” ๋ธ”๋ก์œผ๋กœ ์ชผ๊ฐœ(tiling) online softmax๋กœ ์ ์ง„์ ์œผ๋กœ ๊ณ„์‚ฐํ•ด Nร—N์„ ์•„์˜ˆ ๋งŒ๋“ค์ง€ ์•Š๊ณ , ์—ญ์ „ํŒŒ์—์„œ๋Š” ํ†ต๊ณ„๋งŒ ์ €์žฅํ•ด ์žฌ๊ณ„์‚ฐ(recomputation)ํ•ฉ๋‹ˆ๋‹ค. ๊ฒฐ๊ณผ๋Š” ์ •ํ™•ํ•˜๋ฉด์„œ ๋ฉ”๋ชจ๋ฆฌ O(N)ยท์†๋„ 2~4๋ฐฐ์ž…๋‹ˆ๋‹ค. ๋ณธ ๋ฌธ์„œ๋Š” ๋ฌธ์ œ(๋ฉ”๋ชจ๋ฆฌ ๋ฐ”์šด๋“œ) โ†’ GPU ๋ฉ”๋ชจ๋ฆฌ ๊ณ„์ธต โ†’ ํ•ด๋ฒ•โ‘ (tiling+online softmax) โ†’ ํ•ด๋ฒ•โ‘ก(recomputation)์™€ ๊ฒฐ๊ณผ โ†’ ์ง„ํ™”(FA-2/FA-3ยทFlash-Decoding)์™€ ์œ„์น˜์˜ ์ˆœ์„œ๋กœ ๋ถ„์„ํ•ฉ๋‹ˆ๋‹ค.

ํ•ต์‹ฌ ๊ฐœ๋…

  • IO-aware: HBM ์ฝ๊ธฐ/์“ฐ๊ธฐ๋ฅผ ์ค„์—ฌ ์ „์ฒด ์†๋„๋ฅผ ๋Œ์–ด์˜ฌ๋ฆฌ๋Š” ์„ค๊ณ„์ž…๋‹ˆ๋‹ค.
  • Tiling: QยทKยทV๋ฅผ SRAM์— ๋งž๋Š” ๋ธ”๋ก์œผ๋กœ ๋‚˜๋ˆ„์–ด ๋ธ”๋ก ๋‹จ์œ„๋กœ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
  • Online softmax: running max์™€ running sum์œผ๋กœ ๋ธ”๋ก ๊ฒฐ๊ณผ๋ฅผ ์ ์ง„์ ์œผ๋กœ ํ•ฉ์นฉ๋‹ˆ๋‹ค.
  • Recomputation: ์—ญ์ „ํŒŒ์—์„œ ํ•„์š”ํ•œ ๊ฐ’์€ ํ†ต๊ณ„๋งŒ ์ €์žฅํ•˜๊ณ  ๋‹ค์‹œ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
  • FA-2/FA-3: ๋ณ‘๋ ฌํ™”์™€ Hopper ๋น„๋™๊ธฐยทFP8์œผ๋กœ FlashAttention์„ ํ™•์žฅํ•ฉ๋‹ˆ๋‹ค.

๋น„๊ต/๋ถ„์„

ํ•ญ๋ชฉ ํ‘œ์ค€ ์–ดํ…์…˜ FlashAttention
์ค‘๊ฐ„ ํ–‰๋ ฌ SยทP๋ฅผ HBM์— ์‹ค์ฒดํ™” Nร—N์„ ๋งŒ๋“ค์ง€ ์•Š๊ณ  ๋ธ”๋ก ๋‹จ์œ„๋กœ ๋ˆ„์ 
๋ฉ”๋ชจ๋ฆฌ ๋ณต์žก๋„ O(Nยฒ) O(N)
์ฃผ์š” ๋ณ‘๋ชฉ HBM ์ฝ๊ธฐ/์“ฐ๊ธฐ ๋ธ”๋ก ๊ณ„์‚ฐ๊ณผ ์žฌ๊ณ„์‚ฐ
๊ธด ์‹œํ€€์Šค ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์ด ๊ธ‰๊ฒฉํžˆ ์ฆ๊ฐ€ ๊ธธ์ด๊ฐ€ ๊ธธ์ˆ˜๋ก ์ด์ ์ด ์ปค์ง

๋™์ž‘ ์›๋ฆฌ

FlashAttention์˜ ๋™์ž‘ ์›๋ฆฌ๋Š” ํ‘œ์ค€ ์–ดํ…์…˜์ด ์–ด๋–ค ์ง€์ ์—์„œ HBM ์™•๋ณต์„ ๋ฐ˜๋ณตํ•˜๋Š”์ง€ ํŒŒ์•…ํ•œ ๋’ค, ๊ฐ™์€ ์ˆ˜ํ•™์  ๊ฒฐ๊ณผ๋ฅผ ์œ ์ง€ํ•œ ์ฑ„ ๋ฐ์ดํ„ฐ๋ฅผ SRAM ์•ˆ์—์„œ ์ตœ๋Œ€ํ•œ ์˜ค๋ž˜ ๋จธ๋ฌด๋ฅด๊ฒŒ ๋งŒ๋“œ๋Š” ๋ฐ ์žˆ์Šต๋‹ˆ๋‹ค. ์•„๋ž˜ ์ˆœ์„œ๋Œ€๋กœ ๋ณด๋ฉด ์™œ tiling, online softmax, recomputation์ด ํ•จ๊ป˜ ํ•„์š”ํ•ด์ง€๋Š”์ง€ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ์—ฐ๊ฒฐ๋ฉ๋‹ˆ๋‹ค.

1. ๋ฌธ์ œ โ€” ํ‘œ์ค€ ์–ดํ…์…˜์€ ๋ฉ”๋ชจ๋ฆฌ ๋ฐ”์šด๋“œ(IO ๋ณ‘๋ชฉ)

ํ‘œ์ค€ ์–ดํ…์…˜์€ S=QK^T(Nร—N) โ†’ P=softmax(S)(Nร—N) โ†’ O=PV๋กœ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ N์€ ์‹œํ€€์Šค ๊ธธ์ด์ด๋ฉฐ, ๊ฑฐ๋Œ€ํ•œ Nร—N ํ–‰๋ ฌ SยทP๋ฅผ HBM์— ์‹ค์ฒดํ™”(materialize)ํ•˜๊ณ  ์—ฌ๋Ÿฌ ํŒจ์Šค๋กœ ์ฝ๊ณ  ์”๋‹ˆ๋‹ค. ์‹œ๊ฐ„ยท๋ฉ”๋ชจ๋ฆฌ๊ฐ€ Nยฒ์— ๋น„๋ก€ํ•ฉ๋‹ˆ๋‹ค.

llm_0023_flashattention_analysis

๊ทธ๋ฆผ 1. ํ‘œ์ค€ ์–ดํ…์…˜์˜ Nร—N ์‹ค์ฒดํ™”์™€ ๋ฉ”๋ชจ๋ฆฌ ๋ฐ”์šด๋“œ ํŠน์„ฑ

ํ•ต์‹ฌ โ€” ์–ดํ…์…˜์€ ์—ฐ์‚ฐ์ด ์•„๋‹ˆ๋ผ HBM ์ ‘๊ทผ์ด ๋ณ‘๋ชฉ

  • ์–ดํ…์…˜์˜ ์ง„์งœ ๋ณ‘๋ชฉ์€ FLOPs๊ฐ€ ์•„๋‹ˆ๋ผ HBM ์ ‘๊ทผ(๋ฐ์ดํ„ฐ ์ด๋™)์ž…๋‹ˆ๋‹ค. A100์˜ ์—ฐ์‚ฐ ์ฒ˜๋ฆฌ๋Ÿ‰์€ 312 TFLOPS(FP16)๋กœ ๋งค์šฐ ๋น ๋ฅธ ๋ฐ˜๋ฉด HBM ๋Œ€์—ญํญ์€ ~2TB/s์— ๊ทธ์ณ, ์—ฐ์‚ฐ ์œ ๋‹›์ด ๋ฐ์ดํ„ฐ๋ฅผ ๊ธฐ๋‹ค๋ฆฌ๋Š” memory-bound ์ƒํƒœ๊ฐ€ ๋ฉ๋‹ˆ๋‹ค.

  • ๊ทธ๋ž˜์„œ ๊ทผ์‚ฌ ์–ดํ…์…˜(LinformerยทPerformerยทsparse)์ด FLOPs๋ฅผ ์ค„์—ฌ๋„ ์‹ค์ œ wall-clock ์†๋„๋Š” ์ž˜ ๋นจ๋ผ์ง€์ง€ ์•Š์Šต๋‹ˆ๋‹ค โ€” ๊ทธ๋“ค์€ IO(๋ฉ”๋ชจ๋ฆฌ ์ฝ๊ณ  ์“ฐ๊ธฐ)๋ฅผ ๊ณ ๋ คํ•˜์ง€ ์•Š์•˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. ์ด๊ฒƒ์ด FlashAttention์˜ ํ•ต์‹ฌ ์ง€์ ์ž…๋‹ˆ๋‹ค.

  • ํ•ด๊ฒฐ์˜ ์—ด์‡ ๋Š” 'IO-aware' โ€” Nร—N์„ HBM์— ๋งŒ๋“ค์ง€ ์•Š๊ณ  ๋ฉ”๋ชจ๋ฆฌ ์ด๋™์„ ์ตœ์†Œํ™”ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

2. GPU ๋ฉ”๋ชจ๋ฆฌ ๊ณ„์ธต โ€” IO-awareness์˜ ๋ฌด๋Œ€

FlashAttention์„ ์ดํ•ดํ•˜๋ ค๋ฉด GPU์˜ ๊ฐ€ํŒŒ๋ฅธ ๋ฉ”๋ชจ๋ฆฌ ๊ณ„์ธต์„ ๋ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ž‘๊ณ  ๋งค์šฐ ๋น ๋ฅธ on-chip SRAM๊ณผ ํฌ๊ณ  ์ƒ๋Œ€์ ์œผ๋กœ ๋А๋ฆฐ HBM์˜ ์ฐจ์ด๊ฐ€ ๋ชจ๋“  ๊ฒƒ์˜ ์ถœ๋ฐœ์ ์ž…๋‹ˆ๋‹ค.

FlashAttention Analysis

๊ทธ๋ฆผ 2. SRAMโ†”HBM ๋ฉ”๋ชจ๋ฆฌ ๊ณ„์ธต๊ณผ IO ๋ณต์žก๋„

๋ฉ”๋ชจ๋ฆฌ ๊ณ„์ธต๊ณผ IO ๋ณต์žก๋„

  • SRAM์€ ์ˆ˜์‹ญ MB์— ๋ถˆ๊ณผํ•˜์ง€๋งŒ ๋Œ€์—ญํญ์ด HBM์˜ ์•ฝ 10๋ฐฐ์ž…๋‹ˆ๋‹ค(์—ฐ์‚ฐ ์œ ๋‹› ๋ฐ”๋กœ ์˜†). HBM์€ ์ˆ˜์‹ญ GB๋กœ ํฌ์ง€๋งŒ ์ƒ๋Œ€์ ์œผ๋กœ ๋А๋ฆฝ๋‹ˆ๋‹ค.

  • ํ‘œ์ค€ ์–ดํ…์…˜์˜ ๋น„ํšจ์œจ์€ Nร—N ํ–‰๋ ฌ์„ HBM์— ๋งŒ๋“ค๊ณ  scalingยทmaskingยทsoftmaxยทdropout ๋‹จ๊ณ„๋งˆ๋‹ค ๋‹ค์‹œ ์ฝ๊ณ  ์“ฐ๋Š” ๋ฐ ์žˆ์Šต๋‹ˆ๋‹ค. IO-aware ์„ค๊ณ„์˜ ๋ชฉํ‘œ๋Š” ๋ฐ์ดํ„ฐ๋ฅผ SRAM์— ์˜ฌ๋ ค ๊ณ„์‚ฐ์„ ๋๋‚ด๊ณ  HBM ์™•๋ณต์„ ์ตœ์†Œํ™”ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

  • FlashAttention์˜ IO ๋ณต์žก๋„๋Š” O(Nยฒdยฒ/M) HBM ์ ‘๊ทผ(M=SRAM ํฌ๊ธฐ, d=ํ—ค๋“œ ์ฐจ์›)์œผ๋กœ, ํ‘œ์ค€ ์–ดํ…์…˜์˜ ฮฉ(Nd+Nยฒ)๋ณด๋‹ค ํ›จ์”ฌ ์ ์Šต๋‹ˆ๋‹ค. SRAM์ด ํด์ˆ˜๋ก ์œ ๋ฆฌํ•˜๋ฉฐ, ์ •ํ™• ์–ดํ…์…˜์—์„œ ์ด๋ณด๋‹ค ๋” ์ค„์ผ ์ˆ˜ ์—†๋‹ค๋Š” ์ด๋ก ์  ํ•˜ํ•œ๋„ ์ฆ๋ช…ํ–ˆ์Šต๋‹ˆ๋‹ค. ์ฆ‰ '์—ฐ์‚ฐ์„ ์ค„์ด๋Š” ๊ฒƒ'์ด ์•„๋‹ˆ๋ผ '๋ฐ์ดํ„ฐ ์ด๋™์„ ์ค„์ด๋Š” ๊ฒƒ'์ด ํ•ต์‹ฌ์ž…๋‹ˆ๋‹ค.

3. ํ•ด๋ฒ• โ‘  โ€” Tiling + Online Softmax

FlashAttention์€ QยทKยทV๋ฅผ SRAM์— ๋งž๋Š” ๋ธ”๋ก์œผ๋กœ ์ชผ๊ฐœ(tiling) ๋ธ”๋ก ๋‹จ์œ„๋กœ ๊ณ„์‚ฐํ•˜๊ณ , softmax๋ฅผ ์ ์ง„์ ์œผ๋กœ(online) ๊ณ„์‚ฐํ•ด Nร—N ํ–‰๋ ฌ์„ ์•„์˜ˆ ๋งŒ๋“ค์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

FlashAttention Analysis

๊ทธ๋ฆผ 3. Tiling๊ณผ online softmax(running max/sum rescaling)

Online Softmax ๋ˆ„์  (๊ฐœ๋…)

Tiling + Online Softmax์˜ ํ•ต์‹ฌ

  • Tiling โ€” QยทKยทV๋ฅผ ๋ธ”๋ก์œผ๋กœ ์ชผ๊ฐœ ๋ธ”๋ก ์Œ์„ SRAM์— ์˜ฌ๋ ค ๊ณ„์‚ฐํ•˜๊ณ  ์ถœ๋ ฅ๋งŒ ๋ˆ„์ ํ•ฉ๋‹ˆ๋‹ค. ๊ฑฐ๋Œ€ํ•œ Nร—N ํ–‰๋ ฌ SยทP๋ฅผ HBM์— ์ „ํ˜€ ๋งŒ๋“ค์ง€ ์•Š์Šต๋‹ˆ๋‹ค(materialization ํšŒํ”ผ).

  • ๋‚œ์ ๊ณผ ํ•ด๊ฒฐ โ€” softmax๋Š” ํ–‰ ์ „์ฒด์˜ maxยทํ•ฉ์ด ํ•„์š”ํ•ด ๋ธ”๋ก ๋‹จ์œ„๋กœ ์ชผ๊ฐœ๊ธฐ ์–ด๋ ต์Šต๋‹ˆ๋‹ค(๋น„๊ฒฐํ•ฉ์ ). Online softmax๋Š” ๋ธ”๋ก๋งˆ๋‹ค running max(m)์™€ running sum(l)์„ ๊ฐฑ์‹ ํ•˜๊ณ , ์ƒˆ ๋ธ”๋ก์˜ max๊ฐ€ ๋” ํฌ๋ฉด ์ด์ „๊นŒ์ง€ ๋ˆ„์ ํ•œ ์ถœ๋ ฅ O๋ฅผ ์ง€์ˆ˜ ์ธ์ž๋กœ rescaleํ•ด ๋ณด์ •ํ•ฉ๋‹ˆ๋‹ค. softmax๋ฅผ '๊ฒฐํ•ฉ ๊ฐ€๋Šฅ'ํ•˜๊ฒŒ ๋งŒ๋“  ๊ฒƒ์ด FlashAttention์˜ ํ•ต์‹ฌ ์ˆ˜ํ•™์ž…๋‹ˆ๋‹ค.

  • ์ „๋ถ€ ํ•œ GPU ์ปค๋„๋กœ ์œตํ•ฉ(fuse)ํ•ฉ๋‹ˆ๋‹ค โ€” scalingยทmaskingยทsoftmaxยทdropoutยทPV๋ฅผ ํ•œ ํŒจ์Šค์— ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค. ๊ฒฐ๊ณผ๋Š” ํ‘œ์ค€ ์–ดํ…์…˜๊ณผ ์ •ํ™•ํžˆ ๋™์ผ(exact)ํ•˜๋ฉฐ ๊ทผ์‚ฌ๊ฐ€ ์•„๋‹™๋‹ˆ๋‹ค. ๋‹จ์ง€ IO ํŒจํ„ด๋งŒ ๋ฐ”๊ฟ‰๋‹ˆ๋‹ค.

# Q๋ฅผ ๋ธ”๋ก ๋‹จ์œ„๋กœ ๊ณ ์ •ํ•˜๊ณ , KยทV ๋ธ”๋ก์„ ์ˆœํšŒํ•˜๋ฉฐ ์ถœ๋ ฅ์„ ๋ˆ„์ 
m = -inf
l = 0
O = 0
for j in range(num_kv_blocks):
    # K_j, V_j ๋ฅผ SRAM์— ๋กœ๋“œ
    S_ij = Q_i @ K_j.T
    m_new = max(m, rowmax(S_ij))
    P_ij = exp(S_ij - m_new)
    l = exp(m - m_new) * l + rowsum(P_ij)
    O = exp(m - m_new) * O + P_ij @ V_j
    m = m_new

O = O / l

4. ํ•ด๋ฒ• โ‘ก โ€” Recomputation, ๊ทธ๋ฆฌ๊ณ  ๊ฒฐ๊ณผ

FlashAttention Analysis

๊ทธ๋ฆผ 4. Recomputation(ํ†ต๊ณ„๋งŒ ์ €์žฅํ•ด ์—ญ์ „ํŒŒ์—์„œ ์žฌ๊ณ„์‚ฐ)๊ณผ ๊ฒฐ๊ณผ

Recomputation

  • ์—ญ์ „ํŒŒ์—๋Š” ๋ณดํ†ต ์ˆœ์ „ํŒŒ์˜ Nร—N ์–ดํ…์…˜ ํ–‰๋ ฌ์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค โ€” ์ €์žฅํ•˜๋ฉด O(Nยฒ) ๋ฉ”๋ชจ๋ฆฌ์ž…๋‹ˆ๋‹ค. FlashAttention์€ Nร—N์„ ์ €์žฅํ•˜์ง€ ์•Š๊ณ  softmax ์ •๊ทœํ™” ํ†ต๊ณ„(running maxยทsum)๋งŒ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.

  • ์—ญ์ „ํŒŒ ๋•Œ ๊ทธ ํ†ต๊ณ„๋กœ SยทP๋ฅผ SRAM์—์„œ ์ฆ‰์„ ์žฌ๊ณ„์‚ฐ(recompute)ํ•˜์—ฌ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์„ ํ˜•์œผ๋กœ ๋งŒ๋“ญ๋‹ˆ๋‹ค. ํ•ต์‹ฌ ํ†ต์ฐฐ์€ recompute๋กœ FLOPs๋Š” ๋Š˜์ง€๋งŒ, memory-bound๋ผ HBM ์ ‘๊ทผ ๊ฐ์†Œ๊ฐ€ ๋” ์ปค์„œ '์ˆœ์†๋„'๊ฐ€ ์˜คํžˆ๋ ค ๋นจ๋ผ์ง„๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

๊ฒฐ๊ณผ (FlashAttention-1)

๋˜ํ•œ ์ •ํ™• ์–ดํ…์…˜์—์„œ HBM ์ ‘๊ทผ์„ ์ ๊ทผ์ ์œผ๋กœ ๋” ์ค„์ผ ์ˆ˜ ์—†๋‹ค๋Š” IO ๋ณต์žก๋„ ํ•˜ํ•œ์„ ์ฆ๋ช…ํ–ˆ๊ณ (์ตœ์ ), block-sparse๋กœ ํ™•์žฅํ•˜๋ฉด IO ์ธ์‹ + ํฌ์†Œ ํŒจํ„ด์œผ๋กœ ๊ธฐ์กด ๊ทผ์‚ฌ ์–ดํ…์…˜๋ณด๋‹ค๋„ ๋น ๋ฅธ ๊ทผ์‚ฌ ๋ฒ„์ „์ด ๋ฉ๋‹ˆ๋‹ค.

์žฅ๋‹จ์ 

์žฅ์  ๋‹จ์ 
exact ์ •ํ™•์„ฑ์„ ์œ ์ง€ํ•œ๋‹ค ์ปค๋„๊ณผ ์ˆ˜์‹์ด ๋ณต์žกํ•ด ๊ตฌํ˜„ ๋‚œ๋„๊ฐ€ ๋†’๋‹ค
๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์ด O(N)์ด๋‹ค ํ•˜๋“œ์›จ์–ด์™€ ์‹œํ€€์Šค ๊ธธ์ด์— ๋”ฐ๋ผ ์ด๋“์ด ๋‹ฌ๋ผ์ง„๋‹ค
๊ธด ๋ฌธ๋งฅ์—์„œ ํŠนํžˆ ์œ ๋ฆฌํ•˜๋‹ค KV ์บ์‹œ ์ž์ฒด๋ฅผ ์ค„์ด์ง€๋Š” ์•Š๋Š”๋‹ค
block-sparse๋กœ๋„ ํ™•์žฅํ•  ์ˆ˜ ์žˆ๋‹ค ์งง์€ ์‹œํ€€์Šค์—์„œ๋Š” ๊ฐœ์„ ํญ์ด ์ž‘์„ ์ˆ˜ ์žˆ๋‹ค

๊ด€๋ จ ๊ธฐ์ˆ 

FlashAttention Analysis

๊ทธ๋ฆผ 5. ์„ธ๋Œ€๋ณ„ ์ง„ํ™”, Flash-Decoding, ๊ทธ๋ฆฌ๊ณ  FlashAttention์˜ ์œ„์น˜

์„ธ๋Œ€๋ณ„ ์ง„ํ™”

FlashAttention-2๋Š” ๋ณ‘๋ ฌํ™”์™€ work ๋ถ„ํ• ์„ ๊ฐœ์„ ํ•ด A100์—์„œ ์•ฝ 70%์˜ ์ด๋ก ์  ์ตœ๋Œ€ FLOPS๋ฅผ ๋‹ฌ์„ฑํ•ฉ๋‹ˆ๋‹ค. FlashAttention-3๋Š” Hopper(H100)์˜ ๋น„๋™๊ธฐ ๊ธฐ๋Šฅ(warp-specialization์œผ๋กœ ์—ฐ์‚ฐยท๋ฐ์ดํ„ฐ ์ด๋™ ์ค‘์ฒฉ, TMA)๊ณผ FP8 ์ €์ •๋ฐ€(incoherent processing)์„ ํ™œ์šฉํ•ด, FP16์—์„œ 740 TFLOPS(75% ํ™œ์šฉ)ยทFP8์—์„œ ์•ฝ 1.2 PFLOPS์— ์ด๋ฆ…๋‹ˆ๋‹ค(FA-2๋Š” H100์—์„œ 35%๋งŒ ํ™œ์šฉ).

Flash-Decoding โ€” ์ถ”๋ก (๋””์ฝ”๋”ฉ)์šฉ ๋ณ€ํ˜•

์ถ”๋ก ์˜ ๋””์ฝ”๋”ฉ ๋‹จ๊ณ„๋Š” ์งˆ์˜๊ฐ€ 1๊ฐœ์ธ๋ฐ KV๋Š” ์ˆ˜์ฒœ ๊ฐœ๋ผ memory-bound์ด๊ณ , ์งˆ์˜ ๊ธธ์ด๋กœ ๋ณ‘๋ ฌํ™”ํ•˜๋Š” FlashAttention๋งŒ์œผ๋กœ๋Š” ์ž‘์€ ๋ฐฐ์น˜์—์„œ ๋ณ‘๋ ฌ์„ฑ์ด ๋ถ€์กฑํ•ฉ๋‹ˆ๋‹ค. Flash-Decoding์€ KV ๊ธธ์ด ์ถ•์œผ๋กœ ๋ถ„ํ• ํ•ด ๋ถ€๋ถ„ ์–ดํ…์…˜์„ ๋ณ‘๋ ฌ ๊ณ„์‚ฐํ•˜๊ณ , log-sum-exp๋กœ ํ•ฉ์ณ(reduction) ์ž‘์€ ๋ฐฐ์น˜์—์„œ๋„ GPU๋ฅผ ๊ฐ€๋“ ์ฑ„์›๋‹ˆ๋‹ค. xFormersยทvLLM ๋“ฑ์— ํ†ตํ•ฉ๋˜์–ด ์žˆ์œผ๋ฉฐ, FlashDecoding++๋Š” ๋””์ฝ”๋”ฉ์—์„œ ์ตœ๋Œ€ 2.02๋ฐฐ ํ–ฅ์ƒ์„ ๋ณด๊ณ ํ•ฉ๋‹ˆ๋‹ค.

๋‚ด๋ถ€ ๋ฌธ์„œ

์›๋ฌธ

์ž๋ฃŒ ํ•ต์‹ฌ
Dao et al., 2022, FlashAttention IO-aware exact attention์˜ ๊ธฐ๋ณธ ์›๋ฆฌ์™€ IO ๋ณต์žก๋„ ํ•˜ํ•œ
Dao et al., 2023, FlashAttention-2 ๋ณ‘๋ ฌํ™”์™€ work partition์„ ๊ฐœ์„ ํ•œ ํ›„์† ์„ธ๋Œ€
Shah et al., 2024, FlashAttention-3 Hopper ๋น„๋™๊ธฐ ์‹คํ–‰๊ณผ FP8 ์ตœ์ ํ™”
FlashDecoding++ ์ž๋ฃŒ ๋””์ฝ”๋”ฉ ๋‹จ๊ณ„์—์„œ KV ๊ธธ์ด ์ถ• ๋ณ‘๋ ฌํ™”์™€ ์„ฑ๋Šฅ ํ–ฅ์ƒ

์„ฑ๋Šฅ ์š”์•ฝ

ํ•ญ๋ชฉ ๋‚ด์šฉ
exact (์ •ํ™•) ๊ทผ์‚ฌ๊ฐ€ ์•„๋‹Œ ์ •ํ™•ํ•œ ์–ดํ…์…˜ โ€” ํ’ˆ์งˆ ์†์‹ค ์—†์Œ
๋ฉ”๋ชจ๋ฆฌ O(N) Nร—N ๋Œ€์‹  ์„ ํ˜• ๋ฉ”๋ชจ๋ฆฌ โ†’ ์•ฝ 10~20๋ฐฐ ์ ˆ๊ฐ
์†๋„ 2~4๋ฐฐ HBM ์ ‘๊ทผ ๊ธ‰๊ฐ โ†’ GPT-2 ์ตœ๋Œ€ 7.6๋ฐฐ, BERT-large 15%โ†‘, LRA 2.4๋ฐฐ
๊ธด ๋ฌธ๋งฅ ๋ฉ”๋ชจ๋ฆฌ ํ•œ๊ณ„ ์™„ํ™”๋กœ ๋” ๊ธด ์‹œํ€€์Šค ํ•™์Šต ๊ฐ€๋Šฅ(ํ’ˆ์งˆ๋„ ํ–ฅ์ƒ)

์„ธ๋Œ€๋ณ„ ๋น„๊ต

์„ธ๋Œ€ ๋ฐœํ‘œ ํ•ต์‹ฌ
FlashAttention-1 NeurIPS 2022 (2205.14135) tiling + online softmax + recomputation. IO-aware ์ •ํ™• ์–ดํ…์…˜
FlashAttention-2 2023 (2307.08691) ๋ณ‘๋ ฌํ™”ยทwork ๋ถ„ํ•  ๊ฐœ์„ , A100 ~70% ํ™œ์šฉ, FA-1 ๋Œ€๋น„ ์•ฝ 2๋ฐฐ
FlashAttention-3 2024 (2407.08608) Hopper ๋น„๋™๊ธฐ(warp-specializationยทTMA)+FP8, FA-2 ๋Œ€๋น„ 1.5~2๋ฐฐ

ํ•ต์‹ฌ ์ •๋ฆฌ

FlashAttention์€ IO-aware exact attention์œผ๋กœ, Nร—N์„ HBM์— ๋งŒ๋“ค์ง€ ์•Š๊ณ (tiling + online softmax) ์—ญ์ „ํŒŒ๋Š” ํ†ต๊ณ„๋กœ ์žฌ๊ณ„์‚ฐ(recompute)ํ•œ๋‹ค. ์–ดํ…์…˜์ด memory-bound๋ผ๋Š” ํ†ต์ฐฐ์—์„œ ์ถœ๋ฐœํ•ด, ์—ฐ์‚ฐ์„ ์ค„์ด๋Š” ๋Œ€์‹  HBM ์ ‘๊ทผ์„ ์ค„์—ฌ ์ •ํ™•ํ•˜๋ฉด์„œ ๋ฉ”๋ชจ๋ฆฌ O(N)ยท์†๋„ 2~4๋ฐฐ๋ฅผ ๋‹ฌ์„ฑํ•œ๋‹ค.

FA-1(์›๋ฆฌ)โ†’FA-2(๋ณ‘๋ ฌํ™”ยทwork ๋ถ„ํ• )โ†’FA-3(Hopper ๋น„๋™๊ธฐยทFP8)๋กœ ํ•˜๋“œ์›จ์–ด์— ๋งž์ถฐ ์ง„ํ™”ํ–ˆ๊ณ , ์ถ”๋ก  ๋””์ฝ”๋”ฉ์šฉ Flash-Decoding์€ KV ๊ธธ์ด ์ถ•์œผ๋กœ ๋ถ„ํ•  ๋ณ‘๋ ฌํ™”ํ•œ๋‹ค. ์ค‘์š”ํ•œ ์ ์€ FlashAttention์ด KV ์บ์‹œ๋ฅผ '์ค„์ด์ง€' ์•Š๋Š”๋‹ค๋Š” ๊ฒƒ โ€” ์–ดํ…์…˜ ๊ณ„์‚ฐ์˜ IO ํŒจํ„ด์„ ์ตœ์ ํ™”ํ•  ๋ฟ์ด๋‹ค.

๋”ฐ๋ผ์„œ KV ์ ˆ๊ฐ(์–‘์žํ™”ยทMLAยท์˜คํ”„๋กœ๋”ฉ)๊ณผ ์ง๊ต์ ์œผ๋กœ ํ•จ๊ป˜ ์“ฐ์ด๋ฉฐ, PagedAttention/vLLM๋„ ๋‚ด๋ถ€์ ์œผ๋กœ Flash ์ปค๋„์„ ์‚ฌ์šฉํ•œ๋‹ค. ๋ฉ”๋ชจ๋ฆฌ ์‹œ์Šคํ…œ ๊ด€์ ์—์„œ๋Š” SRAMโ†”HBM ๊ณ„์ธต์—์„œ ๋ฐ์ดํ„ฐ๋ฅผ ์–ด๋””์— ๋‘๊ณ  ์–ธ์ œ ๋‹ค์‹œ ๊ณ„์‚ฐํ• ์ง€์— ๋Œ€ํ•œ ์ผ๋ฐ˜์ ์ธ IO ํŠธ๋ ˆ์ด๋“œ์˜คํ”„๋ฅผ ์ž˜ ๋ณด์—ฌ์ฃผ๋Š” ์‚ฌ๋ก€๋‹ค.

์ฃผ์˜ โ€” ๋ณธ๋ฌธ ์ˆ˜์น˜๋Š” ์›๋…ผ๋ฌธยท๊ธฐ์ˆ  ์ž๋ฃŒ(FlashAttention arXiv 2205.14135ยทNeurIPS 2022, FA-2 arXiv 2307.08691, FA-3 arXiv 2407.08608, Flash-Decoding/FlashDecoding++ ์ž๋ฃŒ)์˜ ๋ณด๊ณ ๊ฐ’์ด๋‹ค. 'GPT-2 7.6๋ฐฐ', 'BERT 15%', '๋ฉ”๋ชจ๋ฆฌ 10~20๋ฐฐ', 'A100 312 TFLOPSยท~2TB/s', '740 TFLOPSยท1.2 PFLOPS' ๊ฐ™์€ ๊ฐ’์€ ํŠน์ • ๋ชจ๋ธยทํ•˜๋“œ์›จ์–ดยท์‹œํ€€์Šค ๊ธธ์ดยท์ •๋ฐ€๋„ ์กฐ๊ฑด์˜ ๋ณด๊ณ ๊ฐ’์œผ๋กœ ์ผ๋ฐ˜ํ™”์— ์ฃผ์˜๊ฐ€ ํ•„์š”ํ•˜๋‹ค. SRAMยทHBM ์šฉ๋Ÿ‰/๋Œ€์—ญํญ ์ˆ˜์น˜๋Š” GPU ์„ธ๋Œ€์— ๋”ฐ๋ผ ๋‹ค๋ฅด๋ฉฐ, ์†๋„ ํ–ฅ์ƒ์€ ์‹œํ€€์Šค๊ฐ€ ๊ธธ์ˆ˜๋ก(๋” memory-bound์ผ์ˆ˜๋ก) ์ปค์ง€๋Š” ๊ฒฝํ–ฅ์ด ์žˆ๋‹ค.