按组计算每行到所有其他行的平均距离

问题描述 投票:0回答:1

我有一个数据框,其中每一行代表一个观察值,一些列代表不同的值。我想按组计算每行与所有其他行之间的平均余弦相似度。

假设以下数据:

set.seed(123)

df <- data.frame(id = c(1:100),
                 group1 = rep(1:5, 20),
                 group2 = rep(1:10, 10),
                 value1 = runif(100,0,1),
                 value2 = runif(100,0,1),
                 value3 = runif(100,0,1),
                 value4 = runif(100,0,1),
                 value5 = runif(100,0,1),
                 value6 = runif(100,0,1),
                 value7 = runif(100,0,1),
                 value8 = runif(100,0,1),
                 value9 = runif(100,0,1),
                 value10 = runif(100,0,1))


检查数据时,我们得到以下输出:

> round(head(df),3)
  id group1 group2 value1 value2 value3 value4 value5 value6 value7 value8 value9 value10
1  1      1      1  0.288  0.600  0.239  0.785  0.986  0.354  0.237  0.845  0.471   0.924
2  2      2      2  0.788  0.333  0.962  0.009  0.137  0.366  0.686  0.260  0.366   0.543
3  3      3      3  0.409  0.489  0.601  0.779  0.905  0.287  0.226  0.023  0.121   0.852
4  4      4      4  0.883  0.954  0.515  0.729  0.576  0.080  0.318  0.862  0.047   0.584
5  5      5      5  0.940  0.483  0.403  0.630  0.395  0.365  0.174  0.335  0.263   0.668
6  6      1      6  0.046  0.890  0.880  0.481  0.450  0.178  0.801  0.632  0.969   0.511

有了这些数据,我想根据

value1 ... value10
计算每一行的平均距离(使用余弦距离)到属于
group1
group2
中所述的相同组的所有其他行。我希望输出是一个名为
dist
的附加列,其中包含每行到所有其他行的平均距离。

我已经设法实现了一个代码,它使用以下代码计算每个文档与所有文档中所有其他文档的平均欧几里得距离(不是按组并且不使用余弦相似度 [!!!]):

# Function to estimate the average distance
f<-function(x)
{
  return(mean(dist(x)))
}

# Applying function to each row
dist <-apply(as.matrix(df[,3:ncol(df)]),1,f)

# Binding it together with the data
df <- cbind(df, dist)

这给出了这个输出:

> round(head(df),3)
  id group1 group2 value1 value2 value3 value4 value5 value6 value7 value8 value9 value10  dist
1  1      1      1  0.288  0.600  0.239  0.785  0.986  0.354  0.237  0.845  0.471   0.924 0.364
2  2      2      2  0.788  0.333  0.962  0.009  0.137  0.366  0.686  0.260  0.366   0.543 0.572
3  3      3      3  0.409  0.489  0.601  0.779  0.905  0.287  0.226  0.023  0.121   0.852 0.766
4  4      4      4  0.883  0.954  0.515  0.729  0.576  0.080  0.318  0.862  0.047   0.584 0.938
5  5      5      5  0.940  0.483  0.403  0.630  0.395  0.365  0.174  0.335  0.263   0.668 1.035
6  6      1      6  0.046  0.890  0.880  0.481  0.450  0.178  0.801  0.632  0.969   0.511 1.285

但是,如前所述,我还没有找到一种方法来按组执行此操作并使用余弦相似度而不是欧氏距离。我还可以补充一点,实际数据由超过 150 万行组成,这使得计算变得困难。

r distance
1个回答
1
投票

使用 lsa 包,您可以计算行之间的余弦距离(您只需要转置数据)。

你可以这样做:

library(lsa)
library(dplyr)

set.seed(123)
df <- data.frame(id = c(1:100),
                 group1 = rep(1:5, 20),
                 group2 = rep(1:10, 10),
                 value1 = runif(100,0,1),
                 value2 = runif(100,0,1),
                 value3 = runif(100,0,1),
                 value4 = runif(100,0,1),
                 value5 = runif(100,0,1),
                 value6 = runif(100,0,1),
                 value7 = runif(100,0,1),
                 value8 = runif(100,0,1),
                 value9 = runif(100,0,1),
                 value10 = runif(100,0,1))

# cosine distance between rows
cosine_dist <- cosine(t(select(df, starts_with("value"))))

# remove diagonal (you don't want to include the distance with itself in the avarage)
diag(cosine_dist) <- NA

# average distance
df$dist <- colMeans(cosine_dist, na.rm = TRUE)

# expected output
head(df)
#>   id group1 group2    value1    value2    value3      value4    value5
#> 1  1      1      1 0.2875775 0.5999890 0.2387260 0.784575267 0.9860543
#> 2  2      2      2 0.7883051 0.3328235 0.9623589 0.009429905 0.1370675
#> 3  3      3      3 0.4089769 0.4886130 0.6013657 0.779065883 0.9053096
#> 4  4      4      4 0.8830174 0.9544738 0.5150297 0.729390652 0.5763018
#> 5  5      5      5 0.9404673 0.4829024 0.4025733 0.630131853 0.3954489
#> 6  6      1      6 0.0455565 0.8903502 0.8802465 0.480910830 0.4498025
#>       value6    value7     value8     value9   value10      dist
#> 1 0.35360608 0.2372297 0.84493354 0.47068183 0.9236992 0.7746017
#> 2 0.36644144 0.6864904 0.26013247 0.36584547 0.5425984 0.7363783
#> 3 0.28710013 0.2258184 0.02314449 0.12127205 0.8523646 0.7323561
#> 4 0.07997291 0.3184946 0.86239954 0.04699368 0.5835629 0.7587514
#> 5 0.36545427 0.1739838 0.33458796 0.26279630 0.6683236 0.7910447
#> 6 0.17801381 0.8014296 0.63178887 0.96864117 0.5113146 0.7713708

创建于 2023-03-22 与 reprex v2.0.2

© www.soinside.com 2019 - 2024. All rights reserved.