• 重新定义Pytorch中的TensorDataset,可实现transforms


    class TensorsDataset(torch.utils.data.Dataset):
    
        '''
        A simple loading dataset - loads the tensor that are passed in input. This is the same as
        torch.utils.data.TensorDataset except that you can add transformations to your data and target tensor.
        Target tensor can also be None, in which case it is not returned.
        '''
    
        def __init__(self, data_tensor, target_tensor=None, transforms=None, target_transforms=None):
            if target_tensor is not None:
                assert data_tensor.size(0) == target_tensor.size(0)
            self.data_tensor = data_tensor
            self.target_tensor = target_tensor
    
            if transforms is None:
                transforms = []
            if target_transforms is None:
                target_transforms = []
    
            if not isinstance(transforms, list):
                transforms = [transforms]
            if not isinstance(target_transforms, list):
                target_transforms = [target_transforms]
    
            self.transforms = transforms
            self.target_transforms = target_transforms
    
        def __getitem__(self, index):
    
            data_tensor = self.data_tensor[index]
            for transform in self.transforms:
                data_tensor = transform(data_tensor)
    
            if self.target_tensor is None:
                return data_tensor
    
            target_tensor = self.target_tensor[index]
            for transform in self.target_transforms:
                target_tensor = transform(target_tensor)
    
            return data_tensor, target_tensor
    
        def __len__(self):
            return self.data_tensor.size(0)
    
  • 相关阅读:
    java集合归纳
    判断回文数
    29:四则运算计算表达式的值
    getOutString 输出弹出字符串
    两个字符串中不同元素的个数
    字符串各个字符ASCII值加5
    23:一个整数的二进制表示中有多少个1
    Java进程间通信
    转 双重检查锁定与延迟初始化
    Key-Value键值存储原理初识(NOSQL)
  • 原文地址:https://www.cnblogs.com/marsggbo/p/10459235.html
Copyright © 2020-2023  润新知