在Python中,如何输入提示使用协议和类本身的属性的类方法?

问题描述 投票:0回答:1

我正在实现一个类,该类应该支持一些通用行为,用于使用 PytorchLightning 的

LightningDataModule
设置训练/验证/测试数据加载器。我想在这个泛型类中提供一些功能,但将初始化一些属性留给那些从它继承的属性。我对这个问题的尝试应该说明我的想法:

from typing import Protocol
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader


class HasTrainValTestDatasets(Protocol):
    @property
    def train_ds(self) -> Dataset: ...

    @property
    def val_ds(self) -> Dataset: ...

    @property
    def test_ds(self) -> Dataset: ...


class GenericTrainValTestDataModule(pl.LightningDataModule):
    def __init__(
        self,
        batch_size_train: int,
        batch_size_eval: int,
        num_workers: int = 0,
    ):
        self._batch_size_train = batch_size_train
        self._batch_size_eval = batch_size_eval
        self._num_workers = num_workers

    def train_dataloader(self: HasTrainValTestDatasets) -> DataLoader:
        return DataLoader(self.train_ds, batch_size=self._batch_size_train, num_workers=self._num_workers)

    def val_dataloader(self: HasTrainValTestDatasets) -> DataLoader:
        return DataLoader(self.val_ds, batch_size=self._batch_size_eval, num_workers=self._num_workers)

    def test_dataloader(self: HasTrainValTestDatasets) -> DataLoader:
        return DataLoader(self.test_ds, batch_size=self._batch_size_eval, num_workers=self._num_workers)

我受到了 Stack Overflow 答案的启发:https://stackoverflow.com/a/59128961/8543212

我想概括地说

DataLoader
是使用
batch_size_(train/eval)
num_workers
中的
train/val/test_ds
创建的,后者将由从该泛型类继承的类实现。

不幸的是,我无法正确输入提示。我的目标是使用

Protocol
来强制该通用类的用户提供
train/val/test_ds
。然而,我无法通过上面的例子让 mypy 满意,因为:

d.py:30: error: "HasTrainValTestDatasets" has no attribute "_batch_size_train"  [attr-defined]
d.py:30: error: "HasTrainValTestDatasets" has no attribute "_num_workers"  [attr-defined]
d.py:35: error: "HasTrainValTestDatasets" has no attribute "_batch_size_eval"  [attr-defined]
d.py:35: error: "HasTrainValTestDatasets" has no attribute "_num_workers"  [attr-defined]
d.py:40: error: "HasTrainValTestDatasets" has no attribute "_batch_size_eval"  [attr-defined]
d.py:40: error: "HasTrainValTestDatasets" has no attribute "_num_workers"  [attr-defined]

有没有办法告诉mypy

self
既是
HasTrainValTestDatasets
协议又是
GenericTrainValTestDataModule

对于那些想知道为什么坚持这样的设计的人,我想不出更好的概括 https://lightning.ai/docs/pytorch/stable/data/datamodule.html 来实现我的目的(但我可能是错的)。

重现步骤(假设我的演示存储在

d.py
中):

virtualenv venv
pip install torch pytorch-lightning
mypy --install-types d.py
python mypy torch python-typing pytorch-lightning
1个回答
0
投票

我想我(大致)找到了我正在寻找的解决方案:

from typing import Protocol
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import FakeData
from torchvision.transforms import ToTensor


class HasTrainValTestDatasets(Protocol):
    train_ds: Dataset
    val_ds: Dataset
    test_ds: Dataset


class GenericTrainValTestDataModule(pl.LightningDataModule, HasTrainValTestDatasets):
    def __init__(self, batch_size_train: int, batch_size_eval: int, num_workers: int = 0):
        self._batch_size_train = batch_size_train
        self._batch_size_eval = batch_size_eval
        self._num_workers = num_workers

    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.train_ds, batch_size=self._batch_size_train, num_workers=self._num_workers)

    def val_dataloader(self) -> DataLoader:
        return DataLoader(self.val_ds, batch_size=self._batch_size_eval, num_workers=self._num_workers)

    def test_dataloader(self) -> DataLoader:
        return DataLoader(self.test_ds, batch_size=self._batch_size_eval, num_workers=self._num_workers)


class ConcreteDataModuleOk(GenericTrainValTestDataModule):
    def __init__(self):
        super().__init__(8, 8, 0)
        self.train_ds = FakeData(transform=ToTensor())
        self.val_ds = FakeData(transform=ToTensor())
        self.test_ds = FakeData(transform=ToTensor())


class ConcreteDataModuleBad(GenericTrainValTestDataModule):
    def __init__(self):
        super().__init__(8, 8, 0)

cdm_ok = ConcreteDataModuleOk()

print(cdm_ok.train_ds[0][0].shape)
print(next(iter(cdm_ok.train_dataloader()))[0].shape)

cdm_bad = ConcreteDataModuleBad()  # <-- mypy will complain here

额外

pip install -r torchvision
当我跑步时:
mypy --ignore-missing d1.py
我得到:

d.py:49: error: Cannot instantiate abstract class "ConcreteDataModuleBad" with abstract attributes "test_ds", "train_ds" and "val_ds"  [abstract]
Found 1 error in 1 file (checked 1 source file)

这正是我想要实现的目标。如果我删除

Protocol
,mypy 将无法识别该问题,所以我猜协议会完成其工作。需要注意的是,由于某种原因,它无法识别
train/val/test_ds
何时具有错误类型(例如字符串而不是
Dataset
),但我可以使用它。

另外,正如@Reinderien所说,这也可以用抽象类来完成(也许有相同的结果)?我会坚持我的解决方案,因为它对我来说看起来很简洁。

© www.soinside.com 2019 - 2024. All rights reserved.