• transformer中 数据预处理代码理解


    今天师兄将transformer中的数据预处理部分讲了一下。

    数据准备: train.en train.cn  一个英文的语料,一个中文的语料 语料中是一些一行行的语句

    目标:将语料中的词抽取出来,放在一个词表里。词表里是序号+词 其次,将train中的语句形成数字序列  比如:today在词表中的id是1 is 在词表中的id是4 good的id是5

                                                      today is good --->145  

    details:

    #step1:创建10个文件(train) /dev (1个文件)
    1.
    tf.logging.info('create 10 train file name')
    suffix = '-train'
    train_paths = self.filepaths(data_dir, self.num_shards, suffix)
    suffix = '-dev'
    dev_paths = self.filepaths(data_dir, self.num_dev_shards, suffix)
    10个文件的作用在于:避免产生的语句序列过多放于一个文件夹中运行时时间慢长,比如产生12万句子,平分到10个文件夹中
    #step2: 对训练文件中单词统计词频信息,抽取前size大小的词作为词表(3特殊符号)
    with tf.gfile.GFile(filepaths, mode='r') as source_file:
    for line in source_file:
    yield line #读取训练文件,返回每一行


    tf.logging.info('step1: count word number dict')
    token_counts = defaultdict(int) #将词表存于词典中,以K,V的形式存放
    for item in generate(): #item为返回的每一行
    words = native_to_unicode(item).strip().split()#此作用就是将每一行的词以空格来split,获取每个词
    for tok in words: #记录词出现的次数-词频
    token_counts[tok] += 1
     
    #抽取target_size的大小,形成词表
    self._alphabet = chain(six.iterkeys(token_counts), [native_to_unicode(t) for t in RESERVED_TOKENS])
    new_subtoken_strings = []
    #这一步是计算RESERVED_TOKENS = [PAD, EOS, UNK] 三个特殊字符在token_counts出现的次数,一般都是0
    #PAD--补位 EOS--结束符 UNK---不在所选取的30000个单词里面
    new_subtoken_strings.extend((token_counts.get(a, 0), a) for a in self._alphabet)
    new_subtoken_strings.sort(reverse=True)
    new_subtoken_strings = [subtoken for _, subtoken in new_subtoken_strings]
    new_subtoken_strings = new_subtoken_strings[:target_size]#抽取size的词
    new_subtoken_strings = RESERVED_TOKENS + new_subtoken_strings
    self._subtoken_string_to_id = {s: i for i, s in enumerate(new_subtoken_strings) if s}#添加id
     
    #step3: 根据形成的词表,对训练文件中每句话转为ID序列
    #转成ID序列
    tokens = native_to_unicode(raw_text).strip().split()
    ret=[]
    for tok in tokens:
    if tok in self._subtoken_string_to_id:
    ret.extend([self._subtoken_string_to_id[tok]])
    else:
    ret.extend([UNK_ID])
    #调用上面的encode,转成序列函数,来表示语句序列
    source_ints = source_vocab.encode(source.strip()) + eos_list
    target_ints = target_vocab.encode(target.strip()) + eos_list
    yield {"inputs": source_ints, "targets": target_ints}
    source, target = source_file.readline(), target_file.readline()
    #step4:将ID序列保存到10个不同文件中,并shuffle文件
    for case in generator:generator是第三步的函数
    if counter > 0 and counter % 100000 == 0: counter是第几句话
    tf.logging.info("Generating case %d." % counter)
    counter += 1
    features = {}
    for (k, v) in six.iteritems(case):
    if isinstance(v[0], six.integer_types):
    features[k] = tf.train.Feature(int64_list = tf.train.Int64List(value=v))
    sequence_example = tf.train.Example(features=tf.train.Features(feature=features))
    writers[shard].write(sequence_example.SerializeToString())
    shard = (shard + 1) % num_shards
    -------今天才发现博客可以插入代码------------好吧----------好丑的排版--------------------------
  • 相关阅读:
    Lua学习笔记(二):基本语法
    Lua学习笔记(一):搭建开发环境
    C#学习笔记(十六):Attribute
    [U3D Demo] 手机FPS射击游戏
    C#学习笔记(十五):预处理指令
    js 树菜单 ztree
    jquery flexslider 轮播插件
    浏览器 本地预览图片 window.url.createobjecturl
    mouseover mouseenter mouseout mouseleave
    jquery checkbox问题
  • 原文地址:https://www.cnblogs.com/Shaylin/p/9864016.html
Copyright © 2020-2023  润新知