Train SOTA? Not so fast.
A lot of people want to be able to train their own multimodal large language models (MLLMs). Assuming you got your data from Sweatshop AI and your GPU from Jensen, you still can't really get started. This is because of two main reasons:
- It is no longer possible to fit the parameters of these models in the main memory of even the largest GPUs (80GB-A100 cards still aren't cutting it).
- Even if you could fit the model in a single GPU by swapping parameters between host and device memory, the high number of compute operations required results in extremely long training times (e.g. training GPT-3 with 175 billion params would require 288 years with a single V100).
Naturally we'll try and throw parallelism and multinode clusters at this problem. For the most part, data-parallel scale out tends to work well for models that are particularly parameter efficient (e.g. they have high ratio of FLOPS per forward pass / # of parameters). Unfortunately, this technique also has its limitations. Particularly as the number of GPUs increases, the communication overhead kills performance as more machines sit idle waiting for gradient updates from Ring All-Reduce or parameter servers.
In recent years, methods have been proposed to try and address this downfall like tensor (intra-layer) model parallelism, where matrix multiplications within each transformer layer are split over multiple GPUs (see Megatron). Although this works well for models of sizes up to 20 billion parameters on NVIDIA DGX A100 servers (with 8 x 80GB-A100 GPUs), it still breaks down for larger models. So lets try and do the logical thing and split the model across multiple multi-GPU servers. This also has drawbacks:
- The All-Reduce communication required for tensor parallelism needs to go through inter-server links which are slower than the high bandwidth NVLink available within a multi-GPU server, some techniques however have been tried to resolve this (see DiLoCo).
- A high degree of model parallelism can create small matrix multiplications, decreasing GPU utilization (see tile efficiency). This is because if we have a large model that fully utilizes a 1 GPU during a GEMM, splitting the model across 4 GPUs shrinks the size of the GEMM on each individual GPU leading to machines with underutilized GPUs. Further, if the model sharding (partitioning) results in uneven splits, some GPUs will finish their computations earlier than others.