重塑 LSTM 二元分类模型的表格时间序列数据

问题描述 投票:0回答:2

我想为 LSTM 二元分类模型准备数据。我想将我的数据重塑为 (num_samples,time_steps,num_features) 形状。我的训练数据集的形状为 (2487576, 21)。这是我的玩具数据代码。

import pandas as pd
import numpy as np
url="https://gist.githubusercontent.com/JishanAhmed2019/7381979ecafb7efd456421c324d7963a/raw/a50a653119471cd4fe323d7680fe82a161727169/test.csv"
df=pd.read_csv(url,sep="\t")
def generate_train_data(X, y, sequence_length=2, step = 1):
    X_local = []
    y_local = []
    for start in range(0, len(df) - sequence_length, step):
        end = start + sequence_length
        X_local.append(X[start:end])
        y_local.append(y[end-1])
    return np.array(X_local), np.array(y_local)

train_X_sequence, train_y = generate_train_data(df.loc[:, "V1":"V2"].values, df.Class)

输出:

train_X_sequence

       array([
        [[ 30, 100],
        [ 40, 200]],

       [[ 40, 200],
        [ 50, 300]],

       [[ 50, 300],
        [ 60, 400]],

       [[ 60, 400],
        [ 70, 500]],

       [[ 70, 500],
        [ 80, 600]],

       [[ 80, 600],
        [ 90, 700]]])
train_y

array([0, 1, 0, 0, 0, 1])

我看到最后一行没有出现在重塑的数据中。我在这里缺少什么吗?我正在使用 tensorflow 框架中的 LSTM。

pandas numpy tensorflow lstm tensorflow2.0
2个回答
0
投票

Image 1

由于 for 循环在 generate_train_data() 函数中定义的方式,似乎最后一行没有出现在重塑数据中。 Python 中的 range() 函数排除了结束值,因此当开始到达 len(df) - sequence_length 时,结束变得等于 len(df) - 1,这是你的 DataFrame 的最后一行。但是,由于 end 不包含在范围内,因此循环不会迭代长度为 sequence_length 的最后一个序列,因此最后一行不包含在重塑数据中。

要解决此问题,您可以通过将循环条件从 range(0, len(df) - sequence_length, step) 更改为 range(0, len(df) - sequence_length + 1,步骤):(图像 1) 通过此修改,重塑的 train_X_sequence 应包括最后一行:(图 2) 正确调整数据形状后,您可以继续将其提供给 TensorFlow 中的 LSTM 模型


0
投票

您需要在

len(df) - sequence_length + 1
循环中使用条件
for

for start in range(0, len(df) - sequence_length + 1, step):

证明它的简单步骤:

  1. 你想要
    y[end-1]
    可以访问索引为
    df['Class']
    len(df)-1
    的最后一行,所以
    max(end)
    应该是
    len(df)
  2. end
    等于
    start + sequence_length
    ,所以这意味着
    max(start)
    应该是
    len(df) - sequence_length
  3. 因此,
    start
    循环中的
    for
    应该是
    [0, len(df) - sequence_length + 1)
    ,其中左括号表示包含值,右括号表示不包含值。
© www.soinside.com 2019 - 2024. All rights reserved.