我有以下二维数组:
seq_length = 5
x = np.array([[0, 2, 0, 4], [5,6,7,8]])
x_repeated = np.repeat(x, seq_length, axis=1)
[[0 0 0 0 0 2 2 2 2 2 0 0 0 0 0 4 4 4 4 4]
[5 5 5 5 5 6 6 6 6 6 7 7 7 7 7 8 8 8 8 8]]
我想根据x_repeated
改组seq_length
,以便将seq的所有项目一起改组。
例如,可能的随机播放:
[[0 0 0 0 0 6 6 6 6 6 0 0 0 0 0 8 8 8 8 8]
[5 5 5 5 5 2 2 2 2 2 7 7 7 7 7 4 4 4 4 4]]
谢谢
您可以执行以下操作:
import numpy as np
seq_length = 5
x = np.array([[0, 2, 0, 4], [5,6,7,8]])
swaps = np.random.choice([False, True], size=4)
for swap_index, swap in enumerate(swaps):
if swap:
x[0][swap_index], x[1][swap_index] = x[1][swap_index], x[0][swap_index]
x_repeated = np.repeat(x, seq_length, axis=1)