• 重新定义Pytorch中的TensorDataset,可实现transforms


    class TensorsDataset(torch.utils.data.Dataset):
    
        '''
        A simple loading dataset - loads the tensor that are passed in input. This is the same as
        torch.utils.data.TensorDataset except that you can add transformations to your data and target tensor.
        Target tensor can also be None, in which case it is not returned.
        '''
    
        def __init__(self, data_tensor, target_tensor=None, transforms=None, target_transforms=None):
            if target_tensor is not None:
                assert data_tensor.size(0) == target_tensor.size(0)
            self.data_tensor = data_tensor
            self.target_tensor = target_tensor
    
            if transforms is None:
                transforms = []
            if target_transforms is None:
                target_transforms = []
    
            if not isinstance(transforms, list):
                transforms = [transforms]
            if not isinstance(target_transforms, list):
                target_transforms = [target_transforms]
    
            self.transforms = transforms
            self.target_transforms = target_transforms
    
        def __getitem__(self, index):
    
            data_tensor = self.data_tensor[index]
            for transform in self.transforms:
                data_tensor = transform(data_tensor)
    
            if self.target_tensor is None:
                return data_tensor
    
            target_tensor = self.target_tensor[index]
            for transform in self.target_transforms:
                target_tensor = transform(target_tensor)
    
            return data_tensor, target_tensor
    
        def __len__(self):
            return self.data_tensor.size(0)
    
  • 相关阅读:
    使用Mysql慢查询日志对有效率问题的SQL进行监控
    wampserver3.1.0安装及配置
    Composer
    HTML5 本地存储(Web Storage)
    HTML5 元素拖动
    生成验证码
    git 基础命令
    POI使用流程
    JDK各版本新特性总结
    dubbo+zookeeper项目搭建
  • 原文地址:https://www.cnblogs.com/marsggbo/p/10459235.html
Copyright © 2020-2023  润新知