我试图通过以半精度进行计算来训练使用单精度权重的模型来节省内存。
我尝试使用自动转换,模型以半精度进行预测。 然而,产生的梯度仍然是单精度的。 这会破坏性能和节省内存。 有没有什么方法可以指示 torch 以半精度计算梯度并使用它们来更新原始的单精度权重?
import torch
class KekNet (torch.nn.Module):
def __init__(self):
super(KekNet, self).__init__()
self.layer1 = torch.nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), dtype=torch.float32)
def forward(self, x, features=False):
return self.layer1(x)
device = torch.device("cuda")
# HALF-DATA AUTOCAST
net = KekNet().to(device)
loss_l2 = torch.nn.MSELoss(reduction='none')
g_params = [{'params': net.parameters(), 'weight_decay': 0}]
optimizerG = torch.optim.RMSprop(g_params, lr=3e-5, alpha=0.99, eps=1e-07, weight_decay=0)
schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, T_max=300)
X = torch.randn((40,3,555,555), dtype=torch.float16, device =device)
with torch.autocast(device_type='cuda', dtype=torch.float16):
Y_h=net(X)
Y = torch.randn_like(Y_h)
loss = loss_l2(Y_h, Y).mean()
loss.backward()
print(f"-autocast\r\ndata precision: {X.dtype}\r\npred precision: {Y_h.dtype}\r\ngrad precision: {net.layer1.weight.grad.dtype}\r\n")
optimizerG.step()
schedulerG.step()
结果如下:
data precision: torch.float16
pred precision: torch.float16
grad precision: torch.float32
Autocast 不会转换模型的权重,因此权重梯度将具有与权重相同的 dtype。您可以尝试在模型上手动调用
.half()
来更改此设置。我不确定是否有办法在保持 fp32 权重的同时计算 fp16 的梯度。
import torch
import torch.nn as nn
torch.set_default_device('cuda')
model = nn.Linear(8,1)
opt = torch.optim.SGD(model.parameters(), lr=1e-3)
x = torch.randn(12, 8, dtype=torch.float16)
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True):
output = model(x)
loss = output.mean()
loss.backward()
print(model.weight.grad.dtype)
# > torch.float32
opt.zero_grad()
model.half()
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True):
output = model(x)
loss = output.mean()
loss.backward()
print(model.weight.grad.dtype)
# > torch.float16
此外,某些运算在 fp16 中计算时存在数值稳定性问题。为了避免这种情况,pytorch 自动将某些选项转换为 fp32。您可以在此处找到完整列表。
在您的情况下,MSE 损失(实际上是
pow
函数)自动转换为 fp32。这不会改变上面示例中的权重梯度 dtype,但如果您看到 fp32 出现在其他地方,则值得注意。