我正在尝试仅为在任何 2 列或更多列中具有相同值的节点/(数据框中的记录索引)创建带有边的图形。
我正在做的事情 - 我创建一个包含所有可能的列名组合对的列表,并通过它们搜索重复项,为此我提取索引并创建边。
问题是,对于巨大的数据集(数百万条记录) - 该解决方案太慢并且需要太多内存。
我做什么:
df = pd.DataFrame({
'A': [1, 2, 3, 4, 5],
'B': [1, 1, 1, 1, 2],
'C': [1, 1, 2, 3, 3],
'D': [2, 7, 9, 8, 4]})
A | B | C | D | |
---|---|---|---|---|
0 | 1 | 1 | 1 | 2 |
1 | 2 | 1 | 1 | 7 |
2 | 3 | 1 | 2 | 9 |
3 | 4 | 1 | 3 | 8 |
4 | 5 | 2 | 3 | 4 |
这里,第 0 行和第 1 行在 B 列和 C 列中有 2 个相同的值。
因此,对于节点 0、1、2、3、4,我需要创建边 0-1。其他记录之间最多有 1 个相同字段。
graph = nk.Graph(num_nodes, directed=False, weighted=False)
# Get the indices of all unique pairs
indices = np.triu_indices(len(column_names), k=1)
# Get the unique pairs of column names
unique_pairs = np.column_stack((column_names[indices[0]], column_names[indices[1]]))
for col1, col2 in unique_pairs:
# Filter the dataframe directly
duplicated_rows = df[[col1, col2]].dropna()
duplicated_rows = duplicated_rows[duplicated_rows.duplicated(subset=[col1, col2], keep=False)]
for _, group in duplicated_rows.groupby([col1, col2]):
tb_ids = group.index.tolist()
for i in range(len(tb_ids)):
for j in range(i + 1, len(tb_ids)):
graph.addEdge(tb_ids[i], tb_ids[j])
主要问题 - 如何加速/改进这个解决方案?我正在考虑通过列组合进行并行化 - 但在这种情况下无法弄清楚如何正确地在图中创建边。
感谢任何帮助。
内存问题
您的数百万条记录输入会生成如此多的对,它们无法全部保存在内存中。
您将不得不放弃将所有内容存储在内存中。您需要将数据存储在高度优化的数据库中。我建议使用 SQLite。根据需要将输入数据放入内存中,并在找到数据对时将其存储到数据库中。如果你正确优化 SQLite 的使用,那么对性能的影响将是最小的,并且你不会耗尽内存
性能问题
将对存储到数据库会稍微降低性能。
您将需要优化数据库的使用方式。两个最重要的优化是:
交易分组。最初,将这些对保留在内存中。当对计数达到指定数量时,在一次事务中将它们全部写入数据库。
异步写入。将写入移交给数据库引擎后,不要等待确认写入成功 - 只需继续进行配对搜索即可。
您忘记说明您的性能要求!然而,无论您的要求是什么,我都会假设您需要做出重大改进。
我看到你正在使用Python。这是一种解释性语言,因此性能会很迟缓。切换到编译语言将为您带来显着的性能提升。例如,使用良好编码的 C++ 可以提高高达 50 倍。
算法
SET T number of pairs to writ in one DB transaction
LOOP N over all records
IF N has 2 or more identical values
LOOP M over records N+1 to last
LOOP C over columns
LOOP D over cols C+1 to last
IF N[C] == N[D] == M[C] == M[D]
SAVE M,N to memory pair store
IF memory pair store size >= T
WRITE memory pair store to DB
CLEAR memory pair store
WRITE memory pair store to DB
示例:
以下是这些想法在 C++ 中的实现,可以在一台普通笔记本电脑上在 40 秒内从 100,000 条记录中找到约 6,000,000 对。
#include <string>
#include <fstream>
#include <sstream>
#include <iostream>
#include <vector>
#include <algorithm>
#include <time.h>
#include "sqlite3.h"
#include "cRunWatch.h" // https://ravenspoint.wordpress.com/2010/06/16/timing/
std::vector<std::vector<int>> vdata;
class cPairStorage
{
std::vector<std::pair<int, int>> vPair;
sqlite3 *db;
char *dbErrMsg;
int transactionCount;
public:
cPairStorage();
void add(int r1, int r2)
{
vPair.push_back(std::make_pair(r1, r2));
if (vPair.size() > transactionCount)
writeDB();
}
void writeDB();
int count();
std::pair<int, int> get(int index);
};
cPairStorage pairStore;
cPairStorage::cPairStorage()
: transactionCount(500)
{
int ret = sqlite3_open("pair.db", &db);
if (ret)
throw std::runtime_error("failed to open db");
ret = sqlite3_exec(db,
"CREATE TABLE IF NOT EXISTS pair (r1, r2);",
0, 0, &dbErrMsg);
ret = sqlite3_exec(db,
"DELETE FROM pair;",
0, 0, &dbErrMsg);
ret = sqlite3_exec(db,
"PRAGMA schema.synchronous = 0;",
0, 0, &dbErrMsg);
}
void cPairStorage::writeDB()
{
//raven::set::cRunWatch aWatcher("writeDB");
sqlite3_stmt *stmt;
int ret = sqlite3_prepare_v2(
db,
"INSERT INTO pair VALUES ( ?1, ?2 );",
-1, &stmt, 0);
ret = sqlite3_exec(
db,
"BEGIN TRANSACTION;",
0, 0, &dbErrMsg);
for (auto &p : vPair)
{
ret = sqlite3_bind_int(stmt, 1, p.first);
ret = sqlite3_bind_int(stmt, 2, p.second);
ret = sqlite3_step(stmt);
ret = sqlite3_reset(stmt);
}
ret = sqlite3_exec(
db,
"END TRANSACTION;",
0, 0, &dbErrMsg);
//std::cout << "stored " << vPair.size() << "\n";
vPair.clear();
}
int cPairStorage::count()
{
int ret;
sqlite3_stmt *stmt;
ret = sqlite3_prepare_v2(
db,
"SELECT count(*) FROM pair;",
-1, &stmt, 0);
ret = sqlite3_step(stmt);
int count = sqlite3_column_int(stmt, 0);
ret = sqlite3_reset(stmt);
return count;
}
std::pair<int, int> cPairStorage::get(int index)
{
if (0 > index || index >= count())
throw std::runtime_error("bad pair index");
std::pair<int, int> pair;
int ret;
sqlite3_stmt *stmt;
ret = sqlite3_prepare_v2(
db,
"SELECT * FROM pair WHERE rowid = ?1;",
-1, &stmt, 0);
ret = sqlite3_bind_int(stmt, 1, index);
ret = sqlite3_step(stmt);
pair.first = sqlite3_column_int(stmt, 0);
pair.second = sqlite3_column_int(stmt, 1);
ret = sqlite3_reset(stmt);
return pair;
}
void generateRandom(
int colCount,
int rowCount,
int maxValue)
{
srand(time(NULL));
for (int krow = 0; krow < rowCount; krow++)
{
std::vector<int> vrow;
for (int kcol = 0; kcol < colCount; kcol++)
vrow.push_back(rand() % maxValue + 1);
vdata.push_back(vrow);
}
}
bool isPair(int r1, int r2)
{
auto &v1 = vdata[r1];
auto &v2 = vdata[r2];
for (int kc1 = 0; kc1 < v1.size(); kc1++)
{
for (int kc2 = kc1 + 1; kc2 < v1.size(); kc2++)
{
int tv = v1[kc1];
if (tv != v1[kc2])
continue;
if (tv != v2[kc1])
continue;
if (tv != v2[kc2])
continue;
return true;
}
}
return false;
}
void findPairs()
{
raven::set::cRunWatch aWatcher("findPairs");
int colCount = vdata[0].size();
for (int kr1 = 0; kr1 < vdata.size(); kr1++)
{
bool pairPossible = false;
for (int kc1 = 0; kc1 < colCount; kc1++) {
for (int kc2 = kc1 + 1; kc2 < colCount; kc2++) {
if (vdata[kr1][kc1] == vdata[kr1][kc2])
{
// row has two cols with equal values
// so it can be part of a row pair
pairPossible = true;
break;
}
}
if (!pairPossible)
break;
}
if (!pairPossible)
continue;
for (int kr2 = kr1 + 1; kr2 < vdata.size(); kr2++)
if (isPair(kr1, kr2))
pairStore.add(kr1, kr2);
}
pairStore.writeDB();
}
void display()
{
std::cout << "\nFound " << pairStore.count() << " pairs in " << vdata.size() << " records\n\n";
std::cout << "First 2 pairs found:\n\n";
for (int kp = 0; kp < 2; kp++)
{
auto p = pairStore.get(kp+1);
for (int v : vdata[p.first])
std::cout << v << " ";
std::cout << "\n";
for (int v : vdata[p.second])
std::cout << v << " ";
std::cout << "\n\n";
}
raven::set::cRunWatch::Report();
}
main(int ac, char *argc[])
{
int rowCount = 10;
if (ac == 2)
rowCount = atoi(argc[1]);
raven::set::cRunWatch::Start();
generateRandom(
5, // columns
rowCount, // rows
20); // max value
findPairs();
display();
return 0;
}
测试运行的输出
>matcher --rows 100000 --trans 10000 --seed 571
unit tests passed
Found 6238872 pairs in 100000 records
First 2 pairs found:
4 4 13 18 18
4 4 1 10 7
4 4 13 18 18
4 4 11 3 1
raven::set::cRunWatch code timing profile
Calls Mean (secs) Total Scope
1 40.3924 40.3924 findPairs
完整的应用程序以及 github 存储库中的文档https://github.com/JamesBremner/RecordMatcher
多线程
很简单,可以将要搜索的数据分成两部分,并在各自的线程中搜索每一部分。正如多线程应用程序经常发生的那样,最初的性能结果令人失望。然而,通过调整配置参数,我取得了看似值得的改进。
在一台普通笔记本电脑上 30 秒内在 100,000 条记录中查找约 6,000,000 对。
>matcher --rows 100000 --trans 10000 --seed 571 --multi
unit tests passed
Found 6238872 pairs in 100000 records
First 2 pairs found:
4 4 13 18 18
4 4 1 10 7
4 4 13 18 18
4 4 11 3 1
raven::set::cRunWatch code timing profile
Calls Mean (secs) Total Scope
1 29.6909 29.6909 findPairs
通过使用 joblib 并行化及其新功能稍微改进了我的解决方案 -
return_as="generator"
:
def get_matching_pairs(df_grouped: pd.DataFrame) -> List:
thub_ids = df_grouped.index.values
return list(combinations(thub_ids, 2))
graph = nk.Graph(num_nodes, directed=False, weighted=False)
indices = np.triu_indices(len(column_names), k=1)
unique_pairs = np.column_stack((column_names[indices[0]], column_names[indices[1]]))
for col1, col2 in unique_pairs:
duplicated_rows = df[[col1, col2, 'th_tr_id']].dropna().set_index('th_tr_id')
duplicated_rows = duplicated_rows[duplicated_rows.duplicated(subset=[col1, col2], keep=False)]
duplicated_groups = sorted(duplicated_rows.groupby([col1, col2]), key=lambda x: len(x[1]))
for matching_pairs_list in Parallel(n_jobs=-2, verbose=1, return_as="generator")(
delayed(get_matching_pairs)(group) for name, group in duplicated_groups):
for u, v in matching_pairs_list:
graph.addEdge(u, v)
适用于大型数据集,输出图包含 >700 万个节点和大约 100 亿条边。处理时间适中。
但无论如何,@ravenspoint 解决方案似乎是最好的,尽管对我来说有点硬核(用 C++ 实现)。