Streetfight Transformers Notes
My Takeway
Example: LLama 3.1 8B
Layers (N): 32, Model Dim(D): 4096, FFN Dim: 14336, Attention Heads(NH): 32, K/V Head 8, Vocab Size(V): 128,000
Estimated Param ≈ 128,000 * 4096 + 32 * 4 * 4096 * 4096 + 32 * (4096 * 14336 * 2) ≈ 6.4 B
Miss Params that I can think of: WPE
Overhead = 25% ig
Memory wise:
In order to Inference it you need:
B = 1
Param + N * T * B * 2D = 8B + 32 + 128000 * B * 4096*2 ≈ 8B
inference:
INT8 = 1 byte = 8GB
BFloat16 = 2 byte = 16 GB
TF32 = 2.375 byte = 19 GB
Train:
3 * N * 12 * D * D + N * T * B * 12 * D = 3 * 32 * 12 * 4096 * 4096 + 32 * 128000 * 1 * 12 * 4096 ≈ 220B
TF32 = 220 * 2.375 ≈ 522 GB = 8 H100s to train batch size of 1
Finetune:
B = 1
R = 64
N*12*D*D+N*2*D*R (Param) +2*N*2*D*R (Adam) + B*N*T*(D+R+D+D+D) = 32 * 12 * 4096 * 4096 + 2 * 4096 * 64 + 2 * 32 * 2 * 4096 * 64 + 1 * 32 * 128000 * (4 * 4096 + 64) ≈ 73B
TF32 ≈ 173 GB = 8 5090s(if its 24/32gb) to train batch size of 1
B = 4
2 * 12 * 4096 * 4096 + 2 * 4096 * 64 + 2 * 32 * 2 * 4096 * 64 + 4 * 32 * 128000 * (4 * 4096 + 64) = 270B ≈ 641GB ≈ 8 H100s with mixed precision maybe
maybe with mixed precision you can finetune on 8xH100s
Ok time to earn enough money to buy 8xH100s! It’s only $260,500!