Paddle:加载自定义数据集

Paddle提供两种方式来加载数据集:

1:加载内置数据

2:加载自定义数据

1:加载内置数据

飞桨框架在 paddle.vision.datasets 和 paddle.text 目录下内置了一些经典数据集可直接调用,通过以下代码可查看飞桨框架中的内置数据集。

import paddle
print('计算机视觉(CV)相关数据集:', paddle.vision.datasets.__all__)
print('自然语言处理(NLP)相关数据集:', paddle.text.__all__)

具体使用可参考官方文档:数据集定义与加载-使用文档-PaddlePaddle深度学习平台

2:加载自定义数据 

在实际的场景中,一般需要使用自有的数据来定义数据集,这时可以通过 paddle.io.Dataset 基类来实现自定义数据集。

可构建一个子类继承自 paddle.io.Dataset ,并且实现下面的三个函数:

(是不是很眼熟,不能说和Pytorch完全相同,只能说是一模一样。目的是降低迁移学习的难度)

1、__init__:完成数据集初始化操作,将磁盘中的样本文件路径和对应标签映射到一个列表中。2、__getitem__:定义指定索引(index)时如何获取样本数据,最终返回对应 index 的单条数据(样本数据、对应的标签)。3、__len__:返回数据集的样本总数。

直接上示例代码:

import os
import cv2
import numpy as np
from paddle.io import Dataset
from paddle.vision import transforms as T'''
paddle-API文档:https://www.paddlepaddle.org.cn/documentation/docs/zh/api/index_cn.html
'''class ListDataset(Dataset):def __init__(self, list_file, mode='train'):if mode == 'train':print("Loading train data ......")else:print("Loading test data ......")# modeself.mode = mode# load listself.data_list = []with open(list_file, "r") as f:self.data_list = f.readlines()# define img transformself.transform_train = T.Compose([T.Resize((128, 64), interpolation='nearest'),T.ContrastTransform(0.2),T.BrightnessTransform(0.2),T.RandomHorizontalFlip(0.5),T.RandomRotation(15),T.Transpose(),T.Normalize(mean=[127.5, 127.5, 127.5],  data_format='CHW', std=[127.5, 127.5, 127.5],  to_rgb=True)])self.transfrom_eval = T.Compose([T.Resize((128, 64), interpolation='nearest'),T.Transpose(),T.Normalize(mean=[127.5, 127.5, 127.5],  data_format='CHW', std=[127.5, 127.5, 127.5],  to_rgb=True)])def __getitem__(self, index):line_info = self.data_list[index].strip().split(' ')img_bgr = cv2.imread(line_info[0])img_label = [int(i) for i in line_info[1:]]if self.mode == 'train':img = self.transform_train(img_bgr)else:img = self.transfrom_eval(img_bgr)return img, img_labeldef __len__(self):return len(self.data_list)

对于遇到不清楚的API:直接翻官方文档。如果还不清楚,那就翻对应的pytorch文档。两个基本是相同的。

paddle-API文档:https://www.paddlepaddle.org.cn/documentation/docs/zh/api/index_cn.html