我想使用多个 GPU 进行矩阵乘法,如
torch.mm(a, b)
,以减少单个 GPU 上的内存使用量。
这是在单个 GPU 上运行的代码:
import torch
a = torch.randn(30000, 30000).cuda(1)
b = torch.randn(30000, 30000).cuda(1)
c = torch.mm(a, b)
# during this process, the maximum memory usage is 10491 MB.
这是在两个 GPU 上运行的代码:
import torch
# assuming `a1` and `a2` are parts of a big matrix
a1 = torch.randn(15000, 30000).cuda(0)
a2 = torch.randn(15000, 30000).cuda(1)
b1 = torch.randn(30000, 30000).cuda(0)
b2 = b1.cuda(1)
c1 = torch.mm(a1,b1)
c2 = torch.mm(a2,b2).to(0)
# for now, the result `c1` and `c2` is on GPU 0
# the maximun memory usage on GPU 1 is 7059 MB
# the maximum memory usage on GPU 0 is 8777 MB, bigger than 1 because the result is on it
c = torch.concat([c1, c2], dim=0)
# OOM because concat is not in-place
因此,如果我们能够就地进行 concat 操作,看起来它会按预期工作吗?或者我应该先将
c1
和 c2
移动到 CPU 内存,然后对它们进行cat,然后将cated结果移动到GPU?
我也尝试过 PyTorch 2.2 提供的张量并行性:
import torch
import torch.distributed as distributed
import os
from torch.distributed._tensor import init_device_mesh, Shard, distribute_tensor
from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
from visualize_sharding import visualize_sharding
mesh = init_device_mesh("cuda", (2,))
rank = distributed.get_rank()
big_tensor_1 = torch.randn(3, 2)
big_tensor_2 = torch.randn(2, 6)
print("big_tensor_1", big_tensor_1)
my_dtensor_1 = distribute_tensor(big_tensor_1, mesh, [Shard(dim=0)])
my_dtensor_2 = distribute_tensor(big_tensor_2, mesh, [Shard(dim=1)])
# visualize_sharding(my_dtensor_1, header="my_dtensor_1")
c = torch.mm(my_dtensor_1, my_dtensor_2)
print("c: ", c)
但是一切都会运行两次,因为命令是
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 tmp.py
,所以会随机生成两个big_tensor_1
,如何修改代码使其在两个进程中运行一次?
问题详细信息中列出了我尝试过的所有内容。
我尝试了以下方法,可以在一定程度上解决问题细节中的第一个问题,但并不能完全解决。
import torch
a1 = torch.randn(15000, 30000).cuda(0)
a2 = torch.randn(15000, 30000).cuda(1)
b1 = torch.randn(30000, 30000).cuda(0)
b2 = b1.cuda(1)
# create a empty tensor first,
# then directly use it to save the computation result,
# but its maximum memory usage on a single GPU is still high
c = torch.empty(30000, 30000).cuda(0)
c[:15000] = torch.mm(a1,b1)
c[15000:] = torch.mm(a2,b2).to(0)