计算numpy数组的所有子数组之间的相对距离。

问题描述 投票:0回答:1

我有一个数组。

test_arr = np.array([ [1.2, 2.1, 2.3, 4.5],
                      [2.6, 6.4, 5.2, 6.2],
                      [7.2, 6.2, 2.5, 1.7],
                      [8.2, 7.6, 4.2, 7.3] ]

是否可以得到一个pandas数据框架的形式:

row_id  | row1  | row2          | row3          | row4
row1      0.0     d(row1,row2)    d(row1,row3)    d(row1,row4)
row2      ...     0.0             ...             ...
row3      ...        ...          0.0             ...
row4      ...        ...          0.0             ...

其中 d(row1, row2) 之间的欧氏距离。row1 和 row2。

我现在尝试的是首先生成一个所有行对的列表,然后计算距离,并将每个元素分配给数据框。有没有更好更快的方法?

python pandas numpy distance
1个回答
2
投票
from scipy import spatial
import numpy as np

test_arr = np.array([ [1.2, 2.1, 2.3, 4.5],
                      [2.6, 6.4, 5.2, 6.2],
                      [7.2, 6.2, 2.5, 1.7],
                      [8.2, 7.6, 4.2, 7.3] ])

dist = spatial.distance.pdist(test_arr)
spatial.distance.squareform(dist)

结果。

array([[0.        , 5.63471383, 7.79037868, 9.52365476],
       [5.63471383, 0.        , 6.98140387, 5.91692488],
       [7.79037868, 6.98140387, 0.        , 6.1       ],
       [9.52365476, 5.91692488, 6.1       , 0.        ]])

2
投票
from sklearn.metrics.pairwise import euclidean_distances
pd.DataFrame(euclidean_distances(test_arr, test_arr))

          0         1         2         3
0  0.000000  5.634714  7.790379  9.523655
1  5.634714  0.000000  6.981404  5.916925
2  7.790379  6.981404  0.000000  6.100000
3  9.523655  5.916925  6.100000  0.000000

0
投票

使用 cdist 来计算配对距离

将2D生成的数组放入Pandas DataFrame中。

import numpy as np
from scipy.spatial.distance import cdist
import pandas as pd

test_arr = np.array([ [1.2, 2.1, 2.3, 4.5],
                      [2.6, 6.4, 5.2, 6.2],
                      [7.2, 6.2, 2.5, 1.7],
                      [8.2, 7.6, 4.2, 7.3] ])

    # Use cdist to compute pairwise distances
    dist = cdist(test_arr, test_arr)

    # Place into Pandas DataFrame
    # index and names of columns
    names = ['row' + str(i) for i in range(1, dist.shape[0]+1)]
    df = pd.DataFrame(dist, columns = names, index = names)

    print(df)

产量

Pandas DataFrame

        row1      row2      row3      row4
row1  0.000000  5.634714  7.790379  9.523655
row2  5.634714  0.000000  6.981404  5.916925
row3  7.790379  6.981404  0.000000  6.100000
row4  9.523655  5.916925  6.100000  0.000000
© www.soinside.com 2019 - 2024. All rights reserved.