• 【转载】 PyTorch下训练数据小文件转大文件读写(附有各种存储格式对比)


     


    版权声明:本文为CSDN博主「Liekkas Kono」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
    原文链接:

    https://blog.csdn.net/shiwanghualuo/article/details/120778553

    =======================================================

    引言

    Tensorflow有着专门的数据读取模块tfrecord,可以高效地读取训练神经网络模型所用的数据,充分喂饱GPU

    Caffe用lmdb来读取数据,也可以很高效地去读取

    PyTorch有DataLoader读取数据,但是速度比较慢,尤其是小文件较多情况下

    如何基于PyTorch,高效读取数据,充分利用GPU性能,成为一个关键问题?

    TFRecord

    • 是否可以将tensorflow下的tfrecord借来一用?未尝不可
    • 目前已经有伙伴实现了,详情参见:tfrecord
    • 同时,在Kaggle上,也有大神手动实现,详情参见:PyTorch TFRecord-Loader
     
    tfrecord写入代码:
        import cv2
        import numpy as np
        import tensorflow as tf
        from tqdm import tqdm
         
        from data_loader import TFRecordDataLoader
         
         
        def read_txt(txt_path):
            with open(txt_path, 'r', encoding='utf-8') as f:
                data = f.readlines()
            data = list(map(lambda x: x.rstrip('\n'), data))
            return data
         
         
        def bytes_to_numpy(image_bytes):
            image_np = np.frombuffer(image_bytes, dtype=np.uint8)
            image_np2 = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
            return image_np2  
         
        def list_record_features(tfrecords_path):
            """查看tfrecords结构
            https://stackoverflow.com/questions/63562691/reading-a-tfrecord-file-where-features-that-were-used-to-encode-is-not-known
         
            Args:
                tfrecords_path (str): tfrecords路径
         
            Returns:
                dict: 结构信息
            """
            features = {}
            dataset = tf.data.TFRecordDataset([str(tfrecords_path)])
            data = next(iter(dataset))
         
            example = tf.train.Example()
            example_bytes = data.numpy()
            example.ParseFromString(example_bytes)
         
            for key, value in example.features.feature.items():
                kind = value.WhichOneof('kind')
                size = len(getattr(value, kind).value)
                if key in features:
                    kind2, size2 = features[key]
                    if kind != kind2:
                        kind = None
         
                    if size != size2:
                        size = None
                features[key] = (kind, size)
            return features
         
        class TFRecorder(object):
            def __init__(self) -> None:
                super().__init__()
                self.feature_dict = {
                    'height': None,
                    'width': None,
                    'depth': None,
                    'label': None,
                    'image_raw': None
                }
                self.AUTO = tf.data.experimental.AUTOTUNE
         
            def image_to_feature(self, image_string, label):
                height, width, channel = tf.image.decode_image(image_string).shape
                self.feature_dict = {
                    'height': self._int64_feature(height),
                    'width': self._int64_feature(width),
                    'depth': self._int64_feature(channel),
                    'label': self._int64_feature(label),
                    'image_raw': self._bytes_feature(image_string)
                }
                return tf.train.Example(features=tf.train.Features(feature=self.feature_dict))
         
            def write(self, save_path, img_label_dict):
                with tf.io.TFRecordWriter(save_path) as writer:
                    for file_name, label in tqdm(img_label_dict.items()):
                        img_string = open(file_name, 'rb').read()
                        feature = self.image_to_feature(img_string, label)
                        writer.write(feature.SerializeToString())
         
            def read(self, tfrecord_path):
                reader = tf.data.TFRecordDataset(tfrecord_path)
                dataset = reader.map(self._parse_image_function,
                                     num_parallel_calls=self.AUTO)
                return dataset
         
            def _parse_image_function(self, example_proto):
                self.feature_dict = {
                    'height': tf.io.FixedLenFeature([], tf.int64),
                    'width': tf.io.FixedLenFeature([], tf.int64),
                    'depth': tf.io.FixedLenFeature([], tf.int64),
                    'label': tf.io.FixedLenFeature([], tf.int64),
                    'image_raw': tf.io.FixedLenFeature([], tf.string)
                }
                example = tf.io.parse_single_example(example_proto,
                                                     self.feature_dict)
                return example
         
            @staticmethod
            def _bytes_feature(value):
                """Returns a bytes_list from a string / byte."""
                if isinstance(value, type(tf.constant(0))):
                    # BytesList won't unpack a string from an EagerTensor.
                    value = value.numpy()
                return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
         
            @staticmethod
            def _float_feature(value):
                """Returns a float_list from a float / double."""
                return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
         
            @staticmethod
            def _int64_feature(value):
                """Returns an int64_list from a bool / enum / int / uint."""
                return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
         
         
        if __name__ == '__main__':
            tfrecorder = TFRecorder()
         
            # val.txt中存放的是图像的相对路径
            img_path = read_txt('dataset/val.txt')
         
            # Path(v).parent.name: 图像的标签
            img_label_dict = {v: int(Path(v).parent.name) for v in img_path}
         
            save_path = 'temp/val.tfrecords'
            tfrecorder.write(save_path, img_label_dict)
         
            dataset = tfrecorder.read('dataset/val.tfrecords')
            for v in dataset:
                img, label = v
                print('ok')
         
            # 查看未知tfrecords结构信息
           list_record_features('xxxx.tfrecords')
    基于PyTorch下tfrecord读取代码
    import cv2
    import numpy as np
    import tensorflow as tf
    import tensorflow_datasets as tfds
    
    AUTO = tf.data.experimental.AUTOTUNE
    
    
    def bytes_to_numpy(image_bytes):
        image_np = np.frombuffer(image_bytes, dtype=np.uint8)
        image_np2 = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
        return image_np2
    
    
    def read_labeled_tfrecord(example_proto):
        feature_dict = {
            'height': tf.io.FixedLenFeature([], tf.int64),
            'width': tf.io.FixedLenFeature([], tf.int64),
            'depth': tf.io.FixedLenFeature([], tf.int64),
            'label': tf.io.FixedLenFeature([], tf.int64),
            'image_raw': tf.io.FixedLenFeature([], tf.string)
        }
        example = tf.io.parse_single_example(example_proto,
                                             feature_dict)
        img = tf.io.decode_image(example['image_raw'], channels=3,
                                 expand_animations=False)
        img = tf.image.resize_with_crop_or_pad(img,
                                               target_height=388,
                                               target_width=270)
        return img, example['label']
    
    
    def get_dataset(files, batch_size=16, repeat=False,
                    cache=False, shuffle=False):
        ds = tf.data.TFRecordDataset(files, num_parallel_reads=AUTO)
        if cache:
            ds = ds.cache()
    
        if repeat:
            ds = ds.repeat()
    
        if shuffle:
            ds = ds.shuffle(1024 * 2)
            opt = tf.data.Options()
            opt.experimental_deterministic = False
            ds = ds.with_options(opt)
    
        ds = ds.map(read_labeled_tfrecord, num_parallel_calls=AUTO)
        ds = ds.batch(batch_size)
        ds = ds.prefetch(AUTO)
        return tfds.as_numpy(ds)
    
    
    def count_data_items(file):
        num_ds = tf.data.TFRecordDataset(file, num_parallel_reads=AUTO)
        num_ds = num_ds.map(read_labeled_tfrecord, num_parallel_calls=AUTO)
        num_ds = num_ds.repeat(1)
        num_ds = num_ds.batch(1)
    
        c = 0
        for _ in num_ds:
            c += 1
        del num_ds
        return c
    
    
    class TFRecordDataLoader:
        def __init__(self, files, batch_size=32, cache=False, train=True,
                     repeat=False, shuffle=False, labeled=True,
                     return_image_ids=True):
            self.ds = get_dataset(
                files,
                batch_size=batch_size,
                cache=cache,
                repeat=repeat,
                shuffle=shuffle,)
    
            if train:
                self.num_examples = count_data_items(files)
    
            self.batch_size = batch_size
            self.labeled = labeled
            self.return_image_ids = return_image_ids
            self._iterator = None
    
        def __iter__(self):
            if self._iterator is None:
                self._iterator = iter(self.ds)
            else:
                self._reset()
            return self._iterator
    
        def _reset(self):
            self._iterator = iter(self.ds)
    
        def __next__(self):
            batch = next(self._iterator)
            return batch
    
        def __len__(self):
            n_batches = self.num_examples // self.batch_size
            if self.num_examples % self.batch_size == 0:
                return n_batches
            else:
                return n_batches + 1
    
    # 使用
    train_txt_path = 'dataset/minist/train.tfrecords'
    train_dataloader = TFRecordDataLoader(train_txt_path,
                                          batch_size=batch_size,
                                          shuffle=True)
    for v in train_dataloader:
        pass

    LMDB

    • 纵观各大论坛,说到基于PyTorch下提高小文件读取速度,不得不说到LMDB(Lightning Memory-Mapped Database)了,我也做了一些尝试,最终结论将在最后给出
    写入LMDB
    import os
    import pickle
    from pathlib import Path
     
    import cv2
    import lmdb
    import numpy as np
    from PIL import Image
    from torch.utils.data import DataLoader, Dataset
    from torchvision import transforms
    from tqdm import tqdm
     
    import utils
     
     
    class SimpleDataset(Dataset):
        def __init__(self, txt_path, transform=None) -> None:
            self.img_paths = utils.read_txt(txt_path)
            self.transform = transform
     
        def __getitem__(self, index: int):
            img_path = self.img_paths[index]
            label = int(Path(img_path).parent.name)
            try:
                img = Image.open(img_path)
                img = img.convert('RGB')
            except:
                img = cv2.imread(img_path)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                img = Image.fromarray(img)
     
            if self.transform:
                img = self.transform(img)
            img = np.array(img)
            return img, label
     
        def __len__(self) -> int:
            return len(self.img_paths)
     
     
    class LMDB_Image:
        def __init__(self, image, label):
            # Dimensions of image for reconstruction - not really necessary
            # for this dataset, but some datasets may include images of
            # varying sizes
            self.channels = image.shape[2]
            self.size = image.shape[:2]
     
            self.image = image.tobytes()
            self.label = label
     
        def get_image(self):
            """ Returns the image as a numpy array. """
            image = np.frombuffer(self.image, dtype=np.uint8)
            return image.reshape(*self.size, self.channels)
     
     
    def data2lmdb(dpath, name="train", txt_path=None,
                  write_frequency=10, num_workers=4):
        dataset = SimpleDataset(txt_path=txt_path)
        data_loader = DataLoader(dataset, num_workers=num_workers,
                                 collate_fn=lambda x: x)
     
        lmdb_path = os.path.join(dpath, "%s.lmdb" % name)
        isdir = os.path.isdir(lmdb_path)
     
        print("Generate LMDB to %s" % lmdb_path)
        db = lmdb.open(lmdb_path, subdir=isdir,
                       map_size=1099511627776,  # 单位byte
                       readonly=False,
                       meminit=False,
                       map_async=True)
     
        txn = db.begin(write=True)
        for idx, data in enumerate(tqdm(data_loader)):
            image, label = data[0]
            temp = LMDB_Image(image, label)
            txn.put(u'{}'.format(idx).encode('ascii'), pickle.dumps(temp))
     
            if idx % write_frequency == 0:
                print("[%d/%d]" % (idx, len(data_loader)))
                txn.commit()
                txn = db.begin(write=True)
     
        # finish iterating through dataset
        txn.commit()
     
        keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
        with db.begin(write=True) as txn:
            txn.put(b'__keys__', pickle.dumps(keys))
            txn.put(b'__len__', pickle.dumps(len(keys)))
     
        print("Flushing database ...")
        db.sync()
        db.close()
     
    if __name__ == '__main__':
        save_dir = 'dataset/minist'
        data2lmdb(save_dir, name='val', txt_path='dataset/minist/val.txt')
    读取LMDB
    class DatasetLMDB(Dataset):
        def __init__(self, db_path, transform=None):
            self.db_path = db_path
            self.env = lmdb.open(db_path,
                                 subdir=os.path.isdir(db_path),
                                 readonly=True, lock=False,
                                 readahead=False, meminit=False)
            with self.env.begin() as txn:
                self.length = pickle.loads(txn.get(b'__len__'))
                self.keys = pickle.loads(txn.get(b'__keys__'))
            self.transform = transform
     
        def __getitem__(self, index):
            with self.env.begin() as txn:
                byteflow = txn.get(self.keys[index])
     
            IMAGE = pickle.loads(byteflow)
            img, label = IMAGE.get_image(), IMAGE.label
            return Image.fromarray(img).convert('RGB'), label
     
        def __len__(self):
            return self.length
     
    # 使用
    train_transforms = transforms.Compose([
        transforms.Resize((388, 270)),
        transforms.RandomChoice([
            transforms.RandomRotation(10),
            transforms.RandomHorizontalFlip(0.5),
            transforms.RandomGrayscale(p=0.3),
            transforms.RandomPerspective(distortion_scale=0.6, p=0.5),
            transforms.ColorJitter(brightness=.5, hue=.3),
        ]),
        transforms.ToTensor(),
        normalize,
        transforms.RandomErasing(),
        ])
     
    train_dataset = DatasetLMDB(train_txt_path, train_transforms)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=n_worker,
                                  pin_memory=True)
    # do other things

    二进制大文件

    • 直接将现有数据集按照二进制读取,存入一个bins的大文件中,也不失为一种选择
    写入bins
    import cv2
    import numpy as np
    from tqdm import tqdm
     
     
    def write_bin(save_bin_path, save_index_path, data):
        """将现有基于文件的数据集写为bin大文件
        写入到save_index_path中的索引位置和标签,中间以\t分割
     
        Args:
            save_bin_path (str): 保存bin的位置
            save_index_path (str): 保存bin中索引和对应标签
            data (str): 存放图像路径和对应标签的list,
                        e.g. [['xxx/1.jpg', 'cat'], ['xxx/2.jpg', 'dog']]
        """
        with open(save_bin_path, 'wb') as f_w, \
                open(save_index_path, 'w') as f_index:
            start_index = 0
     
            for img_path, label in tqdm(data):
                with open(img_path, 'rb') as f:
                    img_bin = f.read()
     
                f_w.write(img_bin)
     
                len_bin = len(img_bin)
                f_index.write(f'{start_index}\t{len_bin}\t{label}\n')
     
                start_index += len_bin
     
     
    def read_bin(bin_path, index_path):
        """读取bin大文件和对应的索引标签txt
     
        Args:
            bin_path (str): bin大文件存放路径
            index_path (str): 索引和标签存放txt的路径
        """
        with open(bin_path, 'rb') as f_bin, open(index_path, 'r') as f_index:
            index_lines = list(map(lambda x: x.strip(), f_index.readlines()))
            index_lines = list(map(lambda x: x.split('\t'), index_lines))
     
            for i, (start_index, length) in enumerate(index_lines):
                start_index = int(start_index)
                length = int(length.strip())
                 
                # 定位到当前指针位置到start_index
                f_bin.seek(start_index)
     
                # 读取length的字节值
                img_bytes = f_bin.read(length)
     
                img = np.frombuffer(img_bytes, dtype='uint8')
                img = cv2.imdecode(img, -1)  # -1: cv.IMREAD_UNCHANGED
     
            # 转为PIL
                # img = Image.fromarray(img)
                # img = img.convert('RGB')
     
                # 保存图像
                # cv2.imwrite(f'temp/images/{i}.jpg', img)

    Sqlite

    • 采用python内置的sqlite3作为存储格式,也是一种好的选择
    写入到sqlite数据库中
    import sqlite3
    from pathlib import Path
     
    from tqdm import tqdm
     
     
    def read_txt(txt_path):
        with open(txt_path, 'r', encoding='utf-8-sig') as f:
            data = list(map(lambda x: x.rstrip('\n'), f))
        return data
     
     
    def img_to_bytes(img_path):
        with open(img_path, 'rb') as f:
            img_bytes = f.read()
            return img_bytes
     
     
    class SQLiteWriter(object):
        def __init__(self, db_path):
            self.conn = sqlite3.connect(db_path)
            self.cursor = self.conn.cursor()
     
        def execute(self, sql, value=None):
            if value:
                self.cursor.execute(sql, value)
            else:
                self.cursor.execute(sql)
     
        def __enter__(self):
            return self
     
        def __exit__(self, exc_type, exc_val, exc_tb):
            self.cursor.close()
            self.conn.commit()
            self.conn.close()
     
     
    if __name__ == '__main__':
        dataset_dir = Path('datasets/minist')
     
        save_db_dir = dataset_dir / 'sqlite'
        save_db_path = str(save_db_dir / 'val.db')
     
        # val.txt中 每行为:图像路径\t对应文本值 e.g. xxxx.jpg\txxxxxx
        img_paths = read_txt(str(dataset_dir / 'val.txt'))
     
        with SQLiteWriter(save_db_path) as db_writer:
            # 创建表
            table_name = 'minist'
      
            # 注意这里的表中字段,要根据自己数据集来定义
            # 具体数据库类型,可参考:https://docs.python.org/zh-cn/3/library/sqlite3.html#sqlite-and-python-types
            # demo中示例所涉及到的数据集为文本识别数据集,样本为图像,标签为对应文本,
            # 下面示例字段的数据类型为python下的数据类型,只需转为以下对应数据类型即可写入数据库的表中
            # e.g. img_path: str(xxxx.jpg), img_data: bytes格式的图像数据, img_label: str(xxxxx)
            create_table_sql = f'create table {table_name} (img_path TEXT primary key, img_data BLOB, img_label TEXT)'
            db_writer.execute(create_table_sql)
     
            # 向表中插入数据,value部分采用占位符
            insert_sql = f'insert into {table_name} (img_path, img_data, img_label) values(?, ?, ?)'
            for img_info in tqdm(img_paths):
                img_path, label = img_info.split('\t')
     
                img_full_path = str(dataset_dir / 'images' / img_path)
                img_data = img_to_bytes(img_full_path)
     
                db_writer.execute(insert_sql, (img_path, img_data, label))
    读取数据库
    class SimpleDataset(Dataset):
        def __init__(self, db_path, transform=None) -> None:
            self.db_path = db_path
            self.conn = None
            self.establish_conn()
     
            # 数据库中表名
            self.table_name = 'Synthetic_chinese_dataset'
     
            self.cursor.execute(f'select max(rowid) from {self.table_name}')
            self.nums = self.cursor.fetchall()[0][0]
            self.transform = transform
     
        def __getitem__(self, index: int):
            self.establish_conn()
     
            # 查询
            search_sql = f'select * from {self.table_name} where rowid=?'
            self.cursor.execute(search_sql, (index+1, ))
            img_path, img_bytes, label = self.cursor.fetchone()
     
            # 还原图像和标签
            img = Image.open(BytesIO(img_bytes))
            img = img.convert('RGB')
            img = scale_resize_pillow(img, (320, 32))
     
            if self.transform:
                img = self.transform(img)
            return img, label
     
        def __len__(self) -> int:
            return self.nums
     
        def establish_conn(self):
            if self.conn is None:
                self.conn = sqlite3.connect(self.db_path,
                                            check_same_thread=False,
                                            cached_statements=1024)
                self.cursor = self.conn.cursor()
            return self
     
        def close_conn(self):
            if self.conn is not None:
                self.cursor.close()
                self.conn.close()
     
                del self.conn
                self.conn = None
            return self  
     
    # --------------------------------------------------
    train_dataset = SimpleDataset(train_db_path, train_transforms)
    # ✧✧使用部分,需要手动关闭数据库连接
    train_dataset.close_conn()
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  num_workers=n_worker,
                                  pin_memory=True,
                                  sampler=train_sampler)

    最终结论

    TFRecord

    转换前后,数据存储大小不变,可以充分利用GPU

    tfrecord不能接入到其他数据增强方式(imgaug,opencv),且数据增强方式十分有限

    LMDB

    转换前后,数据存储大小会变得很大(原始4.2G→转换后96G)

    PyTorch多进程读取数据时,会出现图像不能还原为原始图像问题,暂时未找到解决方案

    读取效率可以充分利用GPU

    二进制大文件

    转换前后,数据存储大小不变

    同样,PyTorch多进程读取,也会出现图像不能正确还原的问题,暂时未找到解决方案

     sqlite(推荐使用)

    转换前后,数据存储大小不变

    可以正常多进程读取

    参考资料

    =====================================================

    引言Tensorflow有着专门的数据读取模块tfrecord,可以高效地读取训练神经网络模型所用的数据,充分喂饱GPUCaffe用lmdb来读取数据,也可以很高效地去读取PyTorch有DataLoader读取数据,但是速度比较慢,尤其是小文件较多情况下如何基于PyTorch,高效读取数据,充分利用GPU性能,成为一个关键问题?TFRecord是否可以将tensorflow下的tfrecord借来一用?未尝不可目前已经有伙伴实现了,详情参见:tfrecord同时,在Kaggle上,也有大神手动实现,详情参见:PyTorch TFRecord-Loadertfrecord写入代码:import cv2import numpy as npimport tensorflow as tffrom tqdm import tqdm from data_loader import TFRecordDataLoader  def read_txt(txt_path):    with open(txt_path, 'r', encoding='utf-8') as f:        data = f.readlines()    data = list(map(lambda x: x.rstrip('\n'), data))    return data  def bytes_to_numpy(image_bytes):    image_np = np.frombuffer(image_bytes, dtype=np.uint8)    image_np2 = cv2.imdecode(image_np, cv2.IMREAD_COLOR)    return image_np2   def list_record_features(tfrecords_path):    """查看tfrecords结构    https://stackoverflow.com/questions/63562691/reading-a-tfrecord-file-where-features-that-were-used-to-encode-is-not-known     Args:        tfrecords_path (str): tfrecords路径     Returns:        dict: 结构信息    """    features = {}    dataset = tf.data.TFRecordDataset([str(tfrecords_path)])    data = next(iter(dataset))     example = tf.train.Example()    example_bytes = data.numpy()    example.ParseFromString(example_bytes)     for key, value in example.features.feature.items():        kind = value.WhichOneof('kind')        size = len(getattr(value, kind).value)        if key in features:            kind2, size2 = features[key]            if kind != kind2:                kind = None             if size != size2:                size = None        features[key] = (kind, size)    return features class TFRecorder(object):    def __init__(self) -> None:        super().__init__()        self.feature_dict = {            'height': None,            'width': None,            'depth': None,            'label': None,            'image_raw': None        }        self.AUTO = tf.data.experimental.AUTOTUNE     def image_to_feature(self, image_string, label):        height, width, channel = tf.image.decode_image(image_string).shape        self.feature_dict = {            'height': self._int64_feature(height),            'width': self._int64_feature(width),            'depth': self._int64_feature(channel),            'label': self._int64_feature(label),            'image_raw': self._bytes_feature(image_string)        }        return tf.train.Example(features=tf.train.Features(feature=self.feature_dict))     def write(self, save_path, img_label_dict):        with tf.io.TFRecordWriter(save_path) as writer:            for file_name, label in tqdm(img_label_dict.items()):                img_string = open(file_name, 'rb').read()                feature = self.image_to_feature(img_string, label)                writer.write(feature.SerializeToString())     def read(self, tfrecord_path):        reader = tf.data.TFRecordDataset(tfrecord_path)        dataset = reader.map(self._parse_image_function,                             num_parallel_calls=self.AUTO)        return dataset     def _parse_image_function(self, example_proto):        self.feature_dict = {            'height': tf.io.FixedLenFeature([], tf.int64),            'width': tf.io.FixedLenFeature([], tf.int64),            'depth': tf.io.FixedLenFeature([], tf.int64),            'label': tf.io.FixedLenFeature([], tf.int64),            'image_raw': tf.io.FixedLenFeature([], tf.string)        }        example = tf.io.parse_single_example(example_proto,                                             self.feature_dict)        return example     @staticmethod    def _bytes_feature(value):        """Returns a bytes_list from a string / byte."""        if isinstance(value, type(tf.constant(0))):            # BytesList won't unpack a string from an EagerTensor.            value = value.numpy()        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))     @staticmethod    def _float_feature(value):        """Returns a float_list from a float / double."""        return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))     @staticmethod    def _int64_feature(value):        """Returns an int64_list from a bool / enum / int / uint."""        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))  if __name__ == '__main__':    tfrecorder = TFRecorder()     # val.txt中存放的是图像的相对路径    img_path = read_txt('dataset/val.txt')     # Path(v).parent.name: 图像的标签    img_label_dict = {v: int(Path(v).parent.name) for v in img_path}     save_path = 'temp/val.tfrecords'    tfrecorder.write(save_path, img_label_dict)     dataset = tfrecorder.read('dataset/val.tfrecords')    for v in dataset:        img, label = v        print('ok')     # 查看未知tfrecords结构信息   list_record_features('xxxx.tfrecords')123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137基于PyTorch下tfrecord读取代码import cv2import numpy as npimport tensorflow as tfimport tensorflow_datasets as tfds
    AUTO = tf.data.experimental.AUTOTUNE

    def bytes_to_numpy(image_bytes):    image_np = np.frombuffer(image_bytes, dtype=np.uint8)    image_np2 = cv2.imdecode(image_np, cv2.IMREAD_COLOR)    return image_np2

    def read_labeled_tfrecord(example_proto):    feature_dict = {        'height': tf.io.FixedLenFeature([], tf.int64),        'width': tf.io.FixedLenFeature([], tf.int64),        'depth': tf.io.FixedLenFeature([], tf.int64),        'label': tf.io.FixedLenFeature([], tf.int64),        'image_raw': tf.io.FixedLenFeature([], tf.string)    }    example = tf.io.parse_single_example(example_proto,                                         feature_dict)    img = tf.io.decode_image(example['image_raw'], channels=3,                             expand_animations=False)    img = tf.image.resize_with_crop_or_pad(img,                                           target_height=388,                                           target_width=270)    return img, example['label']

    def get_dataset(files, batch_size=16, repeat=False,                cache=False, shuffle=False):    ds = tf.data.TFRecordDataset(files, num_parallel_reads=AUTO)    if cache:        ds = ds.cache()
        if repeat:        ds = ds.repeat()
        if shuffle:        ds = ds.shuffle(1024 * 2)        opt = tf.data.Options()        opt.experimental_deterministic = False        ds = ds.with_options(opt)
        ds = ds.map(read_labeled_tfrecord, num_parallel_calls=AUTO)    ds = ds.batch(batch_size)    ds = ds.prefetch(AUTO)    return tfds.as_numpy(ds)

    def count_data_items(file):    num_ds = tf.data.TFRecordDataset(file, num_parallel_reads=AUTO)    num_ds = num_ds.map(read_labeled_tfrecord, num_parallel_calls=AUTO)    num_ds = num_ds.repeat(1)    num_ds = num_ds.batch(1)
        c = 0    for _ in num_ds:        c += 1    del num_ds    return c

    class TFRecordDataLoader:    def __init__(self, files, batch_size=32, cache=False, train=True,                 repeat=False, shuffle=False, labeled=True,                 return_image_ids=True):        self.ds = get_dataset(            files,            batch_size=batch_size,            cache=cache,            repeat=repeat,            shuffle=shuffle,)
            if train:            self.num_examples = count_data_items(files)
            self.batch_size = batch_size        self.labeled = labeled        self.return_image_ids = return_image_ids        self._iterator = None
        def __iter__(self):        if self._iterator is None:            self._iterator = iter(self.ds)        else:            self._reset()        return self._iterator
        def _reset(self):        self._iterator = iter(self.ds)
        def __next__(self):        batch = next(self._iterator)        return batch
        def __len__(self):        n_batches = self.num_examples // self.batch_size        if self.num_examples % self.batch_size == 0:            return n_batches        else:            return n_batches + 1
    # 使用train_txt_path = 'dataset/minist/train.tfrecords'train_dataloader = TFRecordDataLoader(train_txt_path,                                      batch_size=batch_size,                                      shuffle=True)for v in train_dataloader:    pass123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113LMDB纵观各大论坛,说到基于PyTorch下提高小文件读取速度,不得不说到LMDB(Lightning Memory-Mapped Database)了,我也做了一些尝试,最终结论将在最后给出写入LMDBimport osimport picklefrom pathlib import Path import cv2import lmdbimport numpy as npfrom PIL import Imagefrom torch.utils.data import DataLoader, Datasetfrom torchvision import transformsfrom tqdm import tqdm import utils  class SimpleDataset(Dataset):    def __init__(self, txt_path, transform=None) -> None:        self.img_paths = utils.read_txt(txt_path)        self.transform = transform     def __getitem__(self, index: int):        img_path = self.img_paths[index]        label = int(Path(img_path).parent.name)        try:            img = Image.open(img_path)            img = img.convert('RGB')        except:            img = cv2.imread(img_path)            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)            img = Image.fromarray(img)         if self.transform:            img = self.transform(img)        img = np.array(img)        return img, label     def __len__(self) -> int:        return len(self.img_paths)  class LMDB_Image:    def __init__(self, image, label):        # Dimensions of image for reconstruction - not really necessary        # for this dataset, but some datasets may include images of        # varying sizes        self.channels = image.shape[2]        self.size = image.shape[:2]         self.image = image.tobytes()        self.label = label     def get_image(self):        """ Returns the image as a numpy array. """        image = np.frombuffer(self.image, dtype=np.uint8)        return image.reshape(*self.size, self.channels)  def data2lmdb(dpath, name="train", txt_path=None,              write_frequency=10, num_workers=4):    dataset = SimpleDataset(txt_path=txt_path)    data_loader = DataLoader(dataset, num_workers=num_workers,                             collate_fn=lambda x: x)     lmdb_path = os.path.join(dpath, "%s.lmdb" % name)    isdir = os.path.isdir(lmdb_path)     print("Generate LMDB to %s" % lmdb_path)    db = lmdb.open(lmdb_path, subdir=isdir,                   map_size=1099511627776,  # 单位byte                   readonly=False,                   meminit=False,                   map_async=True)     txn = db.begin(write=True)    for idx, data in enumerate(tqdm(data_loader)):        image, label = data[0]        temp = LMDB_Image(image, label)        txn.put(u'{}'.format(idx).encode('ascii'), pickle.dumps(temp))         if idx % write_frequency == 0:            print("[%d/%d]" % (idx, len(data_loader)))            txn.commit()            txn = db.begin(write=True)     # finish iterating through dataset    txn.commit()     keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]    with db.begin(write=True) as txn:        txn.put(b'__keys__', pickle.dumps(keys))        txn.put(b'__len__', pickle.dumps(len(keys)))     print("Flushing database ...")    db.sync()    db.close() if __name__ == '__main__':    save_dir = 'dataset/minist'    data2lmdb(save_dir, name='val', txt_path='dataset/minist/val.txt')123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899读取LMDBclass DatasetLMDB(Dataset):    def __init__(self, db_path, transform=None):        self.db_path = db_path        self.env = lmdb.open(db_path,                             subdir=os.path.isdir(db_path),                             readonly=True, lock=False,                             readahead=False, meminit=False)        with self.env.begin() as txn:            self.length = pickle.loads(txn.get(b'__len__'))            self.keys = pickle.loads(txn.get(b'__keys__'))        self.transform = transform     def __getitem__(self, index):        with self.env.begin() as txn:            byteflow = txn.get(self.keys[index])         IMAGE = pickle.loads(byteflow)        img, label = IMAGE.get_image(), IMAGE.label        return Image.fromarray(img).convert('RGB'), label     def __len__(self):        return self.length # 使用train_transforms = transforms.Compose([    transforms.Resize((388, 270)),    transforms.RandomChoice([        transforms.RandomRotation(10),        transforms.RandomHorizontalFlip(0.5),        transforms.RandomGrayscale(p=0.3),        transforms.RandomPerspective(distortion_scale=0.6, p=0.5),        transforms.ColorJitter(brightness=.5, hue=.3),    ]),    transforms.ToTensor(),    normalize,    transforms.RandomErasing(),    ]) train_dataset = DatasetLMDB(train_txt_path, train_transforms)train_dataloader = DataLoader(train_dataset,                              batch_size=batch_size,                              shuffle=True,                              num_workers=n_worker,                              pin_memory=True)# do other things123456789101112131415161718192021222324252627282930313233343536373839404142434445二进制大文件直接将现有数据集按照二进制读取,存入一个bins的大文件中,也不失为一种选择写入binsimport cv2import numpy as npfrom tqdm import tqdm  def write_bin(save_bin_path, save_index_path, data):    """将现有基于文件的数据集写为bin大文件    写入到save_index_path中的索引位置和标签,中间以\t分割     Args:        save_bin_path (str): 保存bin的位置        save_index_path (str): 保存bin中索引和对应标签        data (str): 存放图像路径和对应标签的list,                    e.g. [['xxx/1.jpg', 'cat'], ['xxx/2.jpg', 'dog']]    """    with open(save_bin_path, 'wb') as f_w, \            open(save_index_path, 'w') as f_index:        start_index = 0         for img_path, label in tqdm(data):            with open(img_path, 'rb') as f:                img_bin = f.read()             f_w.write(img_bin)             len_bin = len(img_bin)            f_index.write(f'{start_index}\t{len_bin}\t{label}\n')             start_index += len_bin  def read_bin(bin_path, index_path):    """读取bin大文件和对应的索引标签txt     Args:        bin_path (str): bin大文件存放路径        index_path (str): 索引和标签存放txt的路径    """    with open(bin_path, 'rb') as f_bin, open(index_path, 'r') as f_index:        index_lines = list(map(lambda x: x.strip(), f_index.readlines()))        index_lines = list(map(lambda x: x.split('\t'), index_lines))         for i, (start_index, length) in enumerate(index_lines):            start_index = int(start_index)            length = int(length.strip())                         # 定位到当前指针位置到start_index            f_bin.seek(start_index)             # 读取length的字节值            img_bytes = f_bin.read(length)             img = np.frombuffer(img_bytes, dtype='uint8')            img = cv2.imdecode(img, -1)  # -1: cv.IMREAD_UNCHANGED         # 转为PIL            # img = Image.fromarray(img)            # img = img.convert('RGB')             # 保存图像            # cv2.imwrite(f'temp/images/{i}.jpg', img)12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061读取binsfrom io import BytesIOfrom PIL import Imageimport cv2import numpy as np class SimpleDataset(Dataset):    def __init__(self, txt_path, bin_path, transform=None) -> None:        self.index_info = utils.read_txt(txt_path)        self.index_info = list(map(lambda x: x.split('\t'), self.index_info))         self.f_bin = open(bin_path, 'rb')        self.transform = transform     def __getitem__(self, index: int):        start_index, length, label = list(map(int, self.index_info[index]))        print(start_index)         self.f_bin.seek(start_index)        img_bytes = self.f_bin.read(length)                 # 方案一:        img = np.frombuffer(img_bytes, dtype='uint8')        img = cv2.imdecode(img, -1)        if img is None:            return self.__getitem__(random.randint(0, self.__len__() - 1))         img = Image.fromarray(img)        img = img.convert('RGB')         # 方案二:        try:            img = Image.open(BytesIO(img_bytes))            img = img.convert('RGB')        except:            return self.__getitem__(random.randint(0, self.__len__() - 1))         if self.transform:            img = self.transform(img)        return img, label     def __len__(self) -> int:        return len(self.index_info)123456789101112131415161718192021222324252627282930313233343536373839404142Sqlite采用python内置的sqlite3作为存储格式,也是一种好的选择写入到sqlite数据库中import sqlite3from pathlib import Path from tqdm import tqdm  def read_txt(txt_path):    with open(txt_path, 'r', encoding='utf-8-sig') as f:        data = list(map(lambda x: x.rstrip('\n'), f))    return data  def img_to_bytes(img_path):    with open(img_path, 'rb') as f:        img_bytes = f.read()        return img_bytes  class SQLiteWriter(object):    def __init__(self, db_path):        self.conn = sqlite3.connect(db_path)        self.cursor = self.conn.cursor()     def execute(self, sql, value=None):        if value:            self.cursor.execute(sql, value)        else:            self.cursor.execute(sql)     def __enter__(self):        return self     def __exit__(self, exc_type, exc_val, exc_tb):        self.cursor.close()        self.conn.commit()        self.conn.close()  if __name__ == '__main__':    dataset_dir = Path('datasets/minist')     save_db_dir = dataset_dir / 'sqlite'    save_db_path = str(save_db_dir / 'val.db')     # val.txt中 每行为:图像路径\t对应文本值 e.g. xxxx.jpg\txxxxxx    img_paths = read_txt(str(dataset_dir / 'val.txt'))     with SQLiteWriter(save_db_path) as db_writer:        # 创建表        table_name = 'minist'          # 注意这里的表中字段,要根据自己数据集来定义        # 具体数据库类型,可参考:https://docs.python.org/zh-cn/3/library/sqlite3.html#sqlite-and-python-types        # demo中示例所涉及到的数据集为文本识别数据集,样本为图像,标签为对应文本,        # 下面示例字段的数据类型为python下的数据类型,只需转为以下对应数据类型即可写入数据库的表中        # e.g. img_path: str(xxxx.jpg), img_data: bytes格式的图像数据, img_label: str(xxxxx)        create_table_sql = f'create table {table_name} (img_path TEXT primary key, img_data BLOB, img_label TEXT)'        db_writer.execute(create_table_sql)         # 向表中插入数据,value部分采用占位符        insert_sql = f'insert into {table_name} (img_path, img_data, img_label) values(?, ?, ?)'        for img_info in tqdm(img_paths):            img_path, label = img_info.split('\t')             img_full_path = str(dataset_dir / 'images' / img_path)            img_data = img_to_bytes(img_full_path)             db_writer.execute(insert_sql, (img_path, img_data, label))1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768读取数据库class SimpleDataset(Dataset):    def __init__(self, db_path, transform=None) -> None:        self.db_path = db_path        self.conn = None        self.establish_conn()         # 数据库中表名        self.table_name = 'Synthetic_chinese_dataset'         self.cursor.execute(f'select max(rowid) from {self.table_name}')        self.nums = self.cursor.fetchall()[0][0]        self.transform = transform     def __getitem__(self, index: int):        self.establish_conn()         # 查询        search_sql = f'select * from {self.table_name} where rowid=?'        self.cursor.execute(search_sql, (index+1, ))        img_path, img_bytes, label = self.cursor.fetchone()         # 还原图像和标签        img = Image.open(BytesIO(img_bytes))        img = img.convert('RGB')        img = scale_resize_pillow(img, (320, 32))         if self.transform:            img = self.transform(img)        return img, label     def __len__(self) -> int:        return self.nums     def establish_conn(self):        if self.conn is None:            self.conn = sqlite3.connect(self.db_path,                                        check_same_thread=False,                                        cached_statements=1024)            self.cursor = self.conn.cursor()        return self     def close_conn(self):        if self.conn is not None:            self.cursor.close()            self.conn.close()             del self.conn            self.conn = None        return self   # --------------------------------------------------train_dataset = SimpleDataset(train_db_path, train_transforms)# ✧✧使用部分,需要手动关闭数据库连接train_dataset.close_conn()train_dataloader = DataLoader(train_dataset,                              batch_size=batch_size,                              num_workers=n_worker,                              pin_memory=True,                              sampler=train_sampler)1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859最终结论TFRecord
    转换前后,数据存储大小不变,可以充分利用GPUtfrecord不能接入到其他数据增强方式(imgaug,opencv),且数据增强方式十分有限LMDB
    转换前后,数据存储大小会变得很大(原始4.2G→转换后96G)PyTorch多进程读取数据时,会出现图像不能还原为原始图像问题,暂时未找到解决方案读取效率可以充分利用GPU二进制大文件
    转换前后,数据存储大小不变同样,PyTorch多进程读取,也会出现图像不能正确还原的问题,暂时未找到解决方案✧ sqlite(推荐使用)
    转换前后,数据存储大小不变可以正常多进程读取参考资料pytorch-sqlitesqlite_dataset————————————————版权声明:本文为CSDN博主「Liekkas Kono」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。原文链接:https://blog.csdn.net/shiwanghualuo/article/details/120778553

    本博客是博主个人学习时的一些记录,不保证是为原创,个别文章加入了转载的源地址还有个别文章是汇总网上多份资料所成,在这之中也必有疏漏未加标注者,如有侵权请与博主联系。
  • 相关阅读:
    周总结博客07
    河北重大技术需求系统05
    php面向对象中的魔术方法
    用Kotlin开发Android应用(II):创建新项目
    Android APP性能分析方法及工具
    php基础
    jQuery总结
    css总结
    PHP运算符优先级(摘自在线工具)
    PHPExcel导出excel表格
  • 原文地址:https://www.cnblogs.com/devilmaycry812839668/p/15622336.html
Copyright © 2020-2023  润新知