PyTorch 的 DataLoader() 中的 next() 和 iter() 做了什么

问题描述 投票:0回答:1

我有以下代码:

import torch
import numpy as np
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader

# Load dataset
df = pd.read_csv(r'../iris.csv')

# Extract features and target
data = df.drop('target',axis=1).values
labels = df['target'].values

# Create tensor dataset
iris = TensorDataset(torch.FloatTensor(data),torch.LongTensor(labels))

# Create random batches
iris_loader = DataLoader(iris, batch_size=105, shuffle=True)

next(iter(iris_loader))

next()
iter()
在上面的代码中做了什么?我浏览了 PyTorch 的文档,但仍然不太明白
next()
iter()
在这里做什么。谁能帮忙解释一下?非常感谢。

python iterator pytorch next dataloader
1个回答
29
投票

这些是 python 的内置函数,它们用于处理可迭代对象。

基本上

iter()
调用
__iter__()
上的
iris_loader
方法,它返回一个迭代器。
next()
然后调用该迭代器上的
__next__()
方法以获得第一次迭代。再次运行
next()
会得到迭代器的第二项,以此类推

这种逻辑经常发生在“幕后”,例如在运行

for
循环时。它在可迭代对象上调用
__iter__()
方法,然后在返回的迭代器上调用
__next__()
直到它到达迭代器的末尾。然后它会引发一个
stopIteration
并且循环停止。

有关详细信息和一些细微差别,请参阅文档:https://docs.python.org/3/library/functions.html#iter

© www.soinside.com 2019 - 2024. All rights reserved.