深度学习(五)——DatadLoader的使用

摘要:我们在打扑克,一摞的扑克牌就相当于dataset,拿牌的手相当于神经网络。而dataloader相当于抽牌的过程,它可以控制我们抽几张牌,用几只手抽牌。

一、DataLoader简介#

官网地址:

torch.utils.data — PyTorch 2.0 documentation

1. DataLoder类#

Copy
<span class="token keyword">class</span> <span class="token class-name">torch</span><span class="token punctuation">.</span>utils<span class="token punctuation">.</span>data<span class="token punctuation">.</span>DataLoader<span class="token punctuation">(</span>dataset<span class="token punctuation">,</span> batch_size<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span> shuffle<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> sampler<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> batch_sampler<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> num_workers<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span> collate_fn<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> pin_memory<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> drop_last<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> timeout<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span> worker_init_fn<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> multiprocessing_context<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> generator<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token operator">*</span><span class="token punctuation">,</span> prefetch_factor<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> persistent_workers<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> pin_memory_device<span class="token operator">=</span><span class="token string">''</span><span class="token punctuation">)</span>
<span aria-hidden="true" class="line-numbers-rows"><span></span></span>
<span class="token keyword">class</span> <span class="token class-name">torch</span><span class="token punctuation">.</span>utils<span class="token punctuation">.</span>data<span class="token punctuation">.</span>DataLoader<span class="token punctuation">(</span>dataset<span class="token punctuation">,</span> batch_size<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span> shuffle<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> sampler<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> batch_sampler<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> num_workers<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span> collate_fn<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> pin_memory<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> drop_last<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> timeout<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span> worker_init_fn<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> multiprocessing_context<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> generator<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token operator">*</span><span class="token punctuation">,</span> prefetch_factor<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> persistent_workers<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> pin_memory_device<span class="token operator">=</span><span class="token string">''</span><span class="token punctuation">)</span>
<span aria-hidden="true" class="line-numbers-rows"><span></span></span>
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='')

由此可见,DataLoder必须需要输入的参数只有dataset

2. 参数说明#

  • dataset(Dataset): 数据集的储存的路径位置等信息

  • batch_size(int): 每次取数据的数量,比如batchi_size=2,那么每次取2条数据

  • shuffle(bool): True: 打乱数据(可以理解为打牌中洗牌的过程); False: 不打乱。默认为False

  • num_workers(int): 加载数据的进程,多进程会更快。默认为0,即用主进程进行加载。但在windows系统下,num_workers如果非0,可能会出现 BrokenPipeError[Error 32] 错误

  • drop_last(bool): 比如我们从100条数据中每次取3条,到最后会余下1条,如果drop_last=True,那么这条数据会被舍弃(即只要前面99条数据);如果为False,则保留这条数据

二、DataLoader实操#

  • 数据集仍然采用上一篇的CIFAR10数据集

1. DataLoader取数据的逻辑#

  • 首先import dataset,dataset会返回一个数据的img和target

  • 然后import dataloder,并设置batch_size,比如batch_size=4,那么dataloder会获取这些数据:dataset[0]=img0, target0; dataset[1]=img1, target1; dataset[2]=img2, target2; dataset[3]=img3, target3. 并分别将其中的4个img和4个target进行打包,并返回打包好的imgs和targets

比如下面这串代码:

Copy
<span class="token keyword">import</span> torchvision
<span class="token keyword">from</span> torch<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>data <span class="token keyword">import</span> DataLoader
<span class="token comment">#测试集,并将PIL数据转化为tensor类型</span>
test_data<span class="token operator">=</span>torchvision<span class="token punctuation">.</span>datasets<span class="token punctuation">.</span>CIFAR10<span class="token punctuation">(</span><span class="token string">"./dataset"</span><span class="token punctuation">,</span>train<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>transform<span class="token operator">=</span>torchvision<span class="token punctuation">.</span>transforms<span class="token punctuation">.</span>ToTensor<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token comment">#batch_size=4:每次从test_data中取4个数据集并打包</span>
test_loader<span class="token operator">=</span>DataLoader<span class="token punctuation">(</span>dataset<span class="token operator">=</span>test_data<span class="token punctuation">,</span> batch_size<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> shuffle<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> num_workers<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span> drop_last<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span>
<span class="token keyword">import</span> torchvision
<span class="token keyword">from</span> torch<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>data <span class="token keyword">import</span> DataLoader

<span class="token comment">#测试集,并将PIL数据转化为tensor类型</span>
test_data<span class="token operator">=</span>torchvision<span class="token punctuation">.</span>datasets<span class="token punctuation">.</span>CIFAR10<span class="token punctuation">(</span><span class="token string">"./dataset"</span><span class="token punctuation">,</span>train<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>transform<span class="token operator">=</span>torchvision<span class="token punctuation">.</span>transforms<span class="token punctuation">.</span>ToTensor<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>

<span class="token comment">#batch_size=4:每次从test_data中取4个数据集并打包</span>
test_loader<span class="token operator">=</span>DataLoader<span class="token punctuation">(</span>dataset<span class="token operator">=</span>test_data<span class="token punctuation">,</span> batch_size<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> shuffle<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> num_workers<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span> drop_last<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span>
import torchvision from torch.utils.data import DataLoader #测试集,并将PIL数据转化为tensor类型 test_data=torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor()) #batch_size=4:每次从test_data中取4个数据集并打包 test_loader=DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)

这里的test_loader会取出test_data[0]、test_data[1]、test_data[2]、test_data[3]的img和target,并分别打包。返回两个参数:打包好的imgs,打包好的taregts

2. 如何取出DataLoader中打包好的img、target数据#

(1)输出打包好的img、target#

代码示例如下:

Copy
<span class="token keyword">import</span> torchvision
<span class="token keyword">from</span> torch<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>data <span class="token keyword">import</span> DataLoader
<span class="token comment">#测试集,并将PIL数据转化为tensor类型</span>
test_data<span class="token operator">=</span>torchvision<span class="token punctuation">.</span>datasets<span class="token punctuation">.</span>CIFAR10<span class="token punctuation">(</span><span class="token string">"./dataset"</span><span class="token punctuation">,</span>train<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>transform<span class="token operator">=</span>torchvision<span class="token punctuation">.</span>transforms<span class="token punctuation">.</span>ToTensor<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token comment">#batch_size=4:每次从test_data中取4个数据集并打包</span>
test_loader<span class="token operator">=</span>DataLoader<span class="token punctuation">(</span>dataset<span class="token operator">=</span>test_data<span class="token punctuation">,</span> batch_size<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> shuffle<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> num_workers<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span> drop_last<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
<span class="token comment">#测试数据集中第一章图片及target</span>
img<span class="token punctuation">,</span> target<span class="token operator">=</span>test_data<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span>
<span class="token keyword">print</span><span class="token punctuation">(</span>img<span class="token punctuation">.</span>shape<span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span>target<span class="token punctuation">)</span>
<span class="token comment">#取出test_loader中的图片</span>
<span class="token keyword">for</span> data <span class="token keyword">in</span> test_loader<span class="token punctuation">:</span>
imgs<span class="token punctuation">,</span>targets <span class="token operator">=</span> data
<span class="token keyword">print</span><span class="token punctuation">(</span>imgs<span class="token punctuation">.</span>shape<span class="token punctuation">)</span> <span class="token comment">#[Run] torch.Size([4, 3, 32, 32]) 4张图片打包,3通道,32×32</span>
<span class="token keyword">print</span><span class="token punctuation">(</span>targets<span class="token punctuation">)</span> <span class="token comment">#[Run] tensor([3, 5, 2, 7]) 4张图,每张图片对应的标签分别是3,5,2,7(某一次print的举例,每次print结果不太一样)</span>
<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span>
<span class="token keyword">import</span> torchvision
<span class="token keyword">from</span> torch<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>data <span class="token keyword">import</span> DataLoader

<span class="token comment">#测试集,并将PIL数据转化为tensor类型</span>
test_data<span class="token operator">=</span>torchvision<span class="token punctuation">.</span>datasets<span class="token punctuation">.</span>CIFAR10<span class="token punctuation">(</span><span class="token string">"./dataset"</span><span class="token punctuation">,</span>train<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>transform<span class="token operator">=</span>torchvision<span class="token punctuation">.</span>transforms<span class="token punctuation">.</span>ToTensor<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>

<span class="token comment">#batch_size=4:每次从test_data中取4个数据集并打包</span>
test_loader<span class="token operator">=</span>DataLoader<span class="token punctuation">(</span>dataset<span class="token operator">=</span>test_data<span class="token punctuation">,</span> batch_size<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> shuffle<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> num_workers<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span> drop_last<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>

<span class="token comment">#测试数据集中第一章图片及target</span>
img<span class="token punctuation">,</span> target<span class="token operator">=</span>test_data<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span>  
<span class="token keyword">print</span><span class="token punctuation">(</span>img<span class="token punctuation">.</span>shape<span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span>target<span class="token punctuation">)</span>

<span class="token comment">#取出test_loader中的图片</span>
<span class="token keyword">for</span> data <span class="token keyword">in</span> test_loader<span class="token punctuation">:</span>
    imgs<span class="token punctuation">,</span>targets <span class="token operator">=</span> data
    <span class="token keyword">print</span><span class="token punctuation">(</span>imgs<span class="token punctuation">.</span>shape<span class="token punctuation">)</span>    <span class="token comment">#[Run] torch.Size([4, 3, 32, 32])  4张图片打包,3通道,32×32</span>
    <span class="token keyword">print</span><span class="token punctuation">(</span>targets<span class="token punctuation">)</span>       <span class="token comment">#[Run] tensor([3, 5, 2, 7]) 4张图,每张图片对应的标签分别是3,5,2,7(某一次print的举例,每次print结果不太一样)</span>
<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span>
import torchvision from torch.utils.data import DataLoader #测试集,并将PIL数据转化为tensor类型 test_data=torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor()) #batch_size=4:每次从test_data中取4个数据集并打包 test_loader=DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False) #测试数据集中第一章图片及target img, target=test_data[0] print(img.shape) print(target) #取出test_loader中的图片 for data in test_loader: imgs,targets = data print(imgs.shape) #[Run] torch.Size([4, 3, 32, 32]) 4张图片打包,3通道,32×32 print(targets) #[Run] tensor([3, 5, 2, 7]) 4张图,每张图片对应的标签分别是3,5,2,7(某一次print的举例,每次print结果不太一样)

在11行处debug一下可以发现,test_loader中有个叫sampler的采样器,采取的是随机采样的方式,也就是说这batch_size=4时,每次抓取的4张图片都是随机抓取的。

(2)展示图片#

用tensorboard就可以可视化了,具体操作改一下上面代码最后的for循环就好了

Copy
<span class="token keyword">from</span> torch<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>tensorboard <span class="token keyword">import</span> SummaryWriter
writer<span class="token operator">=</span>SummaryWriter<span class="token punctuation">(</span><span class="token string">"dataloder"</span><span class="token punctuation">)</span>
step<span class="token operator">=</span><span class="token number">0</span> <span class="token comment">#tensorboard步长参数</span>
<span class="token keyword">for</span> data <span class="token keyword">in</span> test_loader<span class="token punctuation">:</span>
imgs<span class="token punctuation">,</span>targets <span class="token operator">=</span> data
<span class="token comment"># print(imgs.shape) #[Run] torch.Size([4, 3, 32, 32]) 4张图片打包,3通道,32×32</span>
<span class="token comment"># print(targets) #[Run] tensor([3, 5, 2, 7]) 4张图,每张图片对应的标签分别是3,5,2,7(某一次print的举例,每次print结果不太一样)</span>
writer<span class="token punctuation">.</span>add_images<span class="token punctuation">(</span><span class="token string">"test_data"</span><span class="token punctuation">,</span>imgs<span class="token punctuation">,</span>step<span class="token punctuation">)</span> <span class="token comment">#注意这里是add_images,不是add_image。因为这里是加入了64张图</span>
step<span class="token operator">=</span>step<span class="token operator">+</span><span class="token number">1</span>
writer<span class="token punctuation">.</span>close<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span>
<span class="token keyword">from</span> torch<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>tensorboard <span class="token keyword">import</span> SummaryWriter
writer<span class="token operator">=</span>SummaryWriter<span class="token punctuation">(</span><span class="token string">"dataloder"</span><span class="token punctuation">)</span>

step<span class="token operator">=</span><span class="token number">0</span>  <span class="token comment">#tensorboard步长参数</span>
<span class="token keyword">for</span> data <span class="token keyword">in</span> test_loader<span class="token punctuation">:</span>
    imgs<span class="token punctuation">,</span>targets <span class="token operator">=</span> data
    <span class="token comment"># print(imgs.shape)    #[Run] torch.Size([4, 3, 32, 32])  4张图片打包,3通道,32×32</span>
    <span class="token comment"># print(targets)       #[Run] tensor([3, 5, 2, 7]) 4张图,每张图片对应的标签分别是3,5,2,7(某一次print的举例,每次print结果不太一样)</span>
    writer<span class="token punctuation">.</span>add_images<span class="token punctuation">(</span><span class="token string">"test_data"</span><span class="token punctuation">,</span>imgs<span class="token punctuation">,</span>step<span class="token punctuation">)</span>  <span class="token comment">#注意这里是add_images,不是add_image。因为这里是加入了64张图</span>
    step<span class="token operator">=</span>step<span class="token operator">+</span><span class="token number">1</span>
writer<span class="token punctuation">.</span>close<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span>
from torch.utils.tensorboard import SummaryWriter writer=SummaryWriter("dataloder") step=0 #tensorboard步长参数 for data in test_loader: imgs,targets = data # print(imgs.shape) #[Run] torch.Size([4, 3, 32, 32]) 4张图片打包,3通道,32×32 # print(targets) #[Run] tensor([3, 5, 2, 7]) 4张图,每张图片对应的标签分别是3,5,2,7(某一次print的举例,每次print结果不太一样) writer.add_images("test_data",imgs,step) #注意这里是add_images,不是add_image。因为这里是加入了64张图 step=step+1 writer.close()

(3)关于shuffle的理解#

  • 可以理解为一个for循环就是打一次牌,打完一轮牌后,若shuffle=False,那么下一轮每一步抓到的牌都会跟上一轮相同;如果shuffle=True,那么就会进行洗牌,打乱牌的顺序后,下一轮每一步跟上一轮的会有不同。

首先将shuffle设置为False:

Copy
test_loader<span class="token operator">=</span>DataLoader<span class="token punctuation">(</span>dataset<span class="token operator">=</span>test_data<span class="token punctuation">,</span> batch_size<span class="token operator">=</span><span class="token number">64</span><span class="token punctuation">,</span> shuffle<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> num_workers<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span> drop_last<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
<span aria-hidden="true" class="line-numbers-rows"><span></span></span>
test_loader<span class="token operator">=</span>DataLoader<span class="token punctuation">(</span>dataset<span class="token operator">=</span>test_data<span class="token punctuation">,</span> batch_size<span class="token operator">=</span><span class="token number">64</span><span class="token punctuation">,</span> shuffle<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> num_workers<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span> drop_last<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
<span aria-hidden="true" class="line-numbers-rows"><span></span></span>
test_loader=DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)

然后对(2)的代码进行修改,运行代码:

Copy
<span class="token keyword">for</span> epoch <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token comment">#假设打两次牌,我们来观察两次牌中间的洗牌情况</span>
step <span class="token operator">=</span> <span class="token number">0</span> <span class="token comment"># tensorboard步长参数</span>
<span class="token keyword">for</span> data <span class="token keyword">in</span> test_loader<span class="token punctuation">:</span>
imgs<span class="token punctuation">,</span>targets <span class="token operator">=</span> data
<span class="token comment"># print(imgs.shape) #[Run] torch.Size([4, 3, 32, 32]) 4张图片打包,3通道,32×32</span>
<span class="token comment"># print(targets) #[Run] tensor([3, 5, 2, 7]) 4张图,每张图片对应的标签分别是3,5,2,7(某一次print的举例,每次print结果不太一样)</span>
writer<span class="token punctuation">.</span>add_images<span class="token punctuation">(</span><span class="token string">"Epoch: {}"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>epoch<span class="token punctuation">)</span><span class="token punctuation">,</span>imgs<span class="token punctuation">,</span>step<span class="token punctuation">)</span> <span class="token comment">#注意这里是add_images,不是add_image。因为这里是加入了64张图</span>
step<span class="token operator">=</span>step<span class="token operator">+</span><span class="token number">1</span>
writer<span class="token punctuation">.</span>close<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span>
<span class="token keyword">for</span> epoch <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">:</span>  <span class="token comment">#假设打两次牌,我们来观察两次牌中间的洗牌情况</span>
    step <span class="token operator">=</span> <span class="token number">0</span>  <span class="token comment"># tensorboard步长参数</span>
    <span class="token keyword">for</span> data <span class="token keyword">in</span> test_loader<span class="token punctuation">:</span>
        imgs<span class="token punctuation">,</span>targets <span class="token operator">=</span> data
        <span class="token comment"># print(imgs.shape)    #[Run] torch.Size([4, 3, 32, 32])  4张图片打包,3通道,32×32</span>
        <span class="token comment"># print(targets)       #[Run] tensor([3, 5, 2, 7]) 4张图,每张图片对应的标签分别是3,5,2,7(某一次print的举例,每次print结果不太一样)</span>
        writer<span class="token punctuation">.</span>add_images<span class="token punctuation">(</span><span class="token string">"Epoch: {}"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>epoch<span class="token punctuation">)</span><span class="token punctuation">,</span>imgs<span class="token punctuation">,</span>step<span class="token punctuation">)</span>  <span class="token comment">#注意这里是add_images,不是add_image。因为这里是加入了64张图</span>
        step<span class="token operator">=</span>step<span class="token operator">+</span><span class="token number">1</span>
writer<span class="token punctuation">.</span>close<span class="token punctuation">(</span><span class="token punctuation">)</span>

<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span>
for epoch in range(2): #假设打两次牌,我们来观察两次牌中间的洗牌情况 step = 0 # tensorboard步长参数 for data in test_loader: imgs,targets = data # print(imgs.shape) #[Run] torch.Size([4, 3, 32, 32]) 4张图片打包,3通道,32×32 # print(targets) #[Run] tensor([3, 5, 2, 7]) 4张图,每张图片对应的标签分别是3,5,2,7(某一次print的举例,每次print结果不太一样) writer.add_images("Epoch: {}".format(epoch),imgs,step) #注意这里是add_images,不是add_image。因为这里是加入了64张图 step=step+1 writer.close()

结果显示,未洗牌时运行的结果是一样的:

  • 将shuffle设置为True,再次运行,可以发现两次结果还是不一样的:

© 版权声明
THE END
喜欢就支持一下吧
点赞0

Warning: mysqli_query(): (HY000/3): Error writing file '/tmp/MY7A22pV' (Errcode: 28 - No space left on device) in /www/wwwroot/583.cn/wp-includes/class-wpdb.php on line 2345
admin的头像-五八三
评论 抢沙发
头像
欢迎您留下宝贵的见解!
提交
头像

昵称

图形验证码
取消
昵称代码图片