使用多 GPU 的 pytorch DDP 训练的时间成本

问题描述 投票:0回答:0

我使用pytorch的DDP在同一台机器上用两个RTX3090 GPU训练我的模型,我发现时间成本比在单个GPU上没有DDP的训练要好。使用 DDP 训练我的模型大约需要 133 秒,而在没有 DDP 的情况下训练它需要 105 秒。数据正确加载到两个 GPU 的内存中。貌似在使用PyTorch的DistributedDataParallel(DDP)进行训练时,两块GPU并没有进行并行计算。相反,同一个任务被计算了两次。

我的主要代码如下:

import argparse
import os
import pickle
from torch.utils.data.distributed import DistributedSampler
import ruamel.yaml as yaml
import time
from pathlib import Path
import torch
import torch.nn as nn
from dataset.dataset_DPT import VQAFeatureDataset
from torch.utils.data import DataLoader
from m3ae.modules import M3AETransformerSS
import torch.distributed as dist
import numpy as np
from dataset.tools import Logger, create_dir, set_schedule
import torch.nn.functional as F


@torch.no_grad()
def evaluation(model, data_loader, device):
    # test
    score = 0
    open_ended = 0
    closed_ended = 0
    open_score = 0
    close_score = 0
    stage1_dict = []
    # model.load_state_dict(torch.load('model_mlm.pth'))
    model.eval()

    header = 'Generate VQA test result:'
    with torch.no_grad():
        for index, batch in enumerate(data_loader):
            batch['image'][0] = batch['image'][0].to(device)
            batch['text_ids'] = batch['text_ids'].to(device)
            batch['text_labels'] = batch['text_labels'].to(device)
            batch['text_ids_mlm'] = batch['text_ids_mlm'].to(device)
            batch['text_labels_mlm'] = batch['text_labels_mlm'].to(device)
            batch['text_masks'] = batch['text_masks'].to(device)

            target_list = []
            for l in range(len(batch['vqa_labels'])):
                target_list.append(batch['vqa_labels'][l].unsqueeze(0))
            targets = torch.cat(target_list, dim=0).to(device)

            target = torch.argmax(targets, dim=1)
            ans_type = batch['answer_types']
            phrase_type = batch['phrase_type']

            logits = model(batch)


            values, indices = torch.topk(logits, k=8, dim=-1)
            values = values.detach().cpu().numpy()
            indices = indices.detach().cpu().numpy()
            for i in range(values.shape[0]):
                stage1_dict.append((indices[i].astype(np.int16), values[i].astype(np.float16)))

            pred_score = torch.argmax(logits, dim=1)

            for i in range(len(ans_type)):
                if ans_type[i] == 'OPEN':
                    open_ended += 1
                    if target[i] == pred_score[i]:
                        open_score += 1
                elif ans_type[i] == 'CLOSED':
                    closed_ended += 1
                    if target[i] == pred_score[i]:
                        close_score += 1

        score += open_score + close_score
    # with open('data/vqa/data_RAD/stage1_train.pkl', 'wb') as fp:
    #     pickle.dump(stage1_dict, fp)
    score = (score / (open_ended + closed_ended))
    open_score = (open_score / open_ended)
    close_score = (close_score / closed_ended)

    return score, open_score, close_score

def main(args, config):



    create_dir('outputs')
    logger = Logger('outputs/log.txt')
    logger.write(args.__repr__())
    # device = torch.device(args.device)

    if args.local_rank != -1:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl")


    trainset = VQAFeatureDataset('train')
    testset = VQAFeatureDataset('test')


    # tokenizer = BertTokenizer.from_pretrained(args.text_encoder)

    #### Creating Model ####
    print("Creating model")
    config['load_path'] = '/home/liyong/PythonWorkspace/M3AE-vqa/checkpoints/m3ae.ckpt'
    # config['load_path'] = ''

    model = M3AETransformerSS(config)
    model = model.to(device)

    num_gpus = torch.cuda.device_count()
    if num_gpus > 1:
        print('use {} gpus!'.format(num_gpus))
        model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                    output_device=args.local_rank, find_unused_parameters=True)
    word_size = dist.get_world_size()
    train_sampler = DistributedSampler(trainset, num_replicas=word_size, rank=args.local_rank)
    train_loader = DataLoader(trainset, batch_size=config['batch_size'], num_workers=4, sampler=train_sampler,
                              collate_fn=trainset.collote)

    test_loader = DataLoader(testset, batch_size=config['batch_size'], num_workers=4, collate_fn=testset.collote,
                             shuffle=False)

    optimizer, scheduler = set_schedule(model, config, len(trainset.entries))

    # print(model)
    # score, open_score, close_score = evaluation(model, test_loader, device)
    best_score = 0
    best_epoch = 0
    for epoch in range(config['max_epoch']):
        train_sampler.set_epoch(epoch)

        strat_time = time.time()
        print(f"Start running basic DDP example on rank {args.local_rank}.")

        total_loss = 0

        for index, batch in enumerate(train_loader):
            batch['image'][0] = batch['image'][0].to(device)
            batch['text_ids'] = batch['text_ids'].to(device)
            batch['text_labels'] = batch['text_labels'].to(device)
            batch['text_ids_mlm'] = batch['text_ids_mlm'].to(device)
            batch['text_labels_mlm'] = batch['text_labels_mlm'].to(device)
            batch['text_masks'] = batch['text_masks'].to(device)

            target_list = []
            for l in range(len(batch['vqa_labels'])):
                target_list.append(batch['vqa_labels'][l].unsqueeze(0))
            targets = torch.cat(target_list, dim=0).to(device)
            logits = model(batch)
            # loss = criterion(logits.float(), targets)
            loss = F.binary_cross_entropy_with_logits(logits.float(), targets)
            total_loss += loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

        end_time = time.time()

        if args.local_rank == 0:
            score, open_score, close_score = evaluation(model, test_loader, device)

            logger.write('epoch ========== %d' % (epoch))
            logger.write('overall: %.4f,  open: %.4f,  close: %.4f,   loss: %.4f,   lr: %.6f, time: %.4f'
                         % (score, open_score, close_score, total_loss, optimizer.state_dict()['param_groups'][0]['lr'], end_time-strat_time))

            if score > best_score:
                torch.save(model.module.state_dict(), '/home/liyong/PythonWorkspace/M3AE-vqa/model_mlm_ddp.pth')
                best_epoch = epoch
                best_score = score

                # print("best_score ======= " + str(round(best_score, 4)) + "  best_epoch ======= " + str(best_epoch))
            logger.write('best_score: %.4f,      best_epoch: %d' % (best_score, best_epoch))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default='/home/liyong/PythonWorkspace/M3AE-vqa/configs/RAD_M3AE.yaml')
    # parser.add_argument('--checkpoint', default='./pretrain/2022-09-11/med_pretrain_29.pth')
    parser.add_argument('--checkpoint', default=None)
    parser.add_argument('--output_dir', default='/home/liyong/PythonWorkspace/M3AE-vqa/output/rad')
    parser.add_argument('--evaluate', action='store_true')
    parser.add_argument('--text_encoder', default='bert-base-uncased')
    parser.add_argument('--text_decoder', default='bert-base-uncased')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    parser.add_argument('--distributed', default=False, type=bool)
    parser.add_argument("--local_rank", default=os.getenv('LOCAL_RANK', -1), type=int)
    args = parser.parse_args()

    config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)

    config['image_size'] = 224
    config['tokenizer'] = 'roberta-base'

    args.result_dir = os.path.join(args.output_dir, 'result')

    Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    Path(args.result_dir).mkdir(parents=True, exist_ok=True)

    # yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))

    print("config: ", config)
    print("args: ", args)
    main(args, config)

我使用以下命令来执行我的主要代码:

python -m torch.distributed.launch --nproc_per_node=2 DPT_MLM_ddp2.py

训练过程中的一些输出日志如下: y 有谁知道如何解决这个问题,谢谢!

python deep-learning pytorch distributed-system
© www.soinside.com 2019 - 2024. All rights reserved.