这篇文章主要讲解了“Pytorch怎么实现简单的垃圾分类”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“Pytorch怎么实现简单的垃圾分类”吧!
垃圾数据都放在了名字为「垃圾图片库」的文件夹里。
首先,我们需要写个脚本根据文件夹名,生成对应的标签文件(dir_label.txt)。
前面是小分类标签,后面是大分类标签。
然后再将数据集分为训练集(train.txt)、验证集(val.txt)、测试集(test.txt)。
训练集和验证集用于训练模型,测试集用于验收最终模型效果。
此外,在使用图片训练之前还需要检查下图片质量,使用 PIL 的 Image 读取,捕获 Error 和 Warning 异常,对有问题的图片直接删除即可。
写个脚本生成三个 txt 文件,训练集 48045 张,验证集 5652 张,测试集 2826 张。
脚本很简单,代码就不贴了,直接提供处理好的文件。
处理好的四个 txt 文件可以直接下载。
编写 dataset.py 读取数据,看一下效果。
import torch from PIL import Image import os import glob from torch.utils.data import Dataset import random import torchvision.transforms as transforms from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True class Garbage_Loader(Dataset): def __init__(self, txt_path, train_flag=True): self.imgs_info = self.get_images(txt_path) self.train_flag = train_flag self.train_tf = transforms.Compose([ transforms.Resize(224), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.ToTensor(), ]) self.val_tf = transforms.Compose([ transforms.Resize(224), transforms.ToTensor(), ]) def get_images(self, txt_path): with open(txt_path, 'r', encoding='utf-8') as f: imgs_info = f.readlines() imgs_info = list(map(lambda x:x.strip().split('\t'), imgs_info)) return imgs_info def padding_black(self, img): w, h = img.size scale = 224. / max(w, h) img_fg = img.resize([int(x) for x in [w * scale, h * scale]]) size_fg = img_fg.size size_bg = 224 img_bg = Image.new("RGB", (size_bg, size_bg)) img_bg.paste(img_fg, ((size_bg - size_fg[0]) // 2, (size_bg - size_fg[1]) // 2)) img = img_bg return img def __getitem__(self, index): img_path, label = self.imgs_info[index] img = Image.open(img_path) img = img.convert('RGB') img = self.padding_black(img) if self.train_flag: img = self.train_tf(img) else: img = self.val_tf(img) label = int(label) return img, label def __len__(self): return len(self.imgs_info) if __name__ == "__main__": train_dataset = Garbage_Loader("train.txt", True) print("数据个数:", len(train_dataset)) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=1, shuffle=True) for image, label in train_loader: print(image.shape) print(label)
读取 train.txt 文件,加载数据。数据预处理,是将图片等比例填充到尺寸为 280 * 280 的纯黑色图片上,然后再 resize 到 224 * 224 的尺寸。
这是图片分类里,很常规的一种预处理方法。
此外,针对训练集,使用 pytorch 的 transforms 添加了水平翻转和垂直翻转的随机操作,这也是很常见的一种数据增强方法。
运行结果:
OK,搞定!开始写训练代码!
我们使用一个常规的网络 ResNet50 ,这是一个非常常见的提取特征的网络结构。
创建 train.py 文件,编写如下代码:
from dataset import Garbage_Loader from torch.utils.data import DataLoader from torchvision import models import torch.nn as nn import torch.optim as optim import torch import time import os import shutil os.environ["CUDA_VISIBLE_DEVICES"] = "0" """ Author : Jack Cui Wechat : https://mp.weixin.qq.com/s/OCWwRVDFNslIuKyiCVUoTA """ from tensorboardX import SummaryWriter def accuracy(output, target, topk=(1,)): """ 计算topk的准确率 """ with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) class_to = pred[0].cpu().numpy() res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res, class_to def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): """ 根据 is_best 存模型,一般保存 valid acc 最好的模型 """ torch.save(state, filename) if is_best: shutil.copyfile(filename, 'model_best_' + filename) def train(train_loader, model, criterion, optimizer, epoch, writer): """ 训练代码 参数: train_loader - 训练集的 DataLoader model - 模型 criterion - 损失函数 optimizer - 优化器 epoch - 进行第几个 epoch writer - 用于写 tensorboardX """ batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to train mode model.train() end = time.time() for i, (input, target) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) input = input.cuda() target = target.cuda() # compute output output = model(input) loss = criterion(output, target) # measure accuracy and record loss [prec1, prec5], class_to = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1[0], input.size(0)) top5.update(prec5[0], input.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % 10 == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5)) writer.add_scalar('loss/train_loss', losses.val, global_step=epoch) def validate(val_loader, model, criterion, epoch, writer, phase="VAL"): """ 验证代码 参数: val_loader - 验证集的 DataLoader model - 模型 criterion - 损失函数 epoch - 进行第几个 epoch writer - 用于写 tensorboardX """ batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() for i, (input, target) in enumerate(val_loader): input = input.cuda() target = target.cuda() # compute output output = model(input) loss = criterion(output, target) # measure accuracy and record loss [prec1, prec5], class_to = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1[0], input.size(0)) top5.update(prec5[0], input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % 10 == 0: print('Test-{0}: [{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( phase, i, len(val_loader), batch_time=batch_time, loss=losses, top1=top1, top5=top5)) print(' * {} Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' .format(phase, top1=top1, top5=top5)) writer.add_scalar('loss/valid_loss', losses.val, global_step=epoch) return top1.avg, top5.avg class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count if __name__ == "__main__": # -------------------------------------------- step 1/4 : 加载数据 --------------------------- train_dir_list = 'train.txt' valid_dir_list = 'val.txt' batch_size = 64 epochs = 80 num_classes = 214 train_data = Garbage_Loader(train_dir_list, train_flag=True) valid_data = Garbage_Loader(valid_dir_list, train_flag=False) train_loader = DataLoader(dataset=train_data, num_workers=8, pin_memory=True, batch_size=batch_size, shuffle=True) valid_loader = DataLoader(dataset=valid_data, num_workers=8, pin_memory=True, batch_size=batch_size) train_data_size = len(train_data) print('训练集数量:%d' % train_data_size) valid_data_size = len(valid_data) print('验证集数量:%d' % valid_data_size) # ------------------------------------ step 2/4 : 定义网络 ------------------------------------ model = models.resnet50(pretrained=True) fc_inputs = model.fc.in_features model.fc = nn.Linear(fc_inputs, num_classes) model = model.cuda() # ------------------------------------ step 3/4 : 定义损失函数和优化器等 ------------------------- lr_init = 0.0001 lr_stepsize = 20 weight_decay = 0.001 criterion = nn.CrossEntropyLoss().cuda() optimizer = optim.Adam(model.parameters(), lr=lr_init, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_stepsize, gamma=0.1) writer = SummaryWriter('runs/resnet50') # ------------------------------------ step 4/4 : 训练 ----------------------------------------- best_prec1 = 0 for epoch in range(epochs): scheduler.step() train(train_loader, model, criterion, optimizer, epoch, writer) # 在验证集上测试效果 valid_prec1, valid_prec5 = validate(valid_loader, model, criterion, epoch, writer, phase="VAL") is_best = valid_prec1 > best_prec1 best_prec1 = max(valid_prec1, best_prec1) save_checkpoint({ 'epoch': epoch + 1, 'arch': 'resnet50', 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer' : optimizer.state_dict(), }, is_best, filename='checkpoint_resnet50.pth.tar') writer.close()
代码并不复杂,网络结构直接使 torchvision 的 ResNet50 模型,并且采用 ResNet50 的预训练模型。算法采用交叉熵损失函数,优化器选择 Adam,并采用 StepLR 进行学习率衰减。
保存模型的策略是选择在验证集准确率最高的模型。
batch size 设为 64,GPU 显存大约占 8G,显存不够的,可以调整 batch size 大小。
模型训练完成,就可以写测试代码了,看下效果吧!
创建 infer.py 文件,编写如下代码:
from dataset import Garbage_Loader from torch.utils.data import DataLoader import torchvision.transforms as transforms from torchvision import models import torch.nn as nn import torch import os import numpy as np import matplotlib.pyplot as plt #%matplotlib inline os.environ["CUDA_VISIBLE_DEVICES"] = "0" def softmax(x): exp_x = np.exp(x) softmax_x = exp_x / np.sum(exp_x, 0) return softmax_x with open('dir_label.txt', 'r', encoding='utf-8') as f: labels = f.readlines() labels = list(map(lambda x:x.strip().split('\t'), labels)) if __name__ == "__main__": test_list = 'test.txt' test_data = Garbage_Loader(test_list, train_flag=False) test_loader = DataLoader(dataset=test_data, num_workers=1, pin_memory=True, batch_size=1) model = models.resnet50(pretrained=False) fc_inputs = model.fc.in_features model.fc = nn.Linear(fc_inputs, 214) model = model.cuda() # 加载训练好的模型 checkpoint = torch.load('model_best_checkpoint_resnet50.pth.tar') model.load_state_dict(checkpoint['state_dict']) model.eval() for i, (image, label) in enumerate(test_loader): src = image.numpy() src = src.reshape(3, 224, 224) src = np.transpose(src, (1, 2, 0)) image = image.cuda() label = label.cuda() pred = model(image) pred = pred.data.cpu().numpy()[0] score = softmax(pred) pred_id = np.argmax(score) plt.imshow(src) print('预测结果:', labels[pred_id][0]) plt.show()
这里需要注意的是,DataLoader 读取的数据需要进行通道转换,才能显示。
预测结果:
感谢各位的阅读,以上就是“Pytorch怎么实现简单的垃圾分类”的内容了,经过本文的学习后,相信大家对Pytorch怎么实现简单的垃圾分类这一问题有了更深刻的体会,具体使用情况还需要大家实践验证。这里是亿速云,小编将为大家推送更多相关知识点的文章,欢迎关注!
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。