• mxnet自定义dataloader加载自己的数据


    实际上关于pytorch加载自己的数据之前有写过一篇博客,但是最近接触了mxnet,发现关于这方面的教程很少

    如果要加载自己定义的数据的话,看mxnet关于mnist基本上能够推测12

    看pytorch与mxnet他们加载数据方式的对比

    上图左边是pytorch的,右图是mxnet

    实际上,mxnet与pytorch他们的datalayer有着相似之处,为什么这样说呢?直接看上面的代码,基本上都是输入图像的路径,然后输出一个可以供loader调用的可以迭代的对象,所以无论是pytorch或者是mxnet,如果要有自己的数据,只需要在自己的数据那一部分继承与修改ImageFolderDataset这个函数就行,就是直接继承dataset.Dataset类即可

    对于pytorch而言,它使用了find_class这样一个函数,而对于mxnet而言,实际上它在类内部定义了一个_list_images的函数,事实上我并没有发现这有没有用,只需要get_item这个函数中返回list,list中是一个tuple,一个是文件的名字,另外一个是文件所对应的label即可。

    只需要继承这一个类即可

    直接撸代码

    这个是我参加kaggle比赛的一段代码,尽管并不收敛,但请不要在意这些细节

      1 # -*-coding:utf-8-*-
      2 from mxnet import autograd
      3 from mxnet import gluon
      4 from mxnet import image
      5 from mxnet import init
      6 from mxnet import nd
      7 from mxnet.gluon.data import vision
      8 import numpy as np
      9 from mxnet.gluon.data import dataset
     10 import os
     11 import warnings
     12 import random
     13 from mxnet import gpu
     14 from mxnet.gluon.data.vision import datasets
     15 
     16 class MyImageFolderDataset(dataset.Dataset):
     17     def __init__(self, root, label, flag=1, transform=None):
     18         self._root = os.path.expanduser(root)
     19         self._flag = flag
     20         self._label = label
     21         self._transform = transform
     22         self._exts = ['.jpg', '.jpeg', '.png']
     23         self._list_images(self._root, self._label)
     24 
     25     def _list_images(self, root, label):  # label是一个list
     26         self.synsets = []
     27         self.synsets.append(root)
     28         self.items = []
     29         #file = open(label)
     30         #lines = file.readlines()
     31         #random.shuffle(lines)
     32         c = 0
     33         for line in label:
     34             cls = line.split()
     35             fn = cls.pop(0)
     36             fn = fn + '.jpg'
     37             # print(os.path.join(root, fn))
     38             if os.path.isfile(os.path.join(root, fn)):
     39                 self.items.append((os.path.join(root, fn), float(cls[0])))
     40                 # print((os.path.join(root, fn), float(cls[0])))
     41             else:
     42                 print('what')
     43             c = c + 1
     44         print('the total image is ', c)
     45 
     46     def __getitem__(self, idx):
     47         img = image.imread(self.items[idx][0], self._flag)
     48         label = self.items[idx][1]
     49         if self._transform is not None:
     50             return self._transform(img, label)
     51         return img, label
     52 
     53     def __len__(self):
     54         return len(self.items)
     55 
     56 
     57 def _get_batch(batch, ctx):  # 可以在循环中直接for i, data, label,函数主要把data放在ctx上
     58     """return data and label on ctx"""
     59     if isinstance(batch, mx.io.DataBatch):
     60         data = batch.data[0]
     61         label = batch.label[0]
     62     else:
     63         data, label = batch
     64     return (gluon.utils.split_and_load(data, ctx),
     65             gluon.utils.split_and_load(label, ctx),
     66             data.shape[0])
     67 
     68 def transform_train(data, label):
     69     im = image.imresize(data.astype('float32') / 255, 256, 256)
     70     auglist = image.CreateAugmenter(data_shape=(3, 256, 256), resize=0,
     71                         rand_crop=False, rand_resize=False, rand_mirror=True,
     72                         mean=None, std=None,
     73                         brightness=0, contrast=0,
     74                         saturation=0, hue=0,
     75                         pca_noise=0, rand_gray=0, inter_method=2)
     76     for aug in auglist:
     77         im = aug(im)
     78     # 将数据格式从"高*宽*通道"改为"通道*高*宽"。
     79     im = nd.transpose(im, (2, 0, 1))
     80     return (im, nd.array([label]).asscalar().astype('float32'))
     81 
     82 
     83 def transform_test(data, label):
     84     im = image.imresize(data.astype('float32') / 255, 256, 256)
     85     im = nd.transpose(im, (2, 0, 1))  # 之前没有运行此变换
     86     return (im, nd.array([label]).asscalar().astype('float32'))
     87 
     88 batch_size = 16
     89 root = '/home/ying/data2/shiyongjie/landmark_recognition/data/image'
     90 def random_choose_data(label_path):
     91     f = open(label_path)
     92     lines = f.readlins()
     93     random.shuffle(lines)
     94     total_number = len(lines)
     95     train_number = total_number/10*7
     96     train_list = lines[:train_number]
     97     test_list = lines[train_number:]
     98     return (train_list, test_list)
     99 
    100 label_path = '/home/ying/data2/shiyongjie/landmark_recognition/data/train.txt'
    101 train_list, test_list = random_choose_data(label_path)
    102 loader = gluon.data.DataLoader
    103 train_ds = MyImageFolderDataset(os.path.join(root, 'image'), train_list, flag=1, transform=transform_train)
    104 test_ds = MyImageFolderDataset(os.path.join(root, 'Testing'), test_list, flag=1, transform=transform_test)
    105 train_data = loader(train_ds, batch_size, shuffle=True, last_batch='keep')
    106 test_data = loader(test_ds, batch_size, shuffle=False, last_batch='keep')
    107 softmax_cross_entropy = gluon.loss.L2Loss()  # 定义L2 loss
    108 
    109 
    110 from mxnet.gluon import nn
    111 
    112 net = nn.Sequential()
    113 with net.name_scope():
    114     net.add(
    115         # 第一阶段
    116         nn.Conv2D(channels=96, kernel_size=11,
    117                   strides=4, activation='relu'),
    118         nn.MaxPool2D(pool_size=3, strides=2),
    119         # 第二阶段
    120         nn.Conv2D(channels=256, kernel_size=5,
    121                   padding=2, activation='relu'),
    122         nn.MaxPool2D(pool_size=3, strides=2),
    123         # 第三阶段
    124         nn.Conv2D(channels=384, kernel_size=3,
    125                   padding=1, activation='relu'),
    126         nn.Conv2D(channels=384, kernel_size=3,
    127                   padding=1, activation='relu'),
    128         nn.Conv2D(channels=256, kernel_size=3,
    129                   padding=1, activation='relu'),
    130         nn.MaxPool2D(pool_size=3, strides=2),
    131         # 第四阶段
    132         nn.Flatten(),
    133         nn.Dense(4096, activation="relu"),
    134         nn.Dropout(.5),
    135         # 第五阶段
    136         nn.Dense(4096, activation="relu"),
    137         nn.Dropout(.5),
    138         # 第六阶段
    139         nn.Dense(14950)  # 输出为1个值
    140     )
    141 
    142 from mxnet import init
    143 from mxnet import gluon
    144 import mxnet as mx
    145 import utils
    146 import datetime
    147 from time import time
    148 
    149 ctx = utils.try_gpu()
    150 net.initialize(ctx=ctx, init=init.Xavier())
    151 
    152 mse_loss = gluon.loss.L2Loss()
    153 
    154 # utils.train(train_data, test_data, net, loss,
    155 #             trainer, ctx, num_epochs=10)
    156 #def train(train_data, test_data, net, loss, trainer, ctx, num_epochs, print_batches=None):
    157 num_epochs = 10
    158 print_batches = 100
    159 """Train a network"""
    160 print("Start training on ", ctx)
    161 if isinstance(ctx, mx.Context):
    162     ctx = [ctx]
    163 def train(net, train_data, valid_data, num_epochs, lr, wd, ctx, lr_period, lr_decay):
    164     trainer = gluon.Trainer(net.collect_params(), 'sgd',
    165                             {'learning_rate': lr, 'momentum': 0.9, 'wd': wd})
    166     prev_time = datetime.datetime.now()
    167     for epoch in range(num_epochs):
    168         train_loss = 0.0
    169         if epoch > 0 and epoch % lr_period == 0:
    170             trainer.set_learning_rate(trainer.learning_rate*lr_decay)
    171         for data, label in train_data:
    172             label = label.as_in_context(ctx)
    173             with autograd.record():
    174                 output = net(data.as_in_context(ctx))
    175                 loss = mse_loss(output, label)
    176             loss.backward()
    177             trainer.step(batch_size)  # do the update, Trainer needs to know the batch size of the data to normalize
    178             # the gradient by 1/batch_size
    179             train_loss += nd.mean(loss).asscalar()
    180             print(nd.mean(loss).asscalar())
    181         cur_time = datetime.datetime.now()
    182         h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    183         m, s = divmod(remainder, 60)
    184         time_str = "Time %02d:%02d:%02d" % (h, m, s)
    185         epoch_str = ('Epoch %d. Train loss: %f, ' % (epoch, train_loss / len(train_data)))
    186         prev_time = cur_time
    187         print(epoch_str + time_str + ', lr' + str(trainer.learning_rate))
    188     net.collect_params().save('./model/alexnet.params')
    189 ctx = utils.try_gpu()
    190 num_epochs = 100
    191 learning_rate = 0.001
    192 weight_decay = 5e-4
    193 lr_period = 10
    194 lr_decay = 0.1
    195 
    196 train(net, train_data, test_data, num_epochs, learning_rate,
    197       weight_decay, ctx, lr_period, lr_decay)
    View Code

    请看这一段

     1 class MyImageFolderDataset(dataset.Dataset):
     2     def __init__(self, root, label, flag=1, transform=None):
     3         self._root = os.path.expanduser(root)
     4         self._flag = flag
     5         self._label = label
     6         self._transform = transform
     7         self._exts = ['.jpg', '.jpeg', '.png']
     8         self._list_images(self._root, self._label)
     9 
    10     def _list_images(self, root, label):  # label是一个list
    11         self.synsets = []
    12         self.synsets.append(root)
    13         self.items = []
    14         #file = open(label)
    15         #lines = file.readlines()
    16         #random.shuffle(lines)
    17         c = 0
    18         for line in label:
    19             cls = line.split()
    20             fn = cls.pop(0)
    21             fn = fn + '.jpg'
    22             # print(os.path.join(root, fn))
    23             if os.path.isfile(os.path.join(root, fn)):
    24                 self.items.append((os.path.join(root, fn), float(cls[0])))
    25                 # print((os.path.join(root, fn), float(cls[0])))
    26             else:
    27                 print('what')
    28             c = c + 1
    29         print('the total image is ', c)
    30 
    31     def __getitem__(self, idx):
    32         img = image.imread(self.items[idx][0], self._flag)
    33         label = self.items[idx][1]
    34         if self._transform is not None:
    35             return self._transform(img, label)
    36         return img, label
    37 
    38     def __len__(self):
    39         return len(self.items)
    40 batch_size = 16
    41 root = '/home/ying/data2/shiyongjie/landmark_recognition/data/image'
    42 def random_choose_data(label_path):
    43     f = open(label_path)
    44     lines = f.readlins()
    45     random.shuffle(lines)
    46     total_number = len(lines)
    47     train_number = total_number/10*7
    48     train_list = lines[:train_number]
    49     test_list = lines[train_number:]
    50     return (train_list, test_list)
    51 
    52 label_path = '/home/ying/data2/shiyongjie/landmark_recognition/data/train.txt'
    53 train_list, test_list = random_choose_data(label_path)
    54 
    55 loader = gluon.data.DataLoader
    56 train_ds = MyImageFolderDataset(os.path.join(root, 'image'), train_list, flag=1, transform=transform_train)
    57 test_ds = MyImageFolderDataset(os.path.join(root, 'Testing'), test_list, flag=1, transform=transform_test)
    58 train_data = loader(train_ds, batch_size, shuffle=True, last_batch='keep')
    59 test_data = loader(test_ds, batch_size, shuffle=False, last_batch='keep')
    View Code

    MyImageFolderDataset是dataset.Dataset的子类,主要是是重载索引运算__getitem__,并且返回image以及其对应的label即可,前面的的_list_image函数只要是能够返回item这个list就行,关于运算符重载给自己挖个坑

    可以说和pytorch非常像了,就连沐神在讲课的时候还在说,其实在写mxnet的时候,借鉴了很多pytorch的内容

  • 相关阅读:
    用Service充当Domain Object
    Scrum方法回顾
    为什么使用User Story Map
    前端状态管理之状态机
    项目进度管理注意事项
    单元测试遇到的最难的问题
    JS AMD模块的循环依赖
    jupyter notebook常用快捷键
    Jupyter-NoteBook-你应该知道的N个小技巧
    Python之配置日志的几种方式(logging模块)
  • 原文地址:https://www.cnblogs.com/yongjieShi/p/8681666.html
Copyright © 2020-2023  润新知