温馨提示×

温馨提示×

您好,登录后才能下订单哦!

密码登录×
登录注册×
其他方式登录
点击 登录注册 即表示同意《亿速云用户服务条款》

python中如何使用CIFAR10数据集

发布时间:2023-02-03 09:27:01 来源:亿速云 阅读:147 作者:iii 栏目:开发技术

这篇“python中如何使用CIFAR10数据集”文章的知识点大部分人都不太理解,所以小编给大家总结了以下内容,内容详细,步骤清晰,具有一定的借鉴价值,希望大家阅读完这篇文章能有所收获,下面我们一起来看看这篇“python中如何使用CIFAR10数据集”文章吧。

    关于CIFAR10数据集的使用

    主要解决了如何把数据集与transforms结合在一起的问题。

    CIFAR10的官方解释

    torchvision.datasets.CIFAR10(
    root: str, 
    train: bool = True, 
    transform: Optional[Callable] = None,
    target_transform: Optional[Callable] = None,
    download: bool = False)

    注释:

    • root (string)存在 cifar-10-batches-py 目录的数据集的根目录,如果下载设置为 True,则将保存到该目录。

    • train (bool, optional)如果为True,则从训练集创建数据集, 如果为False,从测试集创建数据集。

    • transform (callable, optional)它接受一个 PIL 图像并返回一个转换后的版本。 例如,transforms.RandomCrop/transforms.ToTensor

    • target_transform (callable, optional) 接收目标并对其进行转换的函数/转换。

    • download (bool, optional)如果为 true,则从 Internet 下载数据集并将其放在根目录中。 如果数据集已经下载,则不会再次下载。

    实战操作

    1.CIAFR10数据集的下载

    代码如下:

    import torchvision   #导入torchvision这个类
    
    train_set = torchvision.datasets.CIFAR10(root = "./dataset", train = True, 
    download= True)  #从训练集创建数据集
    test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False,
     download=True)    #从测试集创建数据集

    root = "./dataset",将下载的数据集保存在这个文件夹下;download= True,从 Internet 下载数据集并将其放在根目录中,这里就是在相对路径中,创建dataset文件夹,将数据集保存在dataset中。

    2.查看下载的CIAFR10数据集

    运行程序,开始下载数据集。下载成功后,可以进行一些查看。代码如下:

    接着输入:

    print(train_set[0])  #查看train_set训练集中的第一个数据
    print(train_set.classes)   #查看train_set训练集中有多少个类别
     
    img, target = train_set[0]
    print(img)
    print(target)
    print(train_set.classes[target])
    img.show()  #显示图片

    输出结果:

    (<PIL.Image.Image image mode=RGB size=32x32 at 0x161E924B8D0>, 6)
    ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship',
    'truck']
    <PIL.Image.Image image mode=RGB size=32x32 at 0x161E924B710>
    6
    frog

    注释:可以看见,train_set数据集中有10个类别,train_set中第0个元素的target是6,也就是说,这个元素是属于第7个类别frog的。

    3.数据转换

    因为这些图片类型都是PIL Image,如果要供给pytorch使用的话,需要将数据全都转化成tensor类型。

    完整代码如下:

    import torchvision   #导入torchvision这个类
    from torch.utils.tensorboard import SummaryWriter
    
    from torchvision import transforms
    dataset_transforms = transforms.ToTensor()
    
    # dataset_transforms = torchvision.transforms.Compose([
    #     torchvision.transforms.ToTensor()
    # ])    第3  4 行代码可以用compose直接写
    train_set = torchvision.datasets.CIFAR10(root = "./dataset", train = True, transform=dataset_transforms, download= True) #训练集
    test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transforms, download=True)   #测试集
    
    writer = SummaryWriter("logs")
    
    # print(train_set[0])  #查看train_set训练集中的第一个数据
    # print(train_set.classes)   #查看train_set训练集中有多少个类别
    
    # img, target = train_set[0]
    # print(img)
    # print(target)
    # print(train_set.classes[target])
    # img.show()
    for i in range(20):
        img, target = train_set[i]
        writer.add_image("cifar10_test2", img, i)
    
    writer.close()

    小结:CIFAR10数据集内存很小,只有100多m,下载方便。对我们学习数据集非常友好,练习的时候,我们可以使用SummaryWriter来将数据写入tensorboard中。

    CIFAR-10 数据集简介

    复现代码的过程中,简单了解了作者使用的数据集CIFAR-10 dataset ,简单记录一下。

    CIFAR-10数据集是8000万微小图片的标签子集,它的收集者是:Alex Krizhevsky, Vinod Nair, Geoffrey Hinton。

    数据集由6万张32*32的彩色图片组成,一共有10个类别。每个类别6000张图片。其中有5万张训练图片及1万张测试图片。

    数据集被划分为5个训练块和1个测试块,每个块1万张图片。

    测试块包含了1000张从每个类别中随机选择的图片。训练块包含随机的剩余图像,但某些训练块可能对于一个类别的包含多于其他类别,训练块包含来自各个类别的5000张图片。

    这些类是完全互斥的,及在一个类别中出现的图片不会出现在其它类中。

    数据集版本

    作者提供了3个版本的数据集:python version; Matlab version; binary version。

    可根据自己的需求选择。

    数据集布置

    以python version进行介绍,Matlab version与之相同。

    下载后获得文件 data_batch_1, data_batch_2,&hellip;, data_batch_5。测试块相同。这些文件中的每一个都是用cPickle生成的python pickled对象。

    具体使用方法:

    def unpickle(file):
        import pickle
        with open(file, 'rb') as fo:
            dict = pickle.load(fo, encoding='bytes')
        return dict

    返回字典类,每个块的文件包含一个字典类,包含以下元素:

    • data: 一个100003072的numpy数组(unit8)每个行存储3232的彩色图片,3072=1024*3,分别是red, green, blue。存储方式以行为主。

    • labels:使用0-9进行索引。

    数据集包含的另一个文件batches.meta同样包含python字典,用于加载label_names。如:label_names[0] == “airplane”, label_names[1] == “automobile”

    以上就是关于“python中如何使用CIFAR10数据集”这篇文章的内容,相信大家都有了一定的了解,希望小编分享的内容对大家有帮助,若想了解更多相关的知识内容,请关注亿速云行业资讯频道。

    向AI问一下细节

    免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

    AI