TF-slim 模块是TensorFLow中比较实用的API之一,是一个用于模型构建、训练、评估复杂模型的轻量化库。
最近,在使用TF-slim API编写了一些项目模型后,发现TF-slim模块在搭建网络模型时具有相同的编写模式。这个编写模式主要包含四个部分:
- __init__():
- build_model():
- fit():
- predict():
1. __init__():
这部分相当于是一个main()函数,其中包含参数的设置,模型整体的连接等操作。具体来说:
a. 设置参数
由于是类的构造函数,所以需要在其中设置一些模型网络结构的参数、模型训练时的参数等等。例如
- 学习率
- batch_size
- 训练代数
- 各种文件的存放地址
- ...
- 对于网络结构复杂的模型,还可以将网络结构的table以列表的形式进行保存。便于后续建立模型时可以循环获取每层的超参数。
1 self.lr = lr 2 self.batch_size = batch_size 3 self.epoch = epoch 4 self.checkpoint_dir_load = checkpoint_dir 5 self.checkpoint_dir = os.path.join(checkpoint_dir, filename + ".ckpt") 6 self.logdir = logdir 7 self.result_dir = result_dir
b. 设置输入、输出的占位符placeholder
由于TF-slim框架仍然采用的是tensorflow的那一套,不像tf.keras可以使用keras.layer.Input(),所以还需要使用占位符。例如
1 self.input_image = tf.placeholder(tf.float32, shape=[None, 6000]) 2 self.input_image_raw = tf.reshape(self.input_image, shape=[-1, 6000, 1]) 3 4 self.input_image_label = tf.placeholder(tf.float32, shape=[None, 1, 10]) 5 self.input_label = tf.reshape(self.input_image_label, shape=[-1, 10])
c. 初始化网络结构,生成训练输出和测试输出
用于后续损失的计算以及优化器的生成,以及训练结果和测试结果的调用。
此处会涉及到网络参数的重用,需要使用tf.variable_scope()来管理参数。
1 with tf.variable_scope("Network_Structure") as scope: 2 self.train_digits = self.build_model(is_trained=True) 3 scope.reuse_variables() 4 self.test_digits = self.build_model(is_trained=False)
d. 损失函数和优化器的声明
此处损失声明使用的是 输出的占位符和训练的输出。例如:
1 self.loss = slim.losses.softmax_cross_entropy(logits=self.train_digits, onehot_labels=self.input_label, scope="loss") 2 3 self.optimizer = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(loss=self.loss)
e. 最终训练输出结果和测试输出结果的计算
由于网络输出的结果不一定是最终的结果。对于多分类问题,需要将one_hot编码的结果显示为类值;对于回归问题,输出结果可能会需要反归一化。等等..
如下述代码,多分类问题的one_hot转化为类标签,并进行准确率的计算。
1 # result and accuracy of test 2 self.predicts = tf.math.argmax(self.test_digits, 1) # 将one_hot转化为类标签 3 self.test_correction = tf.equal(self.predicts, tf.math.argmax(self.input_label, 1)) 4 self.accuracy = tf.reduce_mean(tf.cast(self.test_correction, "float")) 5 tf.summary.scalar("test_accuracy", self.accuracy) 6 7 # result and accuracy of train 8 self.train_result = tf.math.argmax(self.train_digits, 1) 9 self.train_correlation = tf.equal(self.train_result, tf.math.argmax(self.input_label, 1)) 10 self.train_accuracy = tf.reduce_mean(tf.cast(self.train_correlation, "float")) 11 tf.summary.scalar("train_accuracy", self.accuracy)
2. build_model():【可以是别的名字】
这部分是为了使用tf-slim搭建网络模型结构。有些模型可能一个函数实现不了,需要多个函数。例如具有共享层的Siamese Network,在共享层后还有其他层。
这一部分也实现了如同tf.keras搭建的模型"乐高式"堆叠,不需要手动为各层生成权重、偏执等参数。也是代码瘦身的重要环节。
1 with slim.arg_scope([slim.conv1d], padding="SAME", stride=2, activation_fn=tf.nn.relu, 2 weights_initializer=tf.truncated_normal_initializer(stddev=0.01), 3 weights_regularizer=slim.l2_regularizer(0.005) 4 ): 5 net = slim.conv1d(self.input_image_raw, num_outputs=16, kernel_size=8, padding="VALID", scope='conv_1') 6 tf.summary.histogram("conv_1", net) 7 net = slim.conv1d(net, num_outputs=16, kernel_size=8, scope='conv_2') 8 tf.summary.histogram("conv_2", net) 9 def_max_pool = tf.layers.MaxPooling1D(pool_size=2, strides=2, padding="VALID", name="max_pool_3") 10 net = def_max_pool(net) 11 # net = slim.nn.max_pool1d(net, ksize=2, strides=None, padding="VALID", data_format="NWC", name="max_pool_3") 12 tf.summary.histogram("max_pool_3", net) 13 net = slim.conv1d(net, num_outputs=64, kernel_size=4, scope="conv_4") 14 tf.summary.histogram("conv_4", net) 15 net = slim.conv1d(net, num_outputs=64, kernel_size=4, scope="conv_5") 16 tf.summary.histogram("conv_5", net) 17 def_max_pool = tf.layers.MaxPooling1D(pool_size=2, strides=2, padding="VALID", name="max_pool_6") 18 net = def_max_pool(net) 19 # net = slim.nn.max_pool1d(net, ksize=2, strides=1, padding="VALID", name="max_pool_6") 20 tf.summary.histogram("max_pool_6", net) 21 net = slim.conv1d(net, num_outputs=256, kernel_size=4, scope="conv_7") 22 tf.summary.histogram("conv_7", net) 23 net = slim.conv1d(net, num_outputs=256, kernel_size=4, scope="conv_8") 24 tf.summary.histogram("conv_8", net) 25 def_max_pool = tf.layers.MaxPooling1D(pool_size=2, strides=2, padding="VALID", name="max_pool_9") 26 net = def_max_pool(net) 27 # net = slim.nn.max_pool1d(net, ksize=1, strides=1, padding="VALID", name="max_pool_9") 28 tf.summary.histogram("max_pool_9", net) 29 net = slim.conv1d(net, num_outputs=512, kernel_size=2, stride=1, scope="conv_10") 30 tf.summary.histogram("conv_10", net) 31 net = slim.conv1d(net, num_outputs=512, kernel_size=2, stride=1, scope="conv_11") 32 tf.summary.histogram("conv_11", net) 33 def_max_pool = tf.layers.MaxPooling1D(pool_size=2, strides=2, padding="VALID", name="max_pool_12") 34 net = def_max_pool(net) 35 # net = slim.nn.max_pool1d(net, ksize=1, strides=1, padding="VALID", name="max_pool_12") 36 tf.summary.histogram("max_pool_12", net) 37 net = tf.reduce_mean(net, axis=1, name="global_max_pool_13") # 起全局平均池化的作用 38 tf.summary.histogram("global_max_pool_13", net) 39 net = slim.dropout(net, keep_prob=0.5, scope="dropout") 40 tf.summary.histogram("dropout", net) 41 digits = slim.fully_connected(net, num_outputs=num_class, activation_fn=tf.nn.softmax, scope="fully_connected_14") 42 tf.summary.histogram("fully_connected_14", digits) 43 return digits
3. fit():
看名字就知道这一部分需要完成的是训练部分的代码。
这一部分需要包含会话的启动、模型保存器的初始化、循环迭代、batch设置、数据集输入、输出数据获取、喂到网络中、保存模型、会话关闭等操作。如下述代码
1 sess = tf.Session() # 启动会话 2 3 merge_summary_op = tf.summary.merge_all() 4 summary_writer = tf.summary.FileWriter(self.logdir, sess.graph) 5 6 saver = tf.train.Saver(max_to_keep=1) # 生成保存器 7 sess.run(tf.global_variables_initializer()) # 变量激活 8 9 for step in range(self.epoch): # 迭代 10 print("Epoch:%d"%step) 11 avg_cost = 0 12 acc = 0 13 total_batch = int(input_x.shape[0]/self.batch_size) # 划分batch 14 for batch_num in range(total_batch): # batch迭代 15 # 获取数据 16 batch_xs = input_x[batch_num*self.batch_size:(batch_num+1)*self.batch_size, :] 17 batch_ys = input_y[batch_num*self.batch_size:(batch_num+1)*self.batch_size, :] 18 batch_ys = sess.run(tf.one_hot(batch_ys, depth=10)) 19 # 喂到损失 优化器等等 20 _, loss, acc = sess.run([self.optimizer, self.loss, self.train_accuracy], 21 feed_dict={self.input_image: batch_xs, 22 self.input_image_label: batch_ys}) 23 avg_cost += loss / total_batch 24 acc += acc /total_batch 25 26 summary_str = sess.run(merge_summary_op, feed_dict={self.input_image: batch_xs, 27 self.input_image_label: batch_ys}) 28 summary_writer.add_summary(summary_str, global_step=step) 29 print("Epoch:%d, batch: %d, avg_cost: %g, accuracy: %g" % (step, batch_num, avg_cost, acc)) 30 # 保存模型 31 saver.save(sess, self.checkpoint_dir, global_step=step) 32 sess.close() # 会话关闭
4. predict():
从函数名可以知道这一部分是实现预测部分的代码。其相对于训练的过程要更简单。主要包括会话的启动、保存器的生成、权重的导入(模型的恢复)、预测、关闭会话。如下述代码
1 sess = tf.Session() # 会话的启动 2 3 saver = tf.train.Saver() # 保存器的生成 4 5 module_file = tf.train.latest_checkpoint(self.checkpoint_dir_load) 6 saver.restore(sess, module_file) # 模型的恢复 7 8 input_y = sess.run(tf.one_hot(input_y, depth=10)) # 获取输出 9 # 获取预测结果和预测精度 10 predicts, acc_test = sess.run([self.predicts, self.accuracy], feed_dict={self.input_image: input_x, 11 # 关闭会话 self.input_image_label: input_y}) 12 sess.close() 13 # print("test_accuracy: %f" %acc_test) 14 return predicts, acc_test
上述四步完成后,便可以编写一个main函数来调用这个类,实现需要的功能。.fit()和.predict()主要是在main()函数来调用。