这是我的预测功能。这有什么问题吗?预测并不稳定,每次我运行相同的数据时,我都会得到不同的预测。
def predict(model, device, inputs, batch_size=1024):
model = model.to(device)
dataset = torch.utils.data.TensorDataset(*inputs)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
pin_memory=False
)
predictions = []
for i, batch in enumerate(loader):
with torch.no_grad():
pred = model(*(item.to(device) for item in batch))
pred = pred.detach().cpu().numpy()
predictions.append(pred)
return np.concatenate(predictions)
智慧大石
慕妹3242003
相关分类