我试图向用户推荐最热门的“文章”,并嵌入他们拥有的“兴趣”。
每个“用户”将有 5-10 个与其个人资料相关的嵌入,表示为双精度数组。
每篇“文章”还将有 5-10 个与之关联的嵌入(每个嵌入代表一个不同的主题)。
我想编写一个 PostgreSQL 查询,返回最符合用户兴趣的前 20 篇“文章”。由于每个用户有 5-10 个嵌入代表他们的兴趣,而每篇文章有 5-10 个嵌入代表其涵盖的内容,因此我不能简单地应用像
pgvector
这样的扩展来解决这个问题。
我用 SQL 编写了一个算法,计算用户嵌入和文章嵌入之间的成对相似性,然后沿每行取最大值,然后对这些值求平均值。它有助于想象一个 UxT 矩阵(其中 U 表示用户嵌入数量,T 是文章嵌入)并通过相应的用户嵌入和文章嵌入之间的余弦相似度填充每个条目。
我编写了辅助函数来计算两个数组之间的乘积,另一个用于计算余弦相似度,第三个用于计算“vectors_similarity”——它计算一组用户向量和一组文章向量之间的相似度。
查询本身应用一些联接来获取所需的信息,过滤出过去十天内的文章和用户已“阅读”的文章,并使用此方法返回前 20 篇最相似的文章。
搜索 1000 篇文章需要超过 30 秒。我不是 SQL 专家,并且正在努力调试它。下面,我发布了我的 SQL 查询和“解释分析”的结果。
这只是计算上难以处理,还是我错过了一些明显的优化机会?
CREATE OR REPLACE FUNCTION array_product(arr double precision[])
RETURNS double precision AS
$$
DECLARE
result double precision := 1;
i integer;
BEGIN
FOR i IN array_lower(arr, 1) .. array_upper(arr, 1)
LOOP
result := result * arr[i];
END LOOP;
RETURN result;
END;
$$
LANGUAGE plpgsql;
CREATE OR REPLACE FUNCTION cosine_similarity(a double precision[], b double precision[])
RETURNS double precision AS $$
DECLARE
dot_product double precision;
norm_a double precision;
norm_b double precision;
a_length int;
b_length int;
BEGIN
a_length := array_length(a, 1);
b_length := array_length(b, 1);
dot_product := 0;
norm_a := 0;
norm_b := 0;
FOR i IN 1..a_length LOOP
dot_product := dot_product + a[i] * b[i];
norm_a := norm_a + a[i] * a[i];
norm_b := norm_b + b[i] * b[i];
END LOOP;
norm_a := sqrt(norm_a);
norm_b := sqrt(norm_b);
IF norm_a = 0 OR norm_b = 0 THEN
RETURN 0;
ELSE
RETURN dot_product / (norm_a * norm_b);
END IF;
END;
$$ LANGUAGE plpgsql;
CREATE OR REPLACE FUNCTION vectors_similarity(
user_vectors FLOAT[][],
article_vectors FLOAT[][]
) RETURNS FLOAT AS $$
DECLARE
num_user_vectors INT;
num_article_vectors INT;
scores FLOAT[][];
row_weights FLOAT[];
row_values FLOAT[];
col_weights FLOAT[];
similarity FLOAT;
article_vector FLOAT[][];
user_vector FLOAT[][];
i int;
j int;
BEGIN
num_user_vectors := array_length(user_vectors, 1);
num_article_vectors := array_length(article_vectors, 1);
scores := ARRAY(SELECT ARRAY(SELECT 0.0 FROM generate_series(1, num_article_vectors)) FROM generate_series(1, num_user_vectors));
i := 1;
FOREACH user_vector SLICE 1 IN ARRAY user_vectors
LOOP
j := 1;
FOREACH article_vector SLICE 1 IN ARRAY article_vectors
LOOP
scores[i][j] := cosine_similarity(user_vector, article_vector);
scores[i][j] := exp(scores[i][j] * 7);
j := j+1;
END LOOP;
i := i + 1;
END LOOP;
SELECT
AVG(
(SELECT MAX(row_val) FROM unnest(row_array) AS row_val)
) INTO similarity
FROM
(
SELECT scores[row_index][:] AS row_array
FROM generate_series(1, array_length(scores, 1)) AS row_index
) AS subquery;
RETURN similarity;
END;
$$ LANGUAGE plpgsql;
EXPLAIN ANALYZE
SELECT
ART.*,
vectors_similarity(array_agg(TOPIC.vector), ARRAY[ARRAY[ -0.0026961329858750105,0.004657252691686153, -0.011298391036689281, ...], ARRAY[...]]) AS similatory_score
FROM
article ART
JOIN
article_topic ART_TOP ON ART.id = ART_TOP.article_id
JOIN
topic TOPIC ON ART_TOP.topic_id = TOPIC.id
WHERE
ART.date_published > CURRENT_DATE - INTERVAL '5' DAY
AND NOT EXISTS (
SELECT 1
FROM user_article_read USR_ART_READ
WHERE USR_ART_READ.article_id = ART.id
AND USR_ART_READ.profile_id = 1 -- :user_id to be inputted by the user with the actual user_id
)
GROUP BY
ART.id
ORDER BY
similatory_score DESC, ART.date_published DESC, ART.id DESC
LIMIT 20;
分析如下:
"Limit (cost=945.53..945.55 rows=5 width=518) (actual time=27873.197..27873.227 rows=5 loops=1)"
" Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type, (scaled_geometric_similarity_vectors(array_agg(topic.vector), '{{-0.0026961329858750105,...}, {...}, ...}'::double precision[])"
" Group Key: art.id"
" Batches: 25 Memory Usage: 8524kB Disk Usage: 3400kB"
" Buffers: shared hit=14535 read=19, temp read=401 written=750"
" -> Hash Join (cost=395.19..687.79 rows=4491 width=528) (actual time=6.746..20.875 rows=4638 loops=1)"
" Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type, topic.vector"
" Inner Unique: true"
" Hash Cond: (art_top.topic_id = topic.id)"
" Buffers: shared hit=289"
" -> Hash Anti Join (cost=202.53..483.33 rows=4491 width=518) (actual time=3.190..15.589 rows=4638 loops=1)"
" Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type, art_top.topic_id"
" Hash Cond: (art.id = usr_art_read.article_id)"
" Buffers: shared hit=229"
" -> Hash Join (cost=188.09..412.13 rows=4506 width=518) (actual time=3.106..14.853 rows=4638 loops=1)"
" Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type, art_top.topic_id"
" Inner Unique: true"
" Hash Cond: (art_top.article_id = art.id)"
" Buffers: shared hit=224"
" -> Seq Scan on public.article_topic art_top (cost=0.00..194.67 rows=11167 width=16) (actual time=0.018..7.589 rows=11178 loops=1)"
" Output: art_top.id, art_top.created_timestamp, art_top.article_id, art_top.topic_id"
" Buffers: shared hit=83"
" -> Hash (cost=177.56..177.56 rows=843 width=510) (actual time=3.005..3.011 rows=818 loops=1)"
" Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type"
" Buckets: 1024 Batches: 1 Memory Usage: 433kB"
" Buffers: shared hit=141"
" -> Seq Scan on public.article art (cost=0.00..177.56 rows=843 width=510) (actual time=0.082..1.585 rows=818 loops=1)"
" Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type"
" Filter: (art.date_published > (CURRENT_DATE - '5 days'::interval day))"
" Rows Removed by Filter: 1191"
" Buffers: shared hit=141"
" -> Hash (cost=14.35..14.35 rows=7 width=8) (actual time=0.052..0.052 rows=0 loops=1)"
" Output: usr_art_read.article_id"
" Buckets: 1024 Batches: 1 Memory Usage: 8kB"
" Buffers: shared hit=5"
" -> Bitmap Heap Scan on public.user_article_read usr_art_read (cost=4.21..14.35 rows=7 width=8) (actual time=0.051..0.052 rows=0 loops=1)"
" Output: usr_art_read.article_id"
" Recheck Cond: (usr_art_read.profile_id = 1)"
" Buffers: shared hit=5"
" -> Bitmap Index Scan on user_article_read_profile_id_d4edd4f6 (cost=0.00..4.21 rows=7 width=0) (actual time=0.050..0.050 rows=0 loops=1)"
" Index Cond: (usr_art_read.profile_id = 1)"
" Buffers: shared hit=5"
" -> Hash (cost=118.96..118.96 rows=5896 width=26) (actual time=3.436..3.440 rows=5918 loops=1)"
" Output: topic.vector, topic.id"
" Buckets: 8192 Batches: 1 Memory Usage: 434kB"
" Buffers: shared hit=60"
" -> Seq Scan on public.topic (cost=0.00..118.96 rows=5896 width=26) (actual time=0.009..2.100 rows=5918 loops=1)"
" Output: topic.vector, topic.id"
" Buffers: shared hit=60"
"Planning:"
" Buffers: shared hit=406 read=7"
"Planning Time: 52.507 ms"
"Execution Time: 27875.522 ms"
目前,几乎所有成本都发生在外部
SELECT
,这不是由于排序造成的。这是一个非常昂贵的函数vectors_similarity()
,它多次调用嵌套函数cosine_similarity()
,而该嵌套函数与第一个函数一样低效。
(您还显示了函数
array_product()
,但在查询中未使用该函数,因此只是分散注意力。顺便说一句,效率也很低。)
work_mem
:
Memory Usage: 8,524kB Disk Usage: 3,400kB
确实,您的服务器似乎处于默认设置,否则
EXPLAIN(ANALYZE, VERBOSE, BUFFERS, SETTINGS)
(就像您声称使用过的那样)会报告自定义设置。这对于不小的工作量来说是不行的。
我从“当前”开始,因为情况只会变得更糟。您过滤了 5 天或您的数据,但
article.date_published
上没有索引。目前,您的 2000 篇文章中几乎有一半符合条件,但这一比例势必会发生巨大变化。 IOW,您需要“文章(发布日期)”的索引。
所以你的行动方针是:
1.) 优化两个“相似性”功能,最重要的是
cosine_similarity()
。
2.) 减少进行昂贵计算的行数,也许可以通过使用更便宜的过滤器来预过滤行。
3.) 优化服务器配置。
4.) 优化索引。