我一直在尝试使用未标记的数据测试我的级联模型。该过程涉及遍历特定目录中的每个数据实例,模型应为它们分配标签,然后继续下一个。但是,我遇到了数据似乎重复的问题,导致张量形状不正确。此外,对于第二个数据实例,返回一个空数组。这是我目前用于参考的代码。
class Hfdata(Dataset):
def __init__(self, data_dir, file_name):
self.data_dir = data_dir
self.file_name = file_name
with h5py.File(f"{data_dir}/{file_name}", "r") as f:
self.data = f["data"][:]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
x = self.data[idx]
print(x.shape)
return x
folder_path_sp = 'test_data'
file_list_sp = glob.glob(os.path.join(folder_path_sp, '*.h5'))
folder_path_tm = 'test_data_tm'
file_list_tm = glob.glob(os.path.join(folder_path_tm, '*.h5'))
model = Ensemble(ModelA(), ModelB())
model.load_state_dict(torch.load('Ensamble_model.pt'))
# predict on the new data
model.eval()
predictions = []
for file_path_sp, file_path_tm in zip(file_list_sp, file_list_tm):
file_name_sp = os.path.basename(file_path_sp)
file_name_tm = os.path.basename(file_path_tm)
print(f"Prediction on {file_name_sp} and {file_name_tm}")
dataset_sp_p = Hfdata(folder_path_sp, file_name_sp)
dataloader_sp_p = DataLoader(dataset_sp_p, batch_size=batch_size, shuffle=False)
dataset_tm_p = Hfdata(folder_path_tm, file_name_tm)
dataloader_tm_p = DataLoader(dataset_tm_p, batch_size=batch_size, shuffle=False)
for i, ((data_sp), (data_tm)) in enumerate(zip(dataloader_sp_p, dataloader_tm_p)):
# prepare the data
print(data_sp.shape)
data_sp = data_sp.permute(0, 4, 1,2, 3)
data_sp = data_sp.float()
print(data_tm.shape)
data_tm = data_tm.view(-1, 1, 609).float()
# pass the data through the model
with torch.no_grad():
outputs = model(data_sp, data_tm)
# obtain the predicted labels from the predicted outputs
_, predicted_labels = torch.max(outputs.data, 1)
# append the predicted labels to the list of predictions
predictions.append(predicted_labels)
print(predicted_labels)
# print the predictions
print(predictions)
这只是打印以下内容:
(79, 95, 79, 1)
(79, 95, 79, 1)
()
()
torch.Size([2, 79, 95, 79, 1])
torch. Size([2])
所以,我的你能帮我解决这个问题吗? 非常感谢。
i 期望第一个数据的 torch 大小为 ([batch size, channel, depth, height width)],而第二个数据为 ([1, 609]) 但这里不是这种情况。