• cnn handwrite使用原生的TensorFlow进行预测


    100个汉字,放在data目录下。直接将下述文件和data存在同一个目录下运行即可。

    关键参数:

    run_mode = "train" 训练模型用,修改为validation 表示验证100张图片的预测精度,修改为inference表示预测 './data/00098/102544.png'这个图片手写识别结果,返回top3。

    charset_size = 100 表示汉字数目。如果是全量数据,则为3755.

    代码参考了:https://github.com/burness/tensorflow-101/blob/master/chinese_hand_write_rec/src/chinese_rec.py

    其中加入:(1)图像随机左右旋转30度特性 (2)断点续传进行训练(3)为了达到更高精度,加入了一个卷积层,见https://github.com/AmemiyaYuko/HandwrittenChineseCharacterRecognition

    import tensorflow as tf
    import os
    import random
    import math
    import tensorflow.contrib.slim as slim
    import time
    import logging
    import numpy as np
    import pickle
    from PIL import Image
     
     
    logger = logging.getLogger('Training a chinese write char recognition')
    logger.setLevel(logging.INFO)
    # formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    logger.addHandler(ch)
    
    run_mode = "train"
    charset_size = 100 # 3755
    max_steps = 12002
    save_steps = 2000
     
    """
    # for online 3755 words training
    checkpoint_dir = '/aiml/dfs/checkpoint/'
    train_data_dir = '/aiml/data/train/'
    test_data_dir = '/aiml/data/test/'
    log_dir = '/aiml/dfs/'
    """
    
    
    checkpoint_dir = './checkpoint2/'
    train_data_dir = './data/'
    test_data_dir = './data/'
    log_dir = './'
    
    
    tf.app.flags.DEFINE_string('mode', run_mode, 'Running mode. One of {"train", "valid", "test"}')
    tf.app.flags.DEFINE_boolean('random_flip_up_down', True, "Whether to random flip up down")
    tf.app.flags.DEFINE_boolean('random_brightness', True, "whether to adjust brightness")
    tf.app.flags.DEFINE_boolean('random_contrast', True, "whether to random constrast")
     
    tf.app.flags.DEFINE_integer('charset_size', charset_size, "Choose the first `charset_size` character to conduct our experiment.")
    tf.app.flags.DEFINE_integer('image_size', 64, "Needs to provide same value as in training.")
    tf.app.flags.DEFINE_boolean('gray', True, "whether to change the rbg to gray")
    tf.app.flags.DEFINE_integer('max_steps', max_steps, 'the max training steps ')
    tf.app.flags.DEFINE_integer('eval_steps', 50, "the step num to eval")
    tf.app.flags.DEFINE_integer('save_steps', save_steps, "the steps to save")
     
    tf.app.flags.DEFINE_string('checkpoint_dir', checkpoint_dir, 'the checkpoint dir')
    tf.app.flags.DEFINE_string('train_data_dir', train_data_dir, 'the train dataset dir')
    tf.app.flags.DEFINE_string('test_data_dir', test_data_dir, 'the test dataset dir')
    tf.app.flags.DEFINE_string('log_dir', log_dir, 'the logging dir')
     
    ##############################
    # resume training
    tf.app.flags.DEFINE_boolean('restore', True, 'whether to restore from checkpoint')
    ##############################
    
    tf.app.flags.DEFINE_boolean('epoch', 10, 'Number of epoches')
    tf.app.flags.DEFINE_boolean('batch_size', 128, 'Validation batch size')
    FLAGS = tf.app.flags.FLAGS
     
     
    class DataIterator:
        def __init__(self, data_dir):
            # Set FLAGS.charset_size to a small value if available computation power is limited.
            truncate_path = data_dir + ('%05d' % FLAGS.charset_size)
            print(truncate_path)
            self.image_names = []
            for root, sub_folder, file_list in os.walk(data_dir):
                if root < truncate_path:
                    self.image_names += [os.path.join(root, file_path) for file_path in file_list]
            random.shuffle(self.image_names)
            self.labels = [int(file_name[len(data_dir):].split(os.sep)[0]) for file_name in self.image_names]
     
        @property
        def size(self):
            return len(self.labels)
     
        @staticmethod
        def data_augmentation(images):
            if FLAGS.random_flip_up_down:
                # images = tf.image.random_flip_up_down(images)
                images = tf.contrib.image.rotate(images, random.randint(0, 30) * math.pi / 180, interpolation='BILINEAR')
            if FLAGS.random_brightness:
                images = tf.image.random_brightness(images, max_delta=0.3)
            if FLAGS.random_contrast:
                images = tf.image.random_contrast(images, 0.8, 1.2)
            return images
     
        def input_pipeline(self, batch_size, num_epochs=None, aug=False):
            images_tensor = tf.convert_to_tensor(self.image_names, dtype=tf.string)
            labels_tensor = tf.convert_to_tensor(self.labels, dtype=tf.int64)
            input_queue = tf.train.slice_input_producer([images_tensor, labels_tensor], num_epochs=num_epochs)
     
            labels = input_queue[1]
            images_content = tf.read_file(input_queue[0])
            images = tf.image.convert_image_dtype(tf.image.decode_png(images_content, channels=1), tf.float32)
            if aug:
                images = self.data_augmentation(images)
            new_size = tf.constant([FLAGS.image_size, FLAGS.image_size], dtype=tf.int32)
            images = tf.image.resize_images(images, new_size)
            image_batch, label_batch = tf.train.shuffle_batch([images, labels], batch_size=batch_size, capacity=50000,
                                                              min_after_dequeue=10000)
            return image_batch, label_batch
     
     
    def build_graph(top_k):
        # with tf.device('/cpu:0'):
        keep_prob = tf.placeholder(dtype=tf.float32, shape=[], name='keep_prob')
        images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1], name='image_batch')
        labels = tf.placeholder(dtype=tf.int64, shape=[None], name='label_batch')
     
        conv_1 = slim.conv2d(images, 64, [3, 3], 1, padding='SAME', scope='conv1')
        max_pool_1 = slim.max_pool2d(conv_1, [2, 2], [2, 2], padding='SAME')
        conv_2 = slim.conv2d(max_pool_1, 128, [3, 3], padding='SAME', scope='conv2')
        max_pool_2 = slim.max_pool2d(conv_2, [2, 2], [2, 2], padding='SAME')
        conv_3 = slim.conv2d(max_pool_2, 256, [3, 3], padding='SAME', scope='conv3')
        max_pool_3 = slim.max_pool2d(conv_3, [2, 2], [2, 2], padding='SAME')
        conv_4 = slim.conv2d(max_pool_3, 512, [3, 3], [2, 2], scope="conv4", padding="SAME")
        max_pool_4 = slim.max_pool2d(conv_4, [2, 2], [2, 2], padding="SAME")
     
        flatten = slim.flatten(max_pool_4)
     
        fc1 = slim.fully_connected(slim.dropout(flatten, keep_prob), 1024, activation_fn=tf.nn.tanh, scope='fc1')
        logits = slim.fully_connected(slim.dropout(fc1, keep_prob), FLAGS.charset_size, activation_fn=None, scope='fc2')
            # logits = slim.fully_connected(flatten, FLAGS.charset_size, activation_fn=None, reuse=reuse, scope='fc')
        loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels))
        accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels), tf.float32))
     
        global_step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False)
        rate = tf.train.exponential_decay(2e-4, global_step, decay_steps=2000, decay_rate=0.97, staircase=True)
        train_op = tf.train.AdamOptimizer(learning_rate=rate).minimize(loss, global_step=global_step)
        probabilities = tf.nn.softmax(logits)
     
        tf.summary.scalar('loss', loss)
        tf.summary.scalar('accuracy', accuracy)
        merged_summary_op = tf.summary.merge_all()
        predicted_val_top_k, predicted_index_top_k = tf.nn.top_k(probabilities, k=top_k)
        accuracy_in_top_k = tf.reduce_mean(tf.cast(tf.nn.in_top_k(probabilities, labels, top_k), tf.float32))
     
        return {'images': images,
                'labels': labels,
                'keep_prob': keep_prob,
                'top_k': top_k,
                'global_step': global_step,
                'train_op': train_op,
                'loss': loss,
                'accuracy': accuracy,
                'accuracy_top_k': accuracy_in_top_k,
                'merged_summary_op': merged_summary_op,
                'predicted_distribution': probabilities,
                'predicted_index_top_k': predicted_index_top_k,
                'predicted_val_top_k': predicted_val_top_k}
     
     
    def train():
        print('Begin training')
        train_feeder = DataIterator(FLAGS.train_data_dir)
        test_feeder = DataIterator(FLAGS.test_data_dir)
        with tf.Session() as sess:
            train_images, train_labels = train_feeder.input_pipeline(batch_size=FLAGS.batch_size, aug=True)
            test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size)
            graph = build_graph(top_k=1)
            sess.run(tf.global_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            saver = tf.train.Saver()
     
            train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
            test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/val')
            start_step = 0
            if FLAGS.restore:
                ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
                if ckpt:
                    saver.restore(sess, ckpt)
                    print("restore from the checkpoint {0}".format(ckpt))
                    start_step += int(ckpt.split('-')[-1])
     
            logger.info(':::Training Start:::')
            try:
                while not coord.should_stop():
                    start_time = time.time()
                    train_images_batch, train_labels_batch = sess.run([train_images, train_labels])
                    feed_dict = {graph['images']: train_images_batch,
                                 graph['labels']: train_labels_batch,
                                 graph['keep_prob']: 0.8}
                    _, loss_val, train_summary, step = sess.run(
                        [graph['train_op'], graph['loss'], graph['merged_summary_op'], graph['global_step']],
                        feed_dict=feed_dict)
                    train_writer.add_summary(train_summary, step)
                    end_time = time.time()
                    logger.info("the step {0} takes {1} loss {2}".format(step, end_time - start_time, loss_val))
                    if step > FLAGS.max_steps:
                        break
                    if step % FLAGS.eval_steps == 1:
                        test_images_batch, test_labels_batch = sess.run([test_images, test_labels])
                        feed_dict = {graph['images']: test_images_batch,
                                     graph['labels']: test_labels_batch,
                                     graph['keep_prob']: 1.0}
                        accuracy_test, test_summary = sess.run(
                            [graph['accuracy'], graph['merged_summary_op']],
                            feed_dict=feed_dict)
                        test_writer.add_summary(test_summary, step)
                        logger.info('===============Eval a batch=======================')
                        logger.info('the step {0} test accuracy: {1}'
                                    .format(step, accuracy_test))
                        logger.info('===============Eval a batch=======================')
                    if step % FLAGS.save_steps == 1:
                        logger.info('Save the ckpt of {0}'.format(step))
                        saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'),
                                   global_step=graph['global_step'])
            except tf.errors.OutOfRangeError:
                logger.info('==================Train Finished================')
                saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'), global_step=graph['global_step'])
            finally:
                coord.request_stop()
            coord.join(threads)
     
     
    def validation():
        print('validation')
        test_feeder = DataIterator(FLAGS.test_data_dir)
     
        final_predict_val = []
        final_predict_index = []
        groundtruth = []
     
        with tf.Session() as sess:
            test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size, num_epochs=1)
            graph = build_graph(top_k=3)
     
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())  # initialize test_feeder's inside state
     
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
     
            saver = tf.train.Saver()
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                saver.restore(sess, ckpt)
                print("restore from the checkpoint {0}".format(ckpt))
     
            print(':::Start validation:::')
            try:
                i = 0
                acc_top_1, acc_top_k = 0.0, 0.0
                while not coord.should_stop():
                    i += 1
                    start_time = time.time()
                    test_images_batch, test_labels_batch = sess.run([test_images, test_labels])
                    feed_dict = {graph['images']: test_images_batch,
                                 graph['labels']: test_labels_batch,
                                 graph['keep_prob']: 1.0}
                    batch_labels, probs, indices, acc_1, acc_k = sess.run([graph['labels'],
                                                                           graph['predicted_val_top_k'],
                                                                           graph['predicted_index_top_k'],
                                                                           graph['accuracy'],
                                                                           graph['accuracy_top_k']], feed_dict=feed_dict)
                    final_predict_val += probs.tolist()
                    final_predict_index += indices.tolist()
                    groundtruth += batch_labels.tolist()
                    acc_top_1 += acc_1
                    acc_top_k += acc_k
                    end_time = time.time()
                    logger.info("the batch {0} takes {1} seconds, accuracy = {2}(top_1) {3}(top_k)"
                                .format(i, end_time - start_time, acc_1, acc_k))
     
            except tf.errors.OutOfRangeError:
                logger.info('==================Validation Finished================')
                acc_top_1 = acc_top_1 * FLAGS.batch_size / test_feeder.size
                acc_top_k = acc_top_k * FLAGS.batch_size / test_feeder.size
                logger.info('top 1 accuracy {0} top k accuracy {1}'.format(acc_top_1, acc_top_k))
            finally:
                coord.request_stop()
            coord.join(threads)
        return {'prob': final_predict_val, 'indices': final_predict_index, 'groundtruth': groundtruth}
     
     
    def inference(image):
        print('inference')
        temp_image = Image.open(image).convert('L')
        temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size), Image.ANTIALIAS)
        temp_image = np.asarray(temp_image) / 255.0
        temp_image = temp_image.reshape([-1, 64, 64, 1])
        with tf.Session() as sess:
            logger.info('========start inference============')
            # images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1])
            # Pass a shadow label 0. This label will not affect the computation graph.
            graph = build_graph(top_k=3)
            saver = tf.train.Saver()
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                saver.restore(sess, ckpt)
            predict_val, predict_index = sess.run([graph['predicted_val_top_k'], graph['predicted_index_top_k']],
                                                  feed_dict={graph['images']: temp_image, graph['keep_prob']: 1.0})
        return predict_val, predict_index
     
     
    def main(_):
        print(FLAGS.mode)
        if FLAGS.mode == "train":
            train()
        elif FLAGS.mode == 'validation':
            dct = validation()
            result_file = 'result.dict'
            logger.info('Write result into {0}'.format(result_file))
            with open(result_file, 'wb') as f:
                pickle.dump(dct, f)
            logger.info('Write file ends')
        elif FLAGS.mode == 'inference':
            image_path = './data/00098/102544.png'
            final_predict_val, final_predict_index = inference(image_path)
            logger.info('the result info label {0} predict index {1} predict_val {2}'.format(190, final_predict_index,
                                                                                             final_predict_val))
     
    if __name__ == "__main__":
        tf.app.run()
    
  • 相关阅读:
    安卓手机的弱网工具
    渗透测试工具之sqlmap
    渗透测试基础之sql注入
    去哪儿网2017校招在线笔试(前端工程师)编程题及JavaScript代码
    滴滴出行2017秋招工程岗笔试题(0918)编程题
    【面试经历】再惠网络、远景能源、东软集团
    二叉树前序、中序、后序遍历相互求法
    58集团2017校招笔试-前端岗
    途牛前端工程师在线笔试题(含答案和全面解析)
    【经典面试题二】二叉树的递归与非递归遍历(前序、中序、后序)
  • 原文地址:https://www.cnblogs.com/bonelee/p/8952748.html
Copyright © 2020-2023  润新知