要使用PyTorch训练出的模型,您需要执行以下步骤:
保存模型:在训练完成后,您可以将模型保存到本地文件系统中。使用torch.save()
函数可以将模型和优化器状态一起保存。例如:
torch.save(model.state_dict(), 'model.pth')
加载模型:在需要使用模型的地方,您可以使用torch.load()
函数加载模型。例如:
model = YourModelClass() # 创建模型实例
model.load_state_dict(torch.load('model.pth')) # 加载模型参数
model.eval() # 将模型设置为评估模式
进行推理:使用加载的模型对新的数据进行推理。例如:
with torch.no_grad(): # 关闭梯度计算
input_data = torch.randn(1, 3, 224, 224) # 创建输入数据张量
output = model(input_data) # 进行推理
print(output) # 输出推理结果
请注意,您需要根据您的具体需求修改代码中的模型类名和输入数据形状。