PyTorch 中自定义数据集的读取方法小结

PyTorch 中自定义数据集的读取方法小结

总结常用的几种自定义数据集(Custom Dataset)的读取方式(采用 Dataloader)。

本文将涉及以下几个方面:

  • 自定义数据集基础方法
  • 使用 Torchvision Transforms
  • 换一种方法使用 Torchvision Transforms
  • 结合 Pandas 读取 csv 文件
  • 结合 Pandas 使用 __getitem__()
  • 使用 Dataloader 读取自定义数据集

一. 自定义数据集基础方法

首先要创建一个 Dataset 类:

from torch.utils.data.dataset import Dataset

class MyCustomDataset(Dataset):
    def __init__(self, ...):
        # stuff
        
    def __getitem__(self, index):
        # stuff
        return (img, label)
 
    def __len__(self):
        return count

这个代码中:

  • __init__() 一些初始化过程写在这里
  • __len__() 返回所有数据的数量
  • __getitem__() 返回数据和标签,可以这样显示调用:
img, label = MyCustomDataset.__getitem__(99)

二. 使用 Torchvision Transforms

Transform 最常见的使用方法是:

from torch.utils.data.dataset import Dataset
from torchvision import transforms
 
class MyCustomDataset(Dataset):
    def __init__(self, ..., transforms=None):
        # stuff
        ...
        self.transforms = transforms
        
    def __getitem__(self, index):
        # stuff
        ...
        data = # 一些读取的数据
        if self.transforms is not None:
            data = self.transforms(data)
        # 如果 transform 不为 None,则进行 transform 操作
        return (img, label)
 
    def __len__(self):
        return count 
        
if __name__ == \'__main__\':
    # 定义我们的 transforms (1)
    transformations = transforms.Compose([transforms.CenterCrop(100), transforms.ToTensor()])
    # 创建 dataset
    custom_dataset = MyCustomDataset(..., transformations)

三. 换一种方法使用 Torchvision Transforms

有些人不喜欢把 transform 操作写在 Dataset 外面(上面代码里的注释 1),所以还有一种写法:

from torch.utils.data.dataset import Dataset
from torchvision import transforms
 
class MyCustomDataset(Dataset):
    def __init__(self, ...):
        # stuff
        ...
        # (2) 一种方法是单独定义 transform
        self.center_crop = transforms.CenterCrop(100)
        self.to_tensor = transforms.ToTensor()
        
        # (3) 或者写成下面这样 
        self.transformations = \
            transforms.Compose([transforms.CenterCrop(100),
                                transforms.ToTensor()])
        
    def __getitem__(self, index):
        # stuff
        ...
        data = #一些读取的数据
        
        # 当第二次调用 transform 时,调用的是 __call__()
        data = self.center_crop(data)  # (2)
        data = self.to_tensor(data)  # (2)
        
        # 或者写成下面这样
        data = self.trasnformations(data)  # (3)
        
        # 注意 (2) 和 (3) 中只需要实现一种
        return (img, label)
 
    def __len__(self):
        return count
        
if __name__ == \'__main__\':
    custom_dataset = MyCustomDataset(...)

四. 结合 Pandas 读取 csv 文件

假如说我们想从一个 csv 文件中用 Pandas 读取数据。一个 csv 示例如下:

File NameLabelExtra Operation
tr_0.png5TRUE
tr_1.png0FALSE
tr_2.png4FALSE

如果我们需要在自定义数据集里从这个 csv 文件读取文件名,可以这样做:

class CustomDatasetFromImages(Dataset):
    def __init__(self, csv_path):
        """
        Args:
            csv_path (string): csv 文件路径
            img_path (string): 图像文件所在路径
            transform: transform 操作
        """
        # Transforms
        self.to_tensor = transforms.ToTensor()
        # 读取 csv 文件
        self.data_info = pd.read_csv(csv_path, header=None)
        # 文件第一列包含图像文件的名称
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        # 第二列是图像的 label
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])
        # 第三列是决定是否进行额外操作
        self.operation_arr = np.asarray(self.data_info.iloc[:, 2])
        # 计算 length
        self.data_len = len(self.data_info.index)
 
    def __getitem__(self, index):
        # 从 pandas df 中得到文件名
        single_image_name = self.image_arr[index]
        # 读取图像文件
        img_as_img = Image.open(single_image_name)
 
        # 检查需不需要额外操作
        some_operation = self.operation_arr[index]
        # 如果需要额外操作
        if some_operation:
            # ...
            # ...
            pass
        # 把图像转换成 tensor
        img_as_tensor = self.to_tensor(img_as_img)
 
        # 得到图像的 label
        single_image_label = self.label_arr[index]
 
        return (img_as_tensor, single_image_label)
 
    def __len__(self):
        return self.data_len
 
if __name__ == "__main__":
    custom_mnist_from_images =  \
        CustomDatasetFromImages(\'../data/mnist_labels.csv\')

五. 结合 Pandas 使用 __getitem__()

另一种情况是 csv 文件中保存了我们需要的图像文件的像素值(比如有些 MNIST 教程就是这样的)。我们需要改动一下 __getitem__() 函数。

Labelpixel_1pixel_2...
15099...
021223...
94455...

代码如下:

class CustomDatasetFromCSV(Dataset):
    def __init__(self, csv_path, height, width, transforms=None):
        """
        Args:
            csv_path (string): csv 文件路径
            height (int): 图像高度
            width (int): 图像宽度
            transform: transform 操作
        """
        self.data = pd.read_csv(csv_path)
        self.labels = np.asarray(self.data.iloc[:, 0])
        self.height = height
        self.width = width
        self.transforms = transform
 
    def __getitem__(self, index):
        single_image_label = self.labels[index]
        # 读取所有像素值,并将 1D array ([784]) reshape 成为 2D array ([28,28]) 
        img_as_np = np.asarray(self.data.iloc[index][1:]).reshape(28,28).astype(\'uint8\')
    # 把 numpy array 格式的图像转换成灰度 PIL image
        img_as_img = Image.fromarray(img_as_np)
        img_as_img = img_as_img.convert(\'L\')
        # 将图像转换成 tensor
        if self.transforms is not None:
            img_as_tensor = self.transforms(img_as_img)
        # 返回图像及其 label
        return (img_as_tensor, single_image_label)
 
    def __len__(self):
        return len(self.data.index)
        
 
if __name__ == "__main__":
    transformations = transforms.Compose([transforms.ToTensor()])
    custom_mnist_from_csv = \
        CustomDatasetFromCSV(\'../data/mnist_in_csv.csv\', 28, 28, transformations)

六. 使用 Dataloader 读取自定义数据集

PyTorch 中的 Dataloader 只是调用 __getitem__() 方法并组合成 batch,我们可以这样调用:

...
if __name__ == "__main__":
    # 定义 transforms
    transformations = transforms.Compose([transforms.ToTensor()])
    # 自定义数据集
    custom_mnist_from_csv = \
        CustomDatasetFromCSV(\'../data/mnist_in_csv.csv\',
                             28, 28,
                             transformations)
    # 定义 data loader
    mn_dataset_loader = torch.utils.data.DataLoader(dataset=custom_mnist_from_csv,
                                                    batch_size=10,
                                                    shuffle=False)
    
    for images, labels in mn_dataset_loader:
        # 将数据传给网络模型

需要注意的是使用多卡训练时,PyTorch dataloader 会将每个 batch 平均分配到各个 GPU。所以如果 batch size 过小,可能发挥不了多卡的效果。

最后修改:2019 年 07 月 23 日 06 : 42 PM
如果觉得我的文章对你有用,请随意赞赏

发表评论