• 用于pytorch的H5Dataset接口(类比TensorDataset接口)


    pytorch的TensorDataset接口

     1 class TensorDataset(Dataset):
     2     """Dataset wrapping data and target tensors.
     3     Each sample will be retrieved by indexing both tensors along the first
     4     dimension.
     5     Arguments:
     6         data_tensor (Tensor): contains sample data.
     7         target_tensor (Tensor): contains sample targets (labels).
     8     """
     9 
    10     def __init__(self, data_tensor, target_tensor):
    11         assert data_tensor.size(0) == target_tensor.size(0)
    12         self.data_tensor = data_tensor
    13         self.target_tensor = target_tensor
    14 
    15     def __getitem__(self, index):
    16         return self.data_tensor[index], self.target_tensor[index]
    17 
    18     def __len__(self):
    19 return self.data_tensor.size(0)

    用于hdf5的H5Dataset接口

     1 class H5Dataset(Dataset):
     2     """Dataset wrapping data and target tensors.
     3 
     4     Each sample will be retrieved by indexing both tensors along the first
     5     dimension.
     6 
     7     Arguments:
     8         data_tensor (Tensor): contains sample data.
     9         target_tensor (Tensor): contains sample targets (labels).
    10     """
    11 
    12     def __init__(self, data_tensor, target_tensor):
    13         assert data_tensor.shape[0] == target_tensor.shape[0]
    14         self.data_tensor = data_tensor
    15         self.target_tensor = target_tensor
    16 
    17     def __getitem__(self, index):
    18         # print(index)
    19         return self.data_tensor[index], self.target_tensor[index]
    20 
    21     def __len__(self):
    22         return self.data_tensor.shape[0]

    对应的DataLoader(把TensorDataset改成H5Dataset即可)

     1 def load_data():
     2     f = h5py.File("./dataset/CAVE.h5", 'r')
     3     MS_train = f['train']["MS"]
     4     RGB_train = f['train']["RGB"]
     5     MS_test = f['test']["MS"]
     6     RGB_test = f['test']["RGB"]
     7     train_set = H5Dataset(RGB_train, MS_train)
     8     test_set = H5Dataset(RGB_test, MS_test)
     9     training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, pin_memory=True,
    10                                       shuffle=True)
    11     testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, pin_memory=True,
    12                                      shuffle=False)
    13     return training_data_loader, testing_data_loader
  • 相关阅读:
    zpf 视图
    html5本地存储
    sqlite 使用记录
    百度开放平台
    PHP exit() 输出
    yum笔记
    手动编译安装软件
    while循环
    linux下面测试网络带宽 (转载)
    软件包管理器的核心功能(笔记)
  • 原文地址:https://www.cnblogs.com/nwpuxuezha/p/7846751.html
Copyright © 2020-2023  润新知