【代码 bug 记录】PyTorch 的 Dataloader 如何加载 List 对象?
0、写在前面在记录该问题解决方案的时候,也有在 csdn 上搜到某位小伙伴遇到同样的问题,但没有说明原因。那我就记录一下吧。1、问题之前看到一份代码,在 __init__() 函数中,加载的每一条数据都是一个列表 List【长度为 len_list】,列表中的每一项是一段经过处理的视频,维度为 [C, T, H, W]。所以 dataset 中每一条数据的维度应该是 [len_list, C, T
0、写在前面
在记录该问题解决方案的时候,也有在 csdn 上搜到某位小伙伴遇到同样的问题,但没有说明原因。那我就记录一下吧。
1、问题
之前看到一份代码,在 __init__() 函数中,加载的每一条数据都是一个列表 List【长度为 len_list】,列表中的每一项是一段经过处理的视频,维度为 [C, T, H, W]。
所以 dataset 中每一条数据的维度应该是 [len_list, C, T, H, W]。
按照以往加载数据的经验,我自然而然地认为用 dataloader 返回的数据维度应该是 [B, len_list, C, T, H, W]。然而,事情不是这样的!实际上用 dataloader 返回的数据维度是 [len_list, B, C, T, H, W]。
我: ???
2、原因
幸亏同实验室的大神了解过这方面的源码,告诉了我原因:
如果 dataset 返回的 sample 是序列(Sequence)类【如:字符串(普通字符串和unicode字符串),列表和元组】的,那 dataloader 默认把 B(batch size)那个维度加在序列里每个 item 的 shape 前面。
相关部分源码:
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
3、解决方案
1、在 __init__() 函数里,先把每一条数据转成 Tensor,而不是直接返回 List。这样用 dataloader 加载数据的维度就是我熟悉的:[B, len_list, C, T, H, W]
2、若直接返回 List,则注意在 __getitem__() 函数里处理数据时,维度是 [len_list, B, C, T, H, W]
更多推荐
所有评论(0)