我正在观看Udacity中关于深度学习的在线课程,这个概念是对not_Mnist数据集的一个简单分类。所有内容都解释得太好但我对给出的部分代码有点困惑。如果你有时间我会很感激给我一个手!例如,我们有一个'notMNIST_large.tar.gz' file
。所以首先我们删除.tar.gz
,根是root = notMNIST_large
。之后我们检查是否已经有一个具有此名称的目录。如果不是我们从'notMNIST_large.tar.gz' file
中提取子文件夹,这就是我有点困惑的地方......
num_classes = 10
np.random.seed(133)
def maybe_extract(filename, force=False):
root = os.path.splitext(os.path.splitext(filename)[0])[0] # remove .tar.gz
if os.path.isdir(root) and not force:
# You may override by setting force=True.
print('%s already present - Skipping extraction of %s.' % (root, filename))
else:
print('Extracting data for %s. This may take a while. Please wait.' % root)
tar = tarfile.open(filename)
sys.stdout.flush()
tar.extractall(data_root)
tar.close()
data_folders = [
os.path.join(root, d) for d in sorted(os.listdir(root))
if os.path.isdir(os.path.join(root, d))]
if len(data_folders) != num_classes:
raise Exception(
'Expected %d folders, one per class. Found %d instead.' % (
num_classes, len(data_folders)))
print(data_folders)
return data_folders
train_folders = maybe_extract(train_filename)
test_folders = maybe_extract(test_filename)
所以我想尽可能解释这部分
data_folders = [
os.path.join(root, d) for d in sorted(os.listdir(root))
if os.path.isdir(os.path.join(root, d))]
if len(data_folders) != num_classes:
raise Exception(
'Expected %d folders, one per class. Found %d instead.' % (
num_classes, len(data_folders)))
它收集子目录列表并检查是否存在预期的编号。
data_folders = [thing(d) for d in something() if predicate(d)]
是一个列表理解,循环something()
的结果,并收集predicate
是True
的项目。它将thing()
应用于这些条目,并在data_folders
中收集结果列表。
这里,something
是当前目录中文件的列表,predicate
检查该项是否是目录(而不是例如常规文件); thing
是os.path.join(root,d)
,即我们在提取的条目前面添加root
目录。
因此,在这种情况下,代码检查子目录的数量是否与类的数量相同(可能每个子目录包含一个类)。