高效的多对多嵌入比较

问题描述 投票:0回答:1

我试图向用户推荐最热门的“文章”,并嵌入他们拥有的“兴趣”。

每个“用户”将有 5-10 个与其个人资料相关的嵌入,表示为双精度数组。

每篇“文章”还将有 5-10 个与之关联的嵌入(每个嵌入代表一个不同的主题)。

我想编写一个 PostgreSQL 查询,返回最符合用户兴趣的前 20 篇“文章”。由于每个用户有 5-10 个嵌入代表他们的兴趣,而每篇文章有 5-10 个嵌入代表其涵盖的内容,因此我不能简单地应用像

pgvector
这样的扩展来解决这个问题。

我用 SQL 编写了一个算法,计算用户嵌入和文章嵌入之间的成对相似性,然后沿每行取最大值,然后对这些值求平均值。它有助于想象一个 UxT 矩阵(其中 U 表示用户嵌入数量,T 是文章嵌入)并通过相应的用户嵌入和文章嵌入之间的余弦相似度填充每个条目。

我编写了辅助函数来计算两个数组之间的乘积,另一个用于计算余弦相似度,第三个用于计算“vectors_similarity”——它计算一组用户向量和一组文章向量之间的相似度。

查询本身应用一些联接来获取所需的信息,过滤出过去十天内的文章和用户已“阅读”的文章,并使用此方法返回前 20 篇最相似的文章。

搜索 1000 篇文章需要超过 30 秒。我不是 SQL 专家,并且正在努力调试它。下面,我发布了我的 SQL 查询和“解释分析”的结果。

这只是计算上难以处理,还是我错过了一些明显的优化机会?

CREATE OR REPLACE FUNCTION array_product(arr double precision[])
RETURNS double precision AS
$$
DECLARE
    result double precision := 1;
    i integer;
BEGIN
    FOR i IN array_lower(arr, 1) .. array_upper(arr, 1)
    LOOP
        result := result * arr[i];
    END LOOP;
    RETURN result;
END;
$$
LANGUAGE plpgsql;

CREATE OR REPLACE FUNCTION cosine_similarity(a double precision[], b double precision[])
RETURNS double precision AS $$
DECLARE
    dot_product double precision;
    norm_a double precision;
    norm_b double precision;
    a_length int;
    b_length int;
BEGIN
    a_length := array_length(a, 1);
    b_length := array_length(b, 1);

    dot_product := 0;
    norm_a := 0;
    norm_b := 0;
    
    FOR i IN 1..a_length LOOP
        dot_product := dot_product + a[i] * b[i];
        norm_a := norm_a + a[i] * a[i];
        norm_b := norm_b + b[i] * b[i];
    END LOOP;

    norm_a := sqrt(norm_a);
    norm_b := sqrt(norm_b);

    IF norm_a = 0 OR norm_b = 0 THEN
        RETURN 0;
    ELSE
        RETURN dot_product / (norm_a * norm_b);
    END IF;
END;
$$ LANGUAGE plpgsql;

CREATE OR REPLACE FUNCTION vectors_similarity(
    user_vectors FLOAT[][],
    article_vectors FLOAT[][]
) RETURNS FLOAT AS $$
DECLARE
    num_user_vectors INT;
    num_article_vectors INT;
    scores FLOAT[][];
    row_weights FLOAT[];
    row_values FLOAT[];
    col_weights FLOAT[];
    similarity FLOAT;
    article_vector FLOAT[][];
    user_vector FLOAT[][];
    i int;
    j int;
BEGIN
    num_user_vectors := array_length(user_vectors, 1);
    num_article_vectors := array_length(article_vectors, 1);

    scores := ARRAY(SELECT ARRAY(SELECT 0.0 FROM generate_series(1, num_article_vectors)) FROM generate_series(1, num_user_vectors));
    
    i := 1;
    FOREACH user_vector SLICE 1 IN ARRAY user_vectors
    LOOP
        j := 1;
        FOREACH article_vector SLICE 1 IN ARRAY article_vectors
        LOOP
            scores[i][j] := cosine_similarity(user_vector, article_vector);
            scores[i][j] := exp(scores[i][j] * 7);
        j := j+1;
        END LOOP;
    i := i + 1;
    END LOOP;
        
    SELECT 
      AVG(
        (SELECT MAX(row_val) FROM unnest(row_array) AS row_val)
      ) INTO similarity
    FROM 
      (
        SELECT scores[row_index][:] AS row_array
        FROM generate_series(1, array_length(scores, 1)) AS row_index
      ) AS subquery;
    
    RETURN similarity;
END;
$$ LANGUAGE plpgsql;

EXPLAIN ANALYZE
SELECT
        ART.*,
        vectors_similarity(array_agg(TOPIC.vector), ARRAY[ARRAY[ -0.0026961329858750105,0.004657252691686153, -0.011298391036689281, ...], ARRAY[...]]) AS similatory_score
    FROM
        article ART     
    JOIN
        article_topic ART_TOP ON ART.id = ART_TOP.article_id
    JOIN
        topic TOPIC ON ART_TOP.topic_id = TOPIC.id
    WHERE
        ART.date_published > CURRENT_DATE - INTERVAL '5' DAY
        AND NOT EXISTS (
        SELECT 1
        FROM user_article_read USR_ART_READ
        WHERE USR_ART_READ.article_id = ART.id
        AND USR_ART_READ.profile_id = 1 -- :user_id to be inputted by the user with the actual user_id
        )
    GROUP BY
        ART.id
    ORDER BY
        similatory_score DESC, ART.date_published DESC, ART.id DESC
    LIMIT 20;

分析如下:

"Limit  (cost=945.53..945.55 rows=5 width=518) (actual time=27873.197..27873.227 rows=5 loops=1)"
"  Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type, (scaled_geometric_similarity_vectors(array_agg(topic.vector), '{{-0.0026961329858750105,...}, {...}, ...}'::double precision[])"
"              Group Key: art.id"
"              Batches: 25  Memory Usage: 8524kB  Disk Usage: 3400kB"
"              Buffers: shared hit=14535 read=19, temp read=401 written=750"
"              ->  Hash Join  (cost=395.19..687.79 rows=4491 width=528) (actual time=6.746..20.875 rows=4638 loops=1)"
"                    Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type, topic.vector"
"                    Inner Unique: true"
"                    Hash Cond: (art_top.topic_id = topic.id)"
"                    Buffers: shared hit=289"
"                    ->  Hash Anti Join  (cost=202.53..483.33 rows=4491 width=518) (actual time=3.190..15.589 rows=4638 loops=1)"
"                          Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type, art_top.topic_id"
"                          Hash Cond: (art.id = usr_art_read.article_id)"
"                          Buffers: shared hit=229"
"                          ->  Hash Join  (cost=188.09..412.13 rows=4506 width=518) (actual time=3.106..14.853 rows=4638 loops=1)"
"                                Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type, art_top.topic_id"
"                                Inner Unique: true"
"                                Hash Cond: (art_top.article_id = art.id)"
"                                Buffers: shared hit=224"
"                                ->  Seq Scan on public.article_topic art_top  (cost=0.00..194.67 rows=11167 width=16) (actual time=0.018..7.589 rows=11178 loops=1)"
"                                      Output: art_top.id, art_top.created_timestamp, art_top.article_id, art_top.topic_id"
"                                      Buffers: shared hit=83"
"                                ->  Hash  (cost=177.56..177.56 rows=843 width=510) (actual time=3.005..3.011 rows=818 loops=1)"
"                                      Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type"
"                                      Buckets: 1024  Batches: 1  Memory Usage: 433kB"
"                                      Buffers: shared hit=141"
"                                      ->  Seq Scan on public.article art  (cost=0.00..177.56 rows=843 width=510) (actual time=0.082..1.585 rows=818 loops=1)"
"                                            Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type"
"                                            Filter: (art.date_published > (CURRENT_DATE - '5 days'::interval day))"
"                                            Rows Removed by Filter: 1191"
"                                            Buffers: shared hit=141"
"                          ->  Hash  (cost=14.35..14.35 rows=7 width=8) (actual time=0.052..0.052 rows=0 loops=1)"
"                                Output: usr_art_read.article_id"
"                                Buckets: 1024  Batches: 1  Memory Usage: 8kB"
"                                Buffers: shared hit=5"
"                                ->  Bitmap Heap Scan on public.user_article_read usr_art_read  (cost=4.21..14.35 rows=7 width=8) (actual time=0.051..0.052 rows=0 loops=1)"
"                                      Output: usr_art_read.article_id"
"                                      Recheck Cond: (usr_art_read.profile_id = 1)"
"                                      Buffers: shared hit=5"
"                                      ->  Bitmap Index Scan on user_article_read_profile_id_d4edd4f6  (cost=0.00..4.21 rows=7 width=0) (actual time=0.050..0.050 rows=0 loops=1)"
"                                            Index Cond: (usr_art_read.profile_id = 1)"
"                                            Buffers: shared hit=5"
"                    ->  Hash  (cost=118.96..118.96 rows=5896 width=26) (actual time=3.436..3.440 rows=5918 loops=1)"
"                          Output: topic.vector, topic.id"
"                          Buckets: 8192  Batches: 1  Memory Usage: 434kB"
"                          Buffers: shared hit=60"
"                          ->  Seq Scan on public.topic  (cost=0.00..118.96 rows=5896 width=26) (actual time=0.009..2.100 rows=5918 loops=1)"
"                                Output: topic.vector, topic.id"
"                                Buffers: shared hit=60"
"Planning:"
"  Buffers: shared hit=406 read=7"
"Planning Time: 52.507 ms"
"Execution Time: 27875.522 ms"
sql postgresql search sql-execution-plan word-embedding
1个回答
0
投票

目前,几乎所有成本都发生在外部

SELECT
,这不是由于排序造成的。这是一个非常昂贵的函数
vectors_similarity()
,它多次调用嵌套函数
cosine_similarity()
,而该嵌套函数与第一个函数一样低效。

(您还显示了函数

array_product()
,但在查询中未使用该函数,因此只是分散注意力。顺便说一句,效率也很低。)

查询计划中的这部分表明您需要更多

work_mem

Memory Usage: 8,524kB Disk Usage: 3,400kB

确实,您的服务器似乎处于默认设置,否则

EXPLAIN(ANALYZE, VERBOSE, BUFFERS, SETTINGS)
(就像您声称使用过的那样)会报告自定义设置。这对于不小的工作量来说是不行的。

我从“当前”开始,因为情况只会变得更糟。您过滤了 5 天或您的数据,但

article.date_published
上没有索引。目前,您的 2000 篇文章中几乎有一半符合条件,但这一比例势必会发生巨大变化。 IOW,您需要“文章(发布日期)”的索引。

所以你的行动方针是:

1.) 优化两个“相似性”功能,最重要的是

cosine_similarity()

2.) 减少进行昂贵计算的行数,也许可以通过使用更便宜的过滤器来预过滤行。

3.) 优化服务器配置。

4.) 优化索引。

© www.soinside.com 2019 - 2024. All rights reserved.