我正在尝试将 TabCBM 用于我自己的表格数据集(在 CSV 文件中提供)。在提供的模型中,它需要一些必需的模型,例如:
feature_to_concept_model
concept_to_feature_model
那么,这些部分意味着什么?论文中没有详细提及。例如,假设我们有 1000 行和 7 列的训练表格数据(特征)和 1000 个带有二进制元素的元素向量。我们如何使用 TabCBM 来训练测试数据的模型?
我根据代码和论文解释中提供的内容尝试了此代码作为简单示例:
from models.tabcbm import TabCBM
import numpy as np
import tensorflow as tf
# Generate random data for X_train and binary labels for y_train
X_train = np.random.randint(7, 100, size=(100, 7))
y_train = np.random.randint(2, size=100) # Binary labels
latent_dims = 4 # Number of latent concepts
# Define your feature-to-concepts model
features_to_concepts_model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(7,)), # Input shape should match your data
tf.keras.layers.Dense(latent_dims),
tf.keras.layers.Softmax()
])
# Define your concepts-to-labels model
concepts_to_labels_model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(latent_dims,)),
tf.keras.layers.Dense(1),
tf.keras.layers.Softmax()
])
# Define your features-to-embeddings model with the correct input shape
features_to_embeddings_model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(7,)), # Input shape should match your data
tf.keras.layers.Dense(latent_dims),
tf.keras.layers.Softmax()
])
# TabCBM parameters
tab_cbm_params = dict(
features_to_concepts_model = features_to_concepts_model,
features_to_embeddings_model = features_to_embeddings_model,
concepts_to_labels_model = concepts_to_labels_model,
mean_inputs=np.mean(X_train, axis=0),
loss_fn=tf.keras.losses.BinaryCrossentropy(), # Binary classification loss
latent_dims=latent_dims,
n_concepts=4,
n_supervised_concepts=0,
coherence_reg_weight=0.1,
diversity_reg_weight=0.1,
feature_selection_reg_weight=0.1,
prob_diversity_reg_weight=0.1,
concept_prediction_weight=0.1,
)
# Create and compile the TabCBM model
ss_tabcbm = TabCBM(
self_supervised_mode=True,
**tab_cbm_params,
)
# Compile the model
ss_tabcbm.compile(optimizer='adam', loss='binary_crossentropy')
# Print the model summary
ss_tabcbm.summary()
# Train the model (Ensure X_train and y_train have the correct shapes)
# ss_tabcbm.fit(X_train, y_train, validation_split=0.2, epochs=10, batch_size=256)
我怎么得到这个代码的错误!
感谢您提出这个问题。我已在 TabCBM 存储库中提出的 GitHub 票证中回复了这一问题。此外,我添加了一个“示例”,希望能够阐明如何将其适应自定义数据集。在这些澄清之后,如果您仍然遇到此存储库的问题,请告诉我。