TFRecords存储和读取:
什么是TFRecords:
Example结构解析:
写:
def write_to_tfrecords(self, image_batch, label_batch): """ 将样本的特征值和目标值一起写入tfrecords文件 :param image: :param label: :return: """ with tf.compat.v1.python_io.TFRecordWriter("cifar10.tfrecords") as writer: # 循环构造example对象,并序列化写入文件 for i in range(100): image = image_batch[i].tostring() label = label_batch[i][0] # print("tfrecords_image: ", image) # print("tfrecords_label: ", label) example = tf.train.Example(features=tf.train.Features(feature={ "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])), "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])), })) # example.SerializeToString() # 将序列化后的example写入文件 writer.write(example.SerializeToString()) return None
读:
def read_tfrecords(self): """ 读取TFRecords文件 :return: """ # 1、构造文件名队列 file_queue = tf.compat.v1.train.string_input_producer(["cifar10.tfrecords"]) # 2、读取与解码 # 读取 reader = tf.compat.v1.TFRecordReader() key, value = reader.read(file_queue) # 解析example feature = tf.compat.v1.parse_single_example(value, features={ "image": tf.compat.v1.FixedLenFeature([], tf.string), "label": tf.compat.v1.FixedLenFeature([], tf.int64) }) image = feature["image"] label = feature["label"] print("read_tf_image: ", image) print("read_tf_label: ", label) # 解码 image_decoded = tf.compat.v1.decode_raw(image, tf.uint8) print("image_decoded: ", image_decoded) # 图像形状调整 image_reshaped = tf.reshape(image_decoded, [self.height, self.width, self.channel]) print("image_reshaped: ", image_reshaped) # 3、构造批处理队列 image_batch, label_batch = tf.compat.v1.train.batch([image_reshaped, label], batch_size=100, num_threads=2, capacity=100) print("image_batch: ", image_batch) print("label_batch: ", label_batch) # 开启会话 with tf.compat.v1.Session() as sess: # 开启线程 coord = tf.train.Coordinator() threads = tf.compat.v1.train.start_queue_runners(sess=sess, coord=coord) image_value, label_value = sess.run([image_batch, label_batch]) print("image_value: ", image_value) print("label_value: ", label_value) # 回收资源 coord.request_stop() coord.join(threads) return None
神经网络:
感知机:
主要用途:
softmax回归:
交叉熵损失: