什么是x_train.reshape(),它的作用是什么?

问题描述 投票:-2回答:1

使用MNIST数据集

import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist

# MNIST dataset parameters
num_classes = 10 # total classes (0-9 digits)
num_features = 784 # data features (img shape: 28*28)

(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Convert to float32
x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)

# Flatten images to 1-D vector of 784 features (28*28)
x_train, x_test = x_train.reshape([-1, num_features]), x_test.reshape([-1, num_features])

# Normalize images value from [0, 255] to [0, 1]
x_train, x_test = x_train / 255., x_test / 255.

15号线 这些守则的内容是:

x_train, x_test = x_train.reshape([-1, num_features]), x_test.reshape([-1, num_features]). 我不明白这些重塑在我们的Dataset中到底做了什么?请解释一下。

python machine-learning keras tensorflow2.0 mnist
1个回答
1
投票

正如第14行所评论的那样,它将图像平坦化为784个特征的1-D向量(28*28),也就是说,它将尺寸为28*28的2-D NumPy数组覆盖为一个长度为784的1-D数组

© www.soinside.com 2019 - 2024. All rights reserved.