我目前正在使用一个数据集类别加载数据。在数据集中,我分别拆分了训练,测试和验证数据。例如:
class Data():
def __init__(self):
self.load()
def load(self):
with open(file=file_name, mode='r') as f:
self.data = f.readlines()
self.train = self.data[:checkpoint]
self.valid = self.data[checkpoint:halfway]
self.test = self.data[halfway:]
出于可读性考虑,省略了许多细节。基本上,我读取了一个大数据集并手动进行拆分。
我的问题是我的火车,有效数据和测试数据的长度都不同时,如何覆盖__len__
方法?
之所以这样做,是因为我想将拆分数据保留在一个类中,并且我还想为每个类创建单独的Dataloader,所以类似:
def __len__(self):
return len(self.train)
不适用于self.test
和self.valid
。
也许我从根本上误解了Dataloader,但是我应该如何解决这个问题?预先感谢。
我认为获取每个拆分的长度的最合适方法是简单地使用:
# Number of training points
len(self.train)
# Number of testing points
len(self.test)
# Number of validation points
len(self.valid)
或者,如果要引用对象定义之外的长度:
data = Data()
print(len(data.train))
print(len(data.test))
print(len(data.valid))
__len__
允许您以希望len()
表现的方式来实现它们,以计算对象的元素。因此,我将按以下方式实现它,并使用上述调用来获取每个拆分的计数:
def __len__(self):
return len(self.train) + len(self.test) + len(self.valid)