我使用pytorch的DDP在同一台机器上用两个RTX3090 GPU训练我的模型,我发现时间成本比在单个GPU上没有DDP的训练要好。使用 DDP 训练我的模型大约需要 133 秒,而在没有 DDP 的情况下训练它需要 105 秒。数据正确加载到两个 GPU 的内存中。貌似在使用PyTorch的DistributedDataParallel(DDP)进行训练时,两块GPU并没有进行并行计算。相反,同一个任务被计算了两次。
我的主要代码如下:
import argparse
import os
import pickle
from torch.utils.data.distributed import DistributedSampler
import ruamel.yaml as yaml
import time
from pathlib import Path
import torch
import torch.nn as nn
from dataset.dataset_DPT import VQAFeatureDataset
from torch.utils.data import DataLoader
from m3ae.modules import M3AETransformerSS
import torch.distributed as dist
import numpy as np
from dataset.tools import Logger, create_dir, set_schedule
import torch.nn.functional as F
@torch.no_grad()
def evaluation(model, data_loader, device):
# test
score = 0
open_ended = 0
closed_ended = 0
open_score = 0
close_score = 0
stage1_dict = []
# model.load_state_dict(torch.load('model_mlm.pth'))
model.eval()
header = 'Generate VQA test result:'
with torch.no_grad():
for index, batch in enumerate(data_loader):
batch['image'][0] = batch['image'][0].to(device)
batch['text_ids'] = batch['text_ids'].to(device)
batch['text_labels'] = batch['text_labels'].to(device)
batch['text_ids_mlm'] = batch['text_ids_mlm'].to(device)
batch['text_labels_mlm'] = batch['text_labels_mlm'].to(device)
batch['text_masks'] = batch['text_masks'].to(device)
target_list = []
for l in range(len(batch['vqa_labels'])):
target_list.append(batch['vqa_labels'][l].unsqueeze(0))
targets = torch.cat(target_list, dim=0).to(device)
target = torch.argmax(targets, dim=1)
ans_type = batch['answer_types']
phrase_type = batch['phrase_type']
logits = model(batch)
values, indices = torch.topk(logits, k=8, dim=-1)
values = values.detach().cpu().numpy()
indices = indices.detach().cpu().numpy()
for i in range(values.shape[0]):
stage1_dict.append((indices[i].astype(np.int16), values[i].astype(np.float16)))
pred_score = torch.argmax(logits, dim=1)
for i in range(len(ans_type)):
if ans_type[i] == 'OPEN':
open_ended += 1
if target[i] == pred_score[i]:
open_score += 1
elif ans_type[i] == 'CLOSED':
closed_ended += 1
if target[i] == pred_score[i]:
close_score += 1
score += open_score + close_score
# with open('data/vqa/data_RAD/stage1_train.pkl', 'wb') as fp:
# pickle.dump(stage1_dict, fp)
score = (score / (open_ended + closed_ended))
open_score = (open_score / open_ended)
close_score = (close_score / closed_ended)
return score, open_score, close_score
def main(args, config):
create_dir('outputs')
logger = Logger('outputs/log.txt')
logger.write(args.__repr__())
# device = torch.device(args.device)
if args.local_rank != -1:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend="nccl")
trainset = VQAFeatureDataset('train')
testset = VQAFeatureDataset('test')
# tokenizer = BertTokenizer.from_pretrained(args.text_encoder)
#### Creating Model ####
print("Creating model")
config['load_path'] = '/home/liyong/PythonWorkspace/M3AE-vqa/checkpoints/m3ae.ckpt'
# config['load_path'] = ''
model = M3AETransformerSS(config)
model = model.to(device)
num_gpus = torch.cuda.device_count()
if num_gpus > 1:
print('use {} gpus!'.format(num_gpus))
model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
output_device=args.local_rank, find_unused_parameters=True)
word_size = dist.get_world_size()
train_sampler = DistributedSampler(trainset, num_replicas=word_size, rank=args.local_rank)
train_loader = DataLoader(trainset, batch_size=config['batch_size'], num_workers=4, sampler=train_sampler,
collate_fn=trainset.collote)
test_loader = DataLoader(testset, batch_size=config['batch_size'], num_workers=4, collate_fn=testset.collote,
shuffle=False)
optimizer, scheduler = set_schedule(model, config, len(trainset.entries))
# print(model)
# score, open_score, close_score = evaluation(model, test_loader, device)
best_score = 0
best_epoch = 0
for epoch in range(config['max_epoch']):
train_sampler.set_epoch(epoch)
strat_time = time.time()
print(f"Start running basic DDP example on rank {args.local_rank}.")
total_loss = 0
for index, batch in enumerate(train_loader):
batch['image'][0] = batch['image'][0].to(device)
batch['text_ids'] = batch['text_ids'].to(device)
batch['text_labels'] = batch['text_labels'].to(device)
batch['text_ids_mlm'] = batch['text_ids_mlm'].to(device)
batch['text_labels_mlm'] = batch['text_labels_mlm'].to(device)
batch['text_masks'] = batch['text_masks'].to(device)
target_list = []
for l in range(len(batch['vqa_labels'])):
target_list.append(batch['vqa_labels'][l].unsqueeze(0))
targets = torch.cat(target_list, dim=0).to(device)
logits = model(batch)
# loss = criterion(logits.float(), targets)
loss = F.binary_cross_entropy_with_logits(logits.float(), targets)
total_loss += loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
end_time = time.time()
if args.local_rank == 0:
score, open_score, close_score = evaluation(model, test_loader, device)
logger.write('epoch ========== %d' % (epoch))
logger.write('overall: %.4f, open: %.4f, close: %.4f, loss: %.4f, lr: %.6f, time: %.4f'
% (score, open_score, close_score, total_loss, optimizer.state_dict()['param_groups'][0]['lr'], end_time-strat_time))
if score > best_score:
torch.save(model.module.state_dict(), '/home/liyong/PythonWorkspace/M3AE-vqa/model_mlm_ddp.pth')
best_epoch = epoch
best_score = score
# print("best_score ======= " + str(round(best_score, 4)) + " best_epoch ======= " + str(best_epoch))
logger.write('best_score: %.4f, best_epoch: %d' % (best_score, best_epoch))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='/home/liyong/PythonWorkspace/M3AE-vqa/configs/RAD_M3AE.yaml')
# parser.add_argument('--checkpoint', default='./pretrain/2022-09-11/med_pretrain_29.pth')
parser.add_argument('--checkpoint', default=None)
parser.add_argument('--output_dir', default='/home/liyong/PythonWorkspace/M3AE-vqa/output/rad')
parser.add_argument('--evaluate', action='store_true')
parser.add_argument('--text_encoder', default='bert-base-uncased')
parser.add_argument('--text_decoder', default='bert-base-uncased')
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
parser.add_argument('--distributed', default=False, type=bool)
parser.add_argument("--local_rank", default=os.getenv('LOCAL_RANK', -1), type=int)
args = parser.parse_args()
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
config['image_size'] = 224
config['tokenizer'] = 'roberta-base'
args.result_dir = os.path.join(args.output_dir, 'result')
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
Path(args.result_dir).mkdir(parents=True, exist_ok=True)
# yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
print("config: ", config)
print("args: ", args)
main(args, config)
我使用以下命令来执行我的主要代码:
python -m torch.distributed.launch --nproc_per_node=2 DPT_MLM_ddp2.py