快速入门深度学习 PyTorch(保姆级教程)第三期

第一期:?? juejin.cn/post/724000…

第二期:?? juejin.cn/post/724406…

四、torchvision 中的数据集使用

Ⅰ. torchvision 中有哪里数据集

torchvision有coco、Caltech 101、CIFAR10、FER2013、iNaturalist等等,
更多数据集可参考文档

Ⅱ. 数据集如何使用

import torchvision


# 在torchvision.datasets加载CIFAR10数据集,并指定一些参数,当然我们也可以使用其他数据集
train_data = torchvision.datasets.CIFAR10(root="./data", train=True, transform=None, download=True)
val_data = torchvision.datasets.CIFAR10(root="./data", train=False, transform=None, download=True)
# root 数据集存储路径
# train是否是训练数据集
# transform可以对数据进行转换
# download是否自动下载

print(train_data[0])

如果运行时出现该异常urllib.error.URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed (_ssl.c:xxx)> 可在下载数据集前面关闭SSL认证

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

下载后是这样的

image.png

五、DataLoader 的使用

import torchvision

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import ssl

ssl._create_default_https_context = ssl._create_unverified_context

# 加载torchvision提供的数据集
train_data = torchvision.datasets.CIFAR10(root="./data", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
val_data = torchvision.datasets.CIFAR10(root="./data", train=False, transform=torchvision.transforms.ToTensor(),
                                        download=True)
# 创建DataLoader对象,可以指定下列主要参数
# dataset: 数据集
# batch_size: 一批多少条数据
# shuffle: 这次数据读完后,下次是否洗牌
# sampler: 抽取器,默认是随机抽取
# num_workers: 多少子进程装载数据,默认0表示数据将在主进程中加载。
# drop_last: 最后一批,不够batch_size的数量,是否舍掉
val_dataLoader = DataLoader(dataset=val_data, batch_size=64, shuffle=True, sampler=None, num_workers=0, drop_last=False)

# 使用上期讲的 tensorboard 看一下我们没批的数据
writer = SummaryWriter("logs")
step = 0
# 遍历我们创建的dataLoader的实例
for images, targets in val_dataLoader:
    # images, targets 是我们那边一批的数据,images是[图片,图片,图片], targets是[标签对应的target,标签对应的target,标签对应的target]
    # 将一批图片通过 add_images 输出
    writer.add_images("CIFAR10-val", images, step)
    step = step + 1

writer.close()

然后打开tensorboard web页面可以看到每一批的图片数据

image.png

六、下下期预告

1. 神经网络基本骨架Module的使用

2. 神经网络的卷积层

会尽快更新~~~ 快来关注一下~~~

© 版权声明
THE END
喜欢就支持一下吧
点赞0

Warning: mysqli_query(): (HY000/3): Error writing file '/tmp/MYUcMMY8' (Errcode: 28 - No space left on device) in /www/wwwroot/583.cn/wp-includes/class-wpdb.php on line 2345
admin的头像-五八三
评论 抢沙发
头像
欢迎您留下宝贵的见解!
提交
头像

昵称

图形验证码
取消
昵称代码图片