Dataset是一个包装类,用来将数据包装为Dataset类,然后传入DataLoader中,我们再使用DataLoader这个类来更加快捷的对数据进行操作。

DataLoader是一个比较重要的类,它为我们提供的常用操作有:batch_size(每个batch的大小), shuffle(是否进行shuffle操作), num_workers(加载数据的时候使用几个子进程)
当我们集成了一个 Dataset类之后,我们需要重写 len 方法,该方法提供了dataset的大小; getitem 方法, 该方法支持从 0 到 len(self)的索引

from torch.utils.data import Dataset class PTB(Dataset):     """battery dataset."""     def __init__(self, data_dir, split,battery_dataset=[],**kwargs):         """         Args:             csv_file (string): Path to the csv file with annotations.             data_dir (string): data path0         """         super().__init__()         self.data_dir = data_dir         try:             for file in os.listdir(self.data_dir):                 # print("file",os.path.join(data_dir,file))                 df = pd.read_csv(os.path.join(data_dir,file), encoding="gbk")                  # self.battery_frame = df.values                 # # print("self.battery_frame",self.battery_frame)                 # # print("self.battery_frame",self.battery_frame.shape)                 # battery_dataset.append(self.battery_frame)                  windows=32                 windows_move=1                 if df.shape[0]>=windows:                     self.battery_frame = df.values                     # print("self.battery_frame",self.battery_frame)                     # print("self.battery_frame",self.battery_frame.shape)                                          feature_num = self.battery_frame.shape[0]-windows+windows_move                     for index in range(0,feature_num,windows_move):                         feature_df = self.battery_frame[index:(index + windows)]                                         battery_dataset.append(feature_df)                     self.battery_dataset = battery_dataset         except RuntimeError:             pass         print(len(self.battery_dataset))     def __len__(self):         #返回文件数据的数目         print(len(self.battery_dataset))         return len(self.battery_dataset)         # return 1800000     def __getitem__(self, idx):         #接收一个索引,返回一个样本(tensor维度相同)         print (idx)         # battery = self.battery_frame.get_chunk(128).as_matrix().astype('float')         # battery = self.battery_dataset[idx].as_matrix().astype('float')         battery = self.battery_dataset[idx]         print("__getitem__",battery.shape)          return battery