• Tensorflow解析tfrecord


    1、序列化

    #coding:utf-8
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import numpy as np
    import tensorflow as tf
    from python_speech_features import logfbank
    import scipy.io.wavfile as wav
    import struct
    
    OUT_BASE_DIR="/data/asr/duantiantian/data"
    
    # 00000210002     /netdisk1/asr_data/accented/00000210002.wav
    # 00000210003     /netdisk1/asr_data/accented/00000210003.wav
    wavscp_path=OUT_BASE_DIR+"/wav.scp.bk"
    text_path=OUT_BASE_DIR+"/syllables"
    symbol_path=OUT_BASE_DIR+"/symbol.txt"
    
    short_file,long_file,wrong_file,oov_file=[],[],[],[]
    STRIDE = 2
    OUT_DIR = OUT_BASE_DIR+"/tfrecord"
    ERROR = OUT_BASE_DIR+"/error_log"
    accsum1 = 0.
    accsum2 = 0.
    accnfrm = 0
    
    
    def _bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    def get_feature(wavfile):
        rate,sig = wav.read(wavfile)
        if rate!=8000:
            print("*** %s: sample rate is not 8000 ***" % wavfile)
            wrong_file.append(wavfile+'\n')
        feats = logfbank(sig,samplerate=rate,winlen=0.025,winstep=0.01,nfilt=40,nfft=512,lowfreq=64,highfreq=3800,preemph=0.97,dither=1,wintype='povey')
        return feats
    
    def write_feat_to_tfr(id,feats,labels):
        print("dtt--",key,"|",labels)
        outfilename = OUT_DIR + '/' + id + '.tfr'
        feats = np.reshape(feats,[-1])
        feats = feats.astype(np.float32)
        lab = np.array(labels,dtype=np.int32)
    
        writer = tf.io.TFRecordWriter(outfilename)
        example = tf.train.Example(features = tf.train.Features(feature={'spectr':_bytes_feature(feats.tostring()),'label':_bytes_feature(lab.tostring())}))
        writer.write(example.SerializeToString())
        writer.close()
        return
    
    # model_unit to idx
    ref2id = {}
    with open(symbol_path,'r') as f:
        for line in f.readlines():
            line = line.strip()
            # print("dtt--",line)
            ref,id = line.strip().split('#')
            ref2id[ref] = int(id)
    
    
    wavlist = {}
    textlist = {}
    with open(wavscp_path,'r') as f:
        for line in f:
            line = line.strip()
            id,path = line.split('\t')
            wavlist[id] = path
    
    with open(text_path,'r') as f:
        for line in f.readlines():
            line = line.strip()[:-1]
            if len(line.split(' ')) == 1:     # 如果转录为空,那么就不考虑这一条
                continue
            id,sylls = line.split()
            textlist[id] = sylls
    
    common = {}
    filesize = []
    for key in wavlist:
        if key in textlist:
            common[key] = (wavlist[key],textlist[key])
            try:
                curr_feature = get_feature(wavlist[key])
            except:
                wrong_file.append(wavlist[key]+'\n')
                continue
    
            labels = textlist[key].split("#")
            labels = [ref2id[ele] for ele in labels]
            frame_num = curr_feature.shape[0]
            if(frame_num<40):
                print("%s less than 0.4s" % wavlist[key])
                short_file.append(wavlist[key]+'\n')
                continue
            if(frame_num>1500):
                print("%s longer than 15s" % wavlist[key])
                long_file.append(wavlist[key]+'\n')
                continue
            if(frame_num//STRIDE<=len(labels)):
                print("%s label size longer than frame number" % key)
                oov_file.append(key+'\n')
                continue
    
            #curr_feature = curr_feature.astype(np.float32)
            write_feat_to_tfr(key,curr_feature,labels)
            fsum1 = np.sum(curr_feature, 0)
            fsum2 = np.sum(np.square(curr_feature), 0)
            accsum1 += fsum1
            accsum2 += fsum2
            accnfrm += frame_num
            filesize.append(key+'\t'+str(frame_num*40*4)+'\n')
    
    
    with open(OUT_BASE_DIR+"/tfr.size",'w') as f:
        f.writelines(filesize)
    print("tfr Done")
    if len(short_file)>0:
        with open(ERROR+"/short_file",'w') as f:
            f.writelines(short_file)
    
    if len(long_file)>0:
        with open(ERROR+"/long_file",'w') as f:
            f.writelines(long_file)
    
    if len(wrong_file)>0:
        with open(ERROR+"/wrong_file",'w') as f:
            f.writelines(wrong_file)
    
    if len(oov_file)>0:
        with open(ERROR+"/oov_file",'w') as f:
            f.writelines(oov_file)
    
    
    accsum1 = -accsum1/accnfrm
    aux = np.ones(40)
    accsum2 = np.divide(aux, np.sqrt(np.subtract((accsum2/accnfrm),np.square(accsum1))))
    
    fmean = np.zeros(40, dtype=float)
    fvar = np.zeros(40, dtype=float)
    for i in range(40):
        fmean[i] = float(accsum1[i])
        fvar[i] = float(accsum2[i])
    print(fmean, fvar)
    print('Save done')
    #print(curr_feature.shape)
    #print(curr_feature[0,:])

    2、解析

    #coding:utf-8
    # CNN+FSMN
    # Author: Jie Ma
    
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    import os
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    
    import re
    import time
    import os
    import errno
    import sys
    import math
    import struct
    import ConfigParser
    
    # from six.moves import xrange  # pylint: disable=redefined-builtin
    import tensorflow as tf
    import numpy as np
    from tensorflow.python.framework import ops
    
    
    # TFR_LIST_FILE="/data/asr/duantiantian/data/tfrecord/eb39a10b280ecdfa04a581da2f002ebf_0114_0002.tfr"
    def read_and_decode(TFR_LIST_FILE):
        print("---------------33----------------")
        filename_queue = tf.train.string_input_producer([TFR_LIST_FILE])  # 生成一个queue队列
    
        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)  # 返回文件名和文件
        raw_example = tf.parse_single_example(serialized_example,
                                           features={
                                               'spectr': tf.FixedLenFeature([], tf.string),
                                                'label': tf.FixedLenFeature([], tf.string),
                                           })  # 将image数据和label取出来
        spectr = raw_example['spectr']
        label = raw_example['label']
        spectr = tf.decode_raw(spectr, tf.float32)
        label = tf.decode_raw(label, tf.int32)
        label = tf.reshape(label, [-1])
        return spectr,label
    
    def parse_exmp(serialized_example):
        raw_example = tf.parse_single_example(
            serialized_example,
            # Defaults are not specified since both keys are required.
            features={
                'spectr': tf.FixedLenFeature([], tf.string),
                'label': tf.FixedLenFeature([], tf.string),
            })
        spectr = raw_example['spectr']
        label = raw_example['label']
        label = tf.decode_raw(label, tf.int32)
        label = tf.reshape(label, [-1])
        label = label + 1
        example = tf.decode_raw(spectr, tf.float32)
        example = tf.reshape(example, [-1, FREQ_BIN_NUM, CHANNEL_NUM])
        example_length = tf.shape(example)[0]
        return example, label, example_length
    
    
    TFR_LIST_FILE="tmp.txt"
    BATCH_SIZE=1
    FREQ_BIN_NUM=40
    CHANNEL_NUM=1
    
    
    def sparse(example, label, example_length):
        example_length = tf.reshape(example_length, [BATCH_SIZE])
        indices = tf.where(tf.not_equal(tf.cast(label, tf.float32), 0.))
        targets = tf.SparseTensor(indices=indices, values=(tf.gather_nd(label, indices) - 1),
                                  dense_shape=tf.cast(tf.shape(label), tf.int64))
        return example, targets, example_length
    
    
    dataset = tf.data.TextLineDataset(TFR_LIST_FILE)
    dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=BATCH_SIZE)
    dataset = dataset.map(parse_exmp, num_parallel_calls=16)
    dataset = dataset.prefetch(buffer_size=10 * BATCH_SIZE)
    dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=([None, FREQ_BIN_NUM, CHANNEL_NUM], [None], []))
    dataset = dataset.map(sparse, num_parallel_calls=16)
    iterator = dataset.make_one_shot_iterator()
    
    
    # with tf.Session() as sess: #开始一个会话
    #
    #     sess.run(tf.global_variables_initializer())
    #     sess.run(tf.local_variables_initializer())
        # spectr,label=read_and_decode(TFR_LIST_FILE)
        # print("-----------11-----------")
        # init_op = tf.global_variables_initializer()
        # sess.run(init_op)
        # spec, label = sess.run([spectr,label])#在会话中取出image和label
        # print("-----------22-----------")
        # print(label)
    
    
    def get_batch():
        ele = iterator.get_next()
        spect, label, _ = ele
        return spect,label
    
    
    spect,label=get_batch()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(iterator.make_initializer(dataset))
        try:
            while True:
                print(sess.run([label]))
        except tf.errors.OutOfRangeError:
            print("outOfRange")
  • 相关阅读:
    Python从入门到精通之First!
    如果你不懂计算机语言,那么就请你不要说你是学计算机的!!好丢人。。。
    shell脚本-编程前奏-小工具之grep(文本处理)
    实战之授权站点漏洞挖掘-git信息泄漏
    实战之授权站点漏洞挖掘-CVE-2015-2808
    实战之授权站点漏洞挖掘-HTTP.sys远程代码执行
    实战之授权站点漏洞挖掘-CVE-1999-0554
    实战之授权站点漏洞挖掘-CORS
    实战之授权站点漏洞挖掘-越权
    实战之授权站点漏洞挖掘-url重定向
  • 原文地址:https://www.cnblogs.com/zijidan/p/16817552.html
Copyright © 2020-2023  润新知