在 pytorch 中建立自己的图片数据集
目录
通常情况下,待处理的图片数据有两种存放方式:
- 所有图片在同一目录下,另有一份文本文件记录了标签。
- 不同标签的图片放在不同目录下,文件夹名就是标签。
对于这两种情况,我们有不同的解决方法。
图片文件在同一目录下
假设在 ./data/
目录下有所需的所有的图片,以及一份标记了图片标签的文本文件(列为图片路径+标签)./labels.txt
./data/IZvVCYcuOkcu6Ufj.jpg 0
./data/2wuPp4yYoc2wJbZI.jpg 0
./data/vzlBbG4Z1KKJ4P6L.jpg 1
./data/nR8VZBPbjF92wNGC.jpg 2
......
思路是继承 torch.utils.data.Dataset
,并重点重写其 __getitem__
方法,示例代码如下:
class CustomDataset(Dataset):
def __init__(self, label_file_path):
with open(label_file_path, 'r') as f:
# (image_path(str), image_label(str))
self.imgs = list(map(lambda line: line.strip().split(' '), f))
def __getitem__(self, index):
path, label = self.imgs[index]
img = transforms.Compose([transforms.Scale(224),
transforms.CenterCrop(224),
transforms.ToTensor(),])(Image.open(path).convert('RGB'))
label = int(label)
return img, label
def __len__(self):
return len(self.imgs)
dataset = CustomDataset('./labels.txt')
loader = DataLoader(dataset, batch_size=64, shuffle=True)
至此,可以用 enumerate(loader)
的方式迭代数据了。需要注意的是,在 __getitem__
时要确保 batch 内图片尺寸相同(上面的例子用了 Scale+CenterCrop 的方法),否则会出现 RuntimeError: inconsistent tensor sizes at ...
的错误。
图片文件在不同目录下
当图片文件依据 label 处于不同文件下时,如:
─── data
├── 虾饺
│ ├── 00856315f0df13536183d8ae6cbaf8d6a54f37ce.jpg
│ └── 00ce9dccdf9a218d3b891e006c81f8e66524b1b3.jpg
├── 八宝粥
│ ├── 055133235f649411e599ce5dba83627d58996209.jpg
│ └── 0a72473884cb6c03191ca929a9aa0b2bbe4abb3d.jpg
└── 钵仔糕
├── 1237b1e7b7e7da0ac78f9e1c8317b9462fe92803.jpg
└── 14a7d6c1a881d1dcfe855bf783064ad2c9d5aba4.jpg
此时我们可以利用 torchvision.datasets.ImageFolder
来直接构造出 dataset,代码如下:
dataset = ImageFolder(path)
loader = DataLoader(dataset)
ImageFolder 会将目录中的文件夹名自动转化成序列,那么 loader
载入时,标签就是整数序列了。