/ Tech

在 pytorch 中建立自己的图片数据集

目录


通常情况下,待处理的图片数据有两种存放方式:

  • 所有图片在同一目录下,另有一份文本文件记录了标签。
  • 不同标签的图片放在不同目录下,文件夹名就是标签。

对于这两种情况,我们有不同的解决方法。

图片文件在同一目录下

假设在 ./data/ 目录下有所需的所有的图片,以及一份标记了图片标签的文本文件(列为图片路径+标签)./labels.txt

./data/IZvVCYcuOkcu6Ufj.jpg 0
./data/2wuPp4yYoc2wJbZI.jpg 0
./data/vzlBbG4Z1KKJ4P6L.jpg 1
./data/nR8VZBPbjF92wNGC.jpg 2
......

思路是继承 torch.utils.data.Dataset,并重点重写其 __getitem__ 方法,示例代码如下:

class CustomDataset(Dataset):
    def __init__(self, label_file_path):
        with open(label_file_path, 'r') as f:
            # (image_path(str), image_label(str))
            self.imgs = list(map(lambda line: line.strip().split(' '), f))
    
    def __getitem__(self, index):
        path, label = self.imgs[index]
        img = transforms.Compose([transforms.Scale(224),
                                  transforms.CenterCrop(224),
                                  transforms.ToTensor(),])(Image.open(path).convert('RGB'))
        label = int(label)
        return img, label
    
    def __len__(self):
        return len(self.imgs)

dataset = CustomDataset('./labels.txt')
loader = DataLoader(dataset, batch_size=64, shuffle=True)

至此,可以用 enumerate(loader) 的方式迭代数据了。需要注意的是,在 __getitem__ 时要确保 batch 内图片尺寸相同(上面的例子用了 Scale+CenterCrop 的方法),否则会出现 RuntimeError: inconsistent tensor sizes at ... 的错误。

图片文件在不同目录下

当图片文件依据 label 处于不同文件下时,如:

─── data
    ├── 虾饺
    │   ├── 00856315f0df13536183d8ae6cbaf8d6a54f37ce.jpg
    │   └── 00ce9dccdf9a218d3b891e006c81f8e66524b1b3.jpg
    ├── 八宝粥
    │   ├── 055133235f649411e599ce5dba83627d58996209.jpg
    │   └── 0a72473884cb6c03191ca929a9aa0b2bbe4abb3d.jpg
    └── 钵仔糕
        ├── 1237b1e7b7e7da0ac78f9e1c8317b9462fe92803.jpg
        └── 14a7d6c1a881d1dcfe855bf783064ad2c9d5aba4.jpg

此时我们可以利用 torchvision.datasets.ImageFolder 来直接构造出 dataset,代码如下:

dataset = ImageFolder(path)
loader = DataLoader(dataset)

ImageFolder 会将目录中的文件夹名自动转化成序列,那么 loader 载入时,标签就是整数序列了。