• 读取tfrecord,并写入h5文件


    import tfrecord as tfr
    import h5py
    import os,sys
    import numpy as np
    import glob
    import pandas as pd
    from tqdm import tqdm
    class TfrecordWorker():
        def __init__(self,tfr_list):
            self.info = {"label":[],"typee":[],"shape":[]}
            self.data_dir = "raw_data"
            self.tfr_list = tfr_list
            self.tfr_description = self._parse_description("label_type.csv")
            loader = tfr.tfrecord_loader(self.tfr_list[0], None, self.tfr_description  )
            for record in loader:
                for key in record.keys():
                    self.info['label'].append(key)
                    self.info['typee'].append(type(record[key][0]))
                    self.info['shape'].append(record[key].shape)
            self.attr_size = len(self.info['label'])
            self.data_size = len(self.tfr_list)
            print(f"总共有{self.attr_size}个属性")
            print(f"总共有{self.data_size}个tfrecord文件")
    
        def create_h5f(self, h5path="./data.h5"):
            self.h5f = h5py.File(h5path, 'w')
            self.dset = {}
            for i in range(self.attr_size):
                label = self.info["label"][i]
                typee = self.info["typee"][i]
                shape = self.info["shape"][i]
                self.dset[label] = self.h5f.create_dataset(label,
                                  shape=[self.data_size, shape[0]],
                                  compression=None,
                                  dtype=typee)
    
            self.dset["name"] = self.h5f.create_dataset("name",
                                shape=[self.data_size],
                                compression=None,
                                dtype=h5py.special_dtype(vlen=str))
        def write_h5f(self):
            for idx,tfr_path in tqdm(enumerate(self.tfr_list)):
                self._write_one_item(tfr_path, idx)
                # if idx>5:
                #     break
                
            
        def close_h5f(self):
            self.h5f.close()
    
        def _write_one_item(self, tfr_path, idx):
            loader = tfr.tfrecord_loader(tfr_path, None, self.tfr_description  )
            for record in loader:
                for key in record.keys(): 
                    content = record[key]
                    self.dset[key][idx] = content
            self.dset["name"][idx] = tfr_path.split("/")[-1]
    
        def _parse_description(self, csv_path):
            label_type = pd.read_csv(csv_path, usecols=["label","type"])
            description = {}
            for _, row in label_type.iterrows():
                description[str(row['label']).strip()] = str(row['type']).strip()
            return description
    
    
    
    def start(files, savename):
        worker = TfrecordWorker(files)
        worker.create_h5f(savename)
        worker.write_h5f()
        worker.close_h5f()
    
    start(glob.glob("raw_data/*fold0*.tfrecord"),"fold0.h5")
    start(glob.glob("raw_data/*fold1*.tfrecord"),"fold1.h5")
    start(glob.glob("raw_data/*fold2*.tfrecord"),"fold2.h5")
    start(glob.glob("raw_data/*fold3*.tfrecord"),"fold3.h5")
    
    f = h5py.File('fold0.h5', 'r')
    print('--iterms: ', len(f.keys()), f.keys())
    name = f['name']
    print(name[:])
    
  • 相关阅读:
    预习笔记 多态 --S2 4.3
    织梦CMS标签生成器
    socketCluster 使用
    JS工具库之Lodash
    socketcluster 客户端请求
    AngularJS自定义指令directive:scope属性 (转载)
    angularjs报错问题记录
    Angularjs中的事件广播 —全面解析$broadcast,$emit,$on
    angularJS中directive与controller之间的通信
    AngularJs Type error : Cannot read property 'childNodes' of undefined
  • 原文地址:https://www.cnblogs.com/geoli/p/15983442.html
Copyright © 2020-2023  润新知