温馨提示×

Torch中怎么加载和使用预训练模型

小亿
147
2024-03-18 13:35:50
栏目: 深度学习

在Torch中加载和使用预训练模型通常通过使用torchvision.models模块来实现。以下是一个简单的示例,演示如何加载预训练的ResNet模型并使用它对图像进行预测:

import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image

# 加载预训练的ResNet模型
model = models.resnet18(pretrained=True)
model.eval()

# 定义图片预处理步骤
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 加载并预处理图像
img = Image.open('image.jpg')
img = transform(img).unsqueeze(0)

# 使用模型进行预测
output = model(img)

# 获取预测结果
_, predicted = torch.max(output, 1)

print('Predicted class:', predicted.item())

在上面的示例中,我们首先加载了预训练的ResNet模型,并将其设置为评估模式。然后定义了图像预处理步骤,并加载并预处理了一个示例图像。最后,我们使用模型对图像进行预测,并输出预测结果。

请注意,这只是一个简单的示例,实际应用中可能会有更复杂的预处理步骤和模型的使用方式,具体取决于你的应用场景和需求。

0