0、写在前面

 在记录该问题解决方案的时候,也有在 csdn 上搜到某位小伙伴遇到同样的问题,但没有说明原因。那我就记录一下吧。

1、问题

之前看到一份代码,在 __init__() 函数中,加载的每一条数据都是一个列表 List【长度为 len_list】,列表中的每一项是一段经过处理的视频,维度为 [C, T, H, W]。

所以 dataset 中每一条数据的维度应该是 [len_list, C, T, H, W]。

按照以往加载数据的经验,我自然而然地认为用 dataloader 返回的数据维度应该是 [B, len_list, C, T, H, W]。然而,事情不是这样的!实际上用 dataloader 返回的数据维度是 [len_list, B, C, T, H, W]。

我: ???

2、原因

幸亏同实验室的大神了解过这方面的源码,告诉了我原因

如果 dataset 返回的 sample 是序列(Sequence)类【如:字符串(普通字符串和unicode字符串),列表和元组】的,那 dataloader 默认把 B(batch size)那个维度加在序列里每个 item 的 shape 前面。

相关部分源码

elif isinstance(elem, collections.abc.Sequence):
    # check to make sure that the elements in batch have consistent size
    it = iter(batch)
    elem_size = len(next(it))
    if not all(len(elem) == elem_size for elem in it):
        raise RuntimeError('each element in list of batch should be of equal size')
    transposed = zip(*batch)
    return [default_collate(samples) for samples in transposed]

源码地址https://github.com/pytorch/pytorch/blob/ca666982028f32ddf3606c1d6e45a3a83f274d5d/torch/utils/data/_utils/collate.py#L77

3、解决方案

1、在 __init__() 函数里,先把每一条数据转成 Tensor,而不是直接返回 List。这样用 dataloader 加载数据的维度就是我熟悉的:[B, len_list, C, T, H, W]

2、若直接返回 List,则注意在 __getitem__() 函数里处理数据时,维度是 [len_list, B, C, T, H, W]

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐