TypeError: `generator` 产生了一个形状元素 (32, 224, 224, 3) 而预期的形状元素 (224, 224, 3)

问题描述 投票:0回答:0

我的生成器代码将数据框(csv 文件)和图像作为输入并生成带有标签的图像。 我的生成器代码是:

class ImageSequence:
    def __init__(self, df, mode,img_size=(224, 224), num_channels=3):
        self.df = df
        self.indices = np.arange(len(df))
        self.batch_size = 32
        self.img_dir = 'dataset'
        self.img_size = tuple(img_size)
        self.num_channels = num_channels
        self.mode = mode
        
        
     def __getitem__(self, idx):
         sample_indices = self.indices[idx * self.batch_size:(idx + 1) * self.batch_size]
         imgs = []
         genders = []
         ages = []
         for _, row in self.df.iloc[sample_indices].iterrows():
             img = cv2.imread(str(os.path.join(self.img_dir, row["img_paths"])))
             img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
             img = cv2.resize(img,  self.img_size)
             img = img.astype(np.float32) / 255.0
            
             imgs.append(img)
             genders.append(row["genders"])
             ages.append(row["ages"])

         return imgs, genders
      
    def __len__(self):
        return len(self.df)
    
    def __call__(self):
        for i in range(self.__len__()):
            yield self.__getitem__(i)
            
            if i == self.__len__()-1:
                self.on_epoch_end()

    def on_epoch_end(self):
        np.random.shuffle(self.indices)

并使用下面的代码调用生成器来训练模型

epochs = 20
batch_size = 32

csv_path = 'asian_dataset.csv'
df = pd.read_csv(str(csv_path))
train, val = train_test_split(df, random_state=42, test_size=0.1)

train_gen = ImageSequence(train, "train")
val_gen = ImageSequence(val, "val")

print(train_gen)

ot = (tf.float32, tf.int32)
os = ((224, 224, 3), ())
train_data = tf.data.Dataset.from_generator(train_gen,output_types=ot,output_shapes=os)
val_data = tf.data.Dataset.from_generator(val_gen,output_types=ot,output_shapes=os)
print(train_data)

train_data = train_data.batch(batch_size)
val_data = val_data.batch(batch_size)

print(train_data)

以上代码执行时出错

TypeError:

generator
产生了一个形状元素 (32, 224, 224, 3) 其中预期形状元素 (224, 224, 3) .

之前没用过

tf.data.Dataset.from_generator
,但由于系统内存不足,我不得不使用它。

python tensorflow generator tensorflow-datasets
© www.soinside.com 2019 - 2024. All rights reserved.