为多维数据集找到正确的插值方法

问题描述 投票:0回答:1

我尝试过 scipy.interpolate 中的“LinearNDInterpolator”。我唯一需要验证的数据是 Matlab 中使用的另一种方法(interpn)的另一组插值点,但目标不是与 Matlab 方法的结果匹配。我也没有超过 4 个维度的数据集示例可供测试。 目标是找到一个受以下约束的插值器:

  • python 3.9
  • 它必须适用于 2D 数据集和更高维度的数据集(最多 10 维)

问题:鉴于上述限制,您将使用哪种方法来插值“输出”?

下面您将看到一个数据集的示例。这个似乎是结构化的(是吗?可以将其视为规则网格吗?),但其他数据集可能并非如此。

所有要点:

>>> df_out[['x', 'y', 'z', 'output']].to_dict()
{'x': {0: -8.0, 1: 8.0, 2: 10.0, 3: 20.0, 4: 36.0, 5: -8.0, 6: 8.0, 7: 10.0, 8: 20.0, 9: 36.0, 10: -8.0, 11: 8.0, 12: 10.0, 13: 20.0, 14: 36.0, 15: -8.0, 16: 8.0, 17: 10.0, 18: 20.0, 19: 36.0, 20: -8.0, 21: 8.0, 22: 10.0, 23: 20.0, 24: 36.0, 25: -8.0, 26: 8.0, 27: 10.0, 28: 20.0, 29: 36.0, 30: -8.0, 31: 8.0, 32: 10.0, 33: 20.0, 34: 36.0, 35: -8.0, 36: 8.0, 37: 10.0, 38: 20.0, 39: 36.0, 40: -8.0, 41: 8.0, 42: 10.0, 43: 20.0, 44: 36.0, 45: -8.0, 46: 8.0, 47: 10.0, 48: 20.0, 49: 36.0, 50: -8.0, 51: 8.0, 52: 10.0, 53: 20.0, 54: 36.0, 55: -8.0, 56: 8.0, 57: 10.0, 58: 20.0, 59: 36.0}, 'y': {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, 4: 0.0, 5: 0.4, 6: 0.4, 7: 0.4, 8: 0.4, 9: 0.4, 10: 0.8, 11: 0.8, 12: 0.8, 13: 0.8, 14: 0.8, 15: 0.0, 16: 0.0, 17: 0.0, 18: 0.0, 19: 0.0, 20: 0.4, 21: 0.4, 22: 0.4, 23: 0.4, 24: 0.4, 25: 0.8, 26: 0.8, 27: 0.8, 28: 0.8, 29: 0.8, 30: 0.0, 31: 0.0, 32: 0.0, 33: 0.0, 34: 0.0, 35: 0.4, 36: 0.4, 37: 0.4, 38: 0.4, 39: 0.4, 40: 0.8, 41: 0.8, 42: 0.8, 43: 0.8, 44: 0.8, 45: 0.0, 46: 0.0, 47: 0.0, 48: 0.0, 49: 0.0, 50: 0.4, 51: 0.4, 52: 0.4, 53: 0.4, 54: 0.4, 55: 0.8, 56: 0.8, 57: 0.8, 58: 0.8, 59: 0.8}, 'z': {0: -25.0, 1: -25.0, 2: -25.0, 3: -25.0, 4: -25.0, 5: -25.0, 6: -25.0, 7: -25.0, 8: -25.0, 9: -25.0, 10: -25.0, 11: -25.0, 12: -25.0, 13: -25.0, 14: -25.0, 15: -10.0, 16: -10.0, 17: -10.0, 18: -10.0, 19: -10.0, 20: -10.0, 21: -10.0, 22: -10.0, 23: -10.0, 24: -10.0, 25: -10.0, 26: -10.0, 27: -10.0, 28: -10.0, 29: -10.0, 30: 0.0, 31: 0.0, 32: 0.0, 33: 0.0, 34: 0.0, 35: 0.0, 36: 0.0, 37: 0.0, 38: 0.0, 39: 0.0, 40: 0.0, 41: 0.0, 42: 0.0, 43: 0.0, 44: 0.0, 45: 15.0, 46: 15.0, 47: 15.0, 48: 15.0, 49: 15.0, 50: 15.0, 51: 15.0, 52: 15.0, 53: 15.0, 54: 15.0, 55: 15.0, 56: 15.0, 57: 15.0, 58: 15.0, 59: 15.0}, 'output': {0: -0.02, 1: -0.02, 2: -0.02, 3: -0.1, 4: -0.11, 5: -0.02, 6: -0.02, 7: -0.02, 8: -0.1, 9: -0.11, 10: -0.02, 11: -0.02, 12: -0.02, 13: -0.1, 14: -0.11, 15: -0.018, 16: -0.018, 17: -0.018, 18: -0.1, 19: -0.1, 20: -0.018, 21: -0.018, 22: -0.018, 23: -0.1, 24: -0.1, 25: -0.018, 26: -0.018, 27: -0.018, 28: -0.1, 29: -0.1, 30: -0.014, 31: -0.014, 32: -0.014, 33: -0.09, 34: -0.09, 35: -0.014, 36: 0.014, 37: -0.014, 38: -0.09, 39: -0.09, 40: -0.014, 41: -0.0184, 42: -0.014, 43: -0.09, 44: -0.09, 45: -0.014, 46: -0.014, 47: -0.014, 48: -0.09, 49: -0.09, 50: -0.014, 51: 0.014, 52: -0.014, 53: -0.09, 54: -0.09, 55: -0.014, 56: -0.0184, 57: -0.014, 58: -0.09, 59: -0.09}}

这是我用来插入“输出”并生成绘图的代码:

import matplotlib.pyplot as plt
from scipy.interpolate import LinearNDInterpolator

# Load the data to a dataframe "df_out"
test_nd = LinearNDInterpolator(points=df_out[['x', 'y', 'z']], values=df_out.output)
# Interpolate output
print(f'output = {test_nd(13, 0.5, -20)}')

df_out['area'] = 50
fig = plt.figure()
cm = plt.cm.get_cmap('RdYlBu')
ax = fig.add_subplot(projection='3d')
sp = ax.scatter(df_out.x, df_out.y, df_out.z, s=df_out.area, c=df_out.output, cmap=cm)
ax.set_xlabel('x data')
ax.set_ylabel('y data')
ax.set_zlabel('z data')
fig.colorbar(sp, label='output')
plt.show(block=True)

非常感谢您的帮助!

python matlab scipy interpolation
1个回答
0
投票

它必须适用于 2D 数据集和更高维度的数据集(最多 10 个维度)

这并没有大幅缩小范围。 SciPy 中的许多插值器在任意数量的维度上工作。除了专门用于 1D 或 2D 数据的插值器之外,所有这些都可以处理任意数量的维度

但是考虑到这里的数据集,可以使用的插值器之一是

scipy.interpolate.interpn()

不幸的是,它期望的数据格式与您所拥有的不同。

您的格式如下:

x y out
0 0 1
0 1 2
1 0 3
1 1 4

但是 interpn 期望这样的格式:

  x 0 1
y
0   1 2
1   3 4

要在格式之间进行转换,您可以尝试这样的操作:

import pandas as pd
import numpy as np
import scipy


df = pd.DataFrame({'x': {0: -8.0, 1: 8.0, 2: 10.0, 3: 20.0, 4: 36.0, 5: -8.0, 6: 8.0, 7: 10.0, 8: 20.0, 9: 36.0, 10: -8.0, 11: 8.0, 12: 10.0, 13: 20.0, 14: 36.0, 15: -8.0, 16: 8.0, 17: 10.0, 18: 20.0, 19: 36.0, 20: -8.0, 21: 8.0, 22: 10.0, 23: 20.0, 24: 36.0, 25: -8.0, 26: 8.0, 27: 10.0, 28: 20.0, 29: 36.0, 30: -8.0, 31: 8.0, 32: 10.0, 33: 20.0, 34: 36.0, 35: -8.0, 36: 8.0, 37: 10.0, 38: 20.0, 39: 36.0, 40: -8.0, 41: 8.0, 42: 10.0, 43: 20.0, 44: 36.0, 45: -8.0, 46: 8.0, 47: 10.0, 48: 20.0, 49: 36.0, 50: -8.0, 51: 8.0, 52: 10.0, 53: 20.0, 54: 36.0, 55: -8.0, 56: 8.0, 57: 10.0, 58: 20.0, 59: 36.0}, 'y': {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, 4: 0.0, 5: 0.4, 6: 0.4, 7: 0.4, 8: 0.4, 9: 0.4, 10: 0.8, 11: 0.8, 12: 0.8, 13: 0.8, 14: 0.8, 15: 0.0, 16: 0.0, 17: 0.0, 18: 0.0, 19: 0.0, 20: 0.4, 21: 0.4, 22: 0.4, 23: 0.4, 24: 0.4, 25: 0.8, 26: 0.8, 27: 0.8, 28: 0.8, 29: 0.8, 30: 0.0, 31: 0.0, 32: 0.0, 33: 0.0, 34: 0.0, 35: 0.4, 36: 0.4, 37: 0.4, 38: 0.4, 39: 0.4, 40: 0.8, 41: 0.8, 42: 0.8, 43: 0.8, 44: 0.8, 45: 0.0, 46: 0.0, 47: 0.0, 48: 0.0, 49: 0.0, 50: 0.4, 51: 0.4, 52: 0.4, 53: 0.4, 54: 0.4, 55: 0.8, 56: 0.8, 57: 0.8, 58: 0.8, 59: 0.8}, 'z': {0: -25.0, 1: -25.0, 2: -25.0, 3: -25.0, 4: -25.0, 5: -25.0, 6: -25.0, 7: -25.0, 8: -25.0, 9: -25.0, 10: -25.0, 11: -25.0, 12: -25.0, 13: -25.0, 14: -25.0, 15: -10.0, 16: -10.0, 17: -10.0, 18: -10.0, 19: -10.0, 20: -10.0, 21: -10.0, 22: -10.0, 23: -10.0, 24: -10.0, 25: -10.0, 26: -10.0, 27: -10.0, 28: -10.0, 29: -10.0, 30: 0.0, 31: 0.0, 32: 0.0, 33: 0.0, 34: 0.0, 35: 0.0, 36: 0.0, 37: 0.0, 38: 0.0, 39: 0.0, 40: 0.0, 41: 0.0, 42: 0.0, 43: 0.0, 44: 0.0, 45: 15.0, 46: 15.0, 47: 15.0, 48: 15.0, 49: 15.0, 50: 15.0, 51: 15.0, 52: 15.0, 53: 15.0, 54: 15.0, 55: 15.0, 56: 15.0, 57: 15.0, 58: 15.0, 59: 15.0}, 'output': {0: -0.02, 1: -0.02, 2: -0.02, 3: -0.1, 4: -0.11, 5: -0.02, 6: -0.02, 7: -0.02, 8: -0.1, 9: -0.11, 10: -0.02, 11: -0.02, 12: -0.02, 13: -0.1, 14: -0.11, 15: -0.018, 16: -0.018, 17: -0.018, 18: -0.1, 19: -0.1, 20: -0.018, 21: -0.018, 22: -0.018, 23: -0.1, 24: -0.1, 25: -0.018, 26: -0.018, 27: -0.018, 28: -0.1, 29: -0.1, 30: -0.014, 31: -0.014, 32: -0.014, 33: -0.09, 34: -0.09, 35: -0.014, 36: 0.014, 37: -0.014, 38: -0.09, 39: -0.09, 40: -0.014, 41: -0.0184, 42: -0.014, 43: -0.09, 44: -0.09, 45: -0.014, 46: -0.014, 47: -0.014, 48: -0.09, 49: -0.09, 50: -0.014, 51: 0.014, 52: -0.014, 53: -0.09, 54: -0.09, 55: -0.014, 56: -0.0184, 57: -0.014, 58: -0.09, 59: -0.09}})


def dataset_to_dense(df, index_cols, value_col):
    points = []
    # Find all distinct values within index_cols.
    for col in index_cols:
        # Index must be sorted for np.searchsorted() later
        ordered_index = sorted(df[col].unique())
        points.append(np.array(ordered_index))
    values = np.full([len(points_array) for points_array in points], np.nan)
    value_indices = []
    # For each column, convert column values to array indices
    for i, col in enumerate(index_cols):
        values_along_index = points[i]
        column_values = df[col].values
        # Note: this assumes that values_along_index contains all values within this col.
        # Otherwise indices will be off by one
        column_index = np.searchsorted(values_along_index, column_values)
        # Sanity check: does indexing into points[i] with this new index produce the
        # same results as the original column?
        assert (values_along_index[column_index] == column_values).all(), \
            f"could not find all elements of {col} in indexer"
        value_indices.append(column_index)
    # Set values of array at indices found in previous step
    values[tuple(value_indices)] = df[value_col]
    return points, values
    

def interpolate_dataset(df, index_cols, value_col, xi):
    points, values = dataset_to_dense(df, index_cols, value_col)
    return scipy.interpolate.interpn(points, values, xi)


xi = [
    (13, 0.5, -20),
    (35, 0, 0),
]
interpolated = interpolate_dataset(df, ['x', 'y', 'z'], 'output', xi)
print(interpolated)

备注:

  • xi
    表示查询插值的点。
© www.soinside.com 2019 - 2024. All rights reserved.