创建 Pytorch 模型时出现属性错误

问题描述 投票:0回答:1

这是我的代码:

from torch import nn
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import DiceLoss 
class segmentationmodels():
  def __init__(self):
    super(segmentationmodels,self).__init__()
    self.arc=smp.Unet(
        encoder_name=ENCODER,
        encoder_weights=weights,
        in_channels=3,
        classes=1,
        activation=None
    )
  def forward(self,images,mask=None):
    logits=self.arc(images)
    if mask != None:
      loss1=DiceLoss(mode='binary')(logits,mask)
      loss2=nn.BCEWithLogitsLoss()(logits,mask)
      return logits,loss1+loss2
    return logits
model=segmentationmodels()
model.to(DEVICE)

这是错误:

AttributeError                            Traceback (most recent call last)
<ipython-input-123-b61ad7737aab> in <module>
      1 model=segmentationmodels()
----> 2 model.to(DEVICE)

AttributeError: 'segmentationmodels' object has no attribute 'to`
python-3.x machine-learning computer-vision image-segmentation unet-neural-network
1个回答
0
投票

您在创建课程时没有提到

torch.nn.Module
,这就是您收到该错误的原因。所以,你的代码应该是这样的:

from torch import nn
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import DiceLoss 
class segmentationmodels(torch.nn.Module):
  def __init__(self):
    super(segmentationmodels,self).__init__()
    self.arc=smp.Unet(
        encoder_name=ENCODER,
        encoder_weights=weights,
        in_channels=3,
        classes=1,
        activation=None
    )
  def forward(self,images,mask=None):
    logits=self.arc(images)
    if mask != None:
      loss1=DiceLoss(mode='binary')(logits,mask)
      loss2=nn.BCEWithLogitsLoss()(logits,mask)
      return logits,loss1+loss2
    return logits
model=segmentationmodels()
model.to(DEVICE)
© www.soinside.com 2019 - 2024. All rights reserved.