PointNet++模型;我在修改分割模型的类别时遇到断言问题

问题描述 投票:0回答:1

我之前这个模型有13个类别,现在我把它改为17个类别进行分割。我更新了“num_classes”和权重以对应 17 个类别。我还修改了“pointnet_sem_seg”以容纳 17 个类别。但是,我仍然遇到断言错误。

BN momentum updated to: 0.100000
  0%|                                          | 5/3121 [00:30<05:34,  9.32it/s]/opt/conda/conda-bld/pytorch_1614378098133/work/aten/src/THCUNN/ClassNLLCriterion.cu:108: cunn_ClassNLLCriterion_updateOutput_kernel: block: [0,0,0], thread: [10,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1614378098133/work/aten/src/THCUNN/ClassNLLCriterion.cu:108: cunn_ClassNLLCriterion_updateOutput_kernel: block: [0,0,0], thread: [26,0,0] Assertion `t >= 0 && t < n_classes` failed.
  0%|                                        | 5/3121 [00:30<5:19:21,  6.15s/it]
Traceback (most recent call last):
  File "/home/dl/Pointnet_Pointnet2_pytorch/train_semseg.py", line 303, in <module>
    main(args)
  File "/home/dl/Pointnet_Pointnet2_pytorch/train_semseg.py", line 202, in main
    loss = criterion(seg_pred, target, trans_feat, weights)
  File "/home/dl/anaconda3/envs/python3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/dl/Pointnet_Pointnet2_pytorch/models/pointnet_sem_seg.py", line 42, in forward
    mat_diff_loss = feature_transform_reguliarzer(trans_feat)
  File "/home/dl/Pointnet_Pointnet2_pytorch/models/pointnet_utils.py", line 140, in feature_transform_reguliarzer
    I = I.cuda()
RuntimeError: CUDA error: device-side assert triggered
pytorch point
1个回答
0
投票

您遇到的错误(运行时错误:CUDA 错误:设备端断言已触发)是 CUDA 代码中断言失败的结果。断言消息为 Assertion 't >= 0 && t < n_classes' failed. This assertion typically checks that the target labels fall within the expected range, which is determined by the number of classes (n_classes).

由于您已将模型更新为具有 17 个类别,因此请确保您已在使用的所有地方正确更新了 n_classes 变量。这里有几个地方需要检查和更新:

模型架构(pointnet_sem_seg.py):

将 PointNet++ 模型架构中的 num_classes 参数更新为 17。 损失函数(train_semseg.py):

确保用于分割的标准已更新以处理 17 个类别。 标签映射(class2label 和 seg_classes):

验证您的标签映射(class2label 和 seg_classes)是否反映了 17 个类别的新类别索引。 这是更新 seg_classes 的简单示例:

蟒蛇 复制代码 seg_classes = {1: 'class1', 2: 'class2', ..., 17: 'class17'} 权重张量大小:

检查损失函数中是否使用了权重张量,并确保其大小适合 17 个类别。 您的数据集中的类索引:

如果您使用自定义数据集类,请确保数据集中标签的类索引范围为 1 到 17。 进行这些更改后,再次运行代码并查看问题是否仍然存在。如果问题仍然存在,您可能需要进一步调查 CUDA 错误的详细信息,以确定代码中断言失败的确切位置。这可能涉及检查中间变量、打印调试信息或使用调试器。

© www.soinside.com 2019 - 2024. All rights reserved.