我在此代码中不断收到此错误:它说这与steps_per_epoch变量有关,但我不知道如何处理。我尝试将其简单地转换为int,但是得到了相同的结果
train_generator=generator(context_1,final_target_1,batch_size=32)
steps_per_epoch = context_1.shape[0]/32
if context_1.shape[0] % 32:
steps_per_epoch += 1
model.fit(train_generator, epochs = 100,
steps_per_epoch=steps_per_epoch)
这是有用的生成器方法:
def generator(questions,answers, batch_size=32):
num_samples = len(questions)
print('l')
while True: # Loop forever so the generator never terminates
# Get index to start each batch: [0, batch_size, 2*batch_size, ..., max multiple of batch_size <= num_samples]
for offset in range(0, num_samples, batch_size):
# Get the samples you'll use in this batch
#batch_samples = samples[offset:offset+batch_size]
question_samples = questions[offset:offset+batch_size]
answer_samples = answers[offset:offset+batch_size]
# Initialise X_train and y_train arrays for this batch
ques_train = []
ans_train = []
for i in question_samples:
ques_train.append(i)
for i in question_samples:
ans_train.append(i)
# Make sure they're numpy arrays (as opposed to lists)
ques_train = np.array(ques_train)
ans_train = np.array(ans_train)
ans_train = sequence.pad_sequences(ans_train, maxlen = 20, dtype = 'int32', padding = 'post', truncating = 'post')
ques_train = sequence.pad_sequences(ques_train, maxlen = 20, dtype = 'int32', padding = 'post', truncating = 'post')
# The generator-y part: yield the next training batch
outs = np.zeros((question_samples.shape[0], maxLen, vocab_size))
for pos,i in enumerate(answer_samples):
for pos1,j in enumerate(i):
if pos1 > 0:
outs[pos, pos1-1, j] = 1
if pos%1000 == 0: print ('{} entries completed'.format(pos))
yield ([ques_train,ans_train],outs)
对于有类似问题的任何人,问题都在屈服线上,其中第一个元素是一个列表,它必须是一个元组。
yield (ques_train,ans_train),outs