在keras中拟合模型时出错:TypeError:int()参数必须是字符串,类似字节的对象或数字,而不是'tuple'

问题描述 投票:0回答:1

我在此代码中不断收到此错误:它说这与steps_per_epoch变量有关,但我不知道如何处理。我尝试将其简单地转换为int,但是得到了相同的结果

train_generator=generator(context_1,final_target_1,batch_size=32)
steps_per_epoch = context_1.shape[0]/32
if context_1.shape[0] % 32:
    steps_per_epoch += 1
model.fit(train_generator, epochs = 100,
      steps_per_epoch=steps_per_epoch)

这是有用的生成器方法:

def generator(questions,answers, batch_size=32):
num_samples = len(questions)
print('l')
while True: # Loop forever so the generator never terminates
    # Get index to start each batch: [0, batch_size, 2*batch_size, ..., max multiple of batch_size <= num_samples]
    for offset in range(0, num_samples, batch_size):
        # Get the samples you'll use in this batch
        #batch_samples = samples[offset:offset+batch_size]
        question_samples = questions[offset:offset+batch_size]
        answer_samples = answers[offset:offset+batch_size]
        # Initialise X_train and y_train arrays for this batch
        ques_train = []
        ans_train = []
        for i in question_samples:
            ques_train.append(i)
        for i in question_samples:
            ans_train.append(i)
        # Make sure they're numpy arrays (as opposed to lists)
        ques_train = np.array(ques_train)
        ans_train = np.array(ans_train)
        ans_train = sequence.pad_sequences(ans_train, maxlen = 20, dtype = 'int32', padding = 'post', truncating = 'post')
        ques_train = sequence.pad_sequences(ques_train, maxlen = 20, dtype = 'int32', padding = 'post', truncating = 'post')
        # The generator-y part: yield the next training batch
        outs = np.zeros((question_samples.shape[0], maxLen, vocab_size))
        for pos,i in enumerate(answer_samples):
            for pos1,j in enumerate(i):
                if pos1 > 0:
                    outs[pos, pos1-1, j] = 1
            if pos%1000 == 0: print ('{} entries completed'.format(pos))
        yield ([ques_train,ans_train],outs)
python keras tensor
1个回答
0
投票

对于有类似问题的任何人,问题都在屈服线上,其中第一个元素是一个列表,它必须是一个元组。

yield (ques_train,ans_train),outs
热门问题
推荐问题
最新问题