24

TensorFlow读写数据

 5 years ago
source link: https://segmentfault.com/a/1190000018530098?amp%3Butm_medium=referral
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

前言

只有光头才能变强。

文本已收录至我的GitHub仓库,欢迎Star: https://github.com/ZhongFuCheng3y/3y

回顾前面:

众所周知,要训练出一个模型,首先我们得有数据。我们第一个例子中,直接使用dataset的api去加载mnist的数据。(minst的数据要么我们是提前下载好,放在对应的目录上,要么就根据他给的url直接从网上下载)。

一般来说,我们使用TensorFlow是从TFRecord文件中读取数据的。

TFRecord 文件格式是一种面向记录的简单 二进制格式 ,很多 TensorFlow 应用采用此格式来训练数据

所以,这篇文章来聊聊怎么 读取 TFRecord文件的数据。

一、入门对数据集的数据进行读和写

首先,我们来体验一下怎么造一个TFRecord文件,怎么从TFRecord文件中读取数据,遍历(消费)这些数据。

1.1 造一个TFRecord文件

现在,我们还没有TFRecord文件,我们可以自己简单写一个:

def write_sample_to_tfrecord():
    gmv_values = np.arange(10)
    click_values = np.arange(10)
    label_values = np.arange(10)

    with tf.python_io.TFRecordWriter("/Users/zhongfucheng/data/fashin/demo.tfrecord", options=None) as writer:
        for _ in range(10):
            feature_internal = {
                "gmv": tf.train.Feature(float_list=tf.train.FloatList(value=[gmv_values[_]])),
                "click": tf.train.Feature(int64_list=tf.train.Int64List(value=[click_values[_]])),
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label_values[_]]))
            }
            features_extern = tf.train.Features(feature=feature_internal)

            # 使用tf.train.Example将features编码数据封装成特定的PB协议格式
            # example = tf.train.Example(features=tf.train.Features(feature=features_extern))
            example = tf.train.Example(features=features_extern)

            # 将example数据系列化为字符串
            example_str = example.SerializeToString()

            # 将系列化为字符串的example数据写入协议缓冲区
            writer.write(example_str)


if __name__ == '__main__':
    write_sample_to_tfrecord()

我相信大家代码应该是能够看得懂的,其实就是分了几步:

  • 生成TFRecord Writer
  • tf.train.Feature生成协议信息
  • 使用tf.train.Example将features编码数据封装成特定的PB协议格式
  • 将example数据系列化为字符串
  • 将系列化为字符串的example数据写入协议缓冲区

参考资料:

ok,现在我们就有了一个TFRecord文件啦。

1.2 读取TFRecord文件

tf.data.TFRecordDataset

demo代码如下:

import tensorflow as tf


def read_tensorflow_tfrecord_files():
    # 定义消费缓冲区协议的parser,作为dataset.map()方法中传入的lambda:
    def _parse_function(single_sample):
        features = {
            "gmv": tf.FixedLenFeature([1], tf.float32),
            "click": tf.FixedLenFeature([1], tf.int64),  # ()或者[]没啥影响
            "label": tf.FixedLenFeature([1], tf.int64)
        }
        parsed_features = tf.parse_single_example(single_sample, features=features)

        # 对parsed 之后的值进行cast.
        gmv = tf.cast(parsed_features["gmv"], tf.float64)
        click = tf.cast(parsed_features["click"], tf.float64)
        label = tf.cast(parsed_features["label"], tf.float64)

        return gmv, click, label

    # 开始定义dataset以及解析tfrecord格式
    filenames = tf.placeholder(tf.string, shape=[None])

    # 定义dataset 和 一些列trasformation method
    dataset = tf.data.TFRecordDataset(filenames)
    parsed_dataset = dataset.map(_parse_function)  # 消费缓冲区需要定义在dataset 的map 函数中
    batchd_dataset = parsed_dataset.batch(3)

    # 创建Iterator
    sample_iter = batchd_dataset.make_initializable_iterator()
    # 获取next_sample
    gmv, click, label = sample_iter.get_next()
    training_filenames = [
        "/Users/zhongfucheng/data/fashin/demo.tfrecord"]
    with tf.Session() as session:
        # 初始化带参数的Iterator
        session.run(sample_iter.initializer, feed_dict={filenames: training_filenames})
        # 读取文件
        print(session.run(gmv))


if __name__ == '__main__':
    read_tensorflow_tfrecord_files()

无意外的话,我们可以输出这样的结果:

[[0.]
 [1.]
 [2.]]

ok,现在我们已经大概知道怎么写一个TFRecord文件,以及怎么读取TFRecord文件的数据,并且消费这些数据了。

二、epoch和batchSize术语解释

我在学习TensorFlow翻阅资料时,经常看到一些机器学习的术语,由于自己没啥机器学习的基础,所以很多时候看到一些专业名词就开始懵逼了。

2.1epoch

当一个完整的数据集通过了神经网络一次并且返回了一次,这个过程称为一个epoch。

这可能使我们跟 dataset.repeat() 方法联系起来,这个方法可以使当前数据集 重复 一遍。比如说,原有的数据集是 [1,2,3,4,5] ,如果我调用 dataset.repeat(2) 的话,那么我们的数据集就变成了 [1,2,3,4,5],[1,2,3,4,5]

  • 所以会有个说法:假设原先的数据是一个epoch,使用repeat(5)就可以将之变成5个epoch

2.2batchSize

一般来说我们的数据集都是比较大的, 无法一次性 将整个数据集的数据喂进神经网络中,所以我们会将数据集分成好几个部分。每次喂多少条样本进神经网络,这个叫做batchSize。

在TensorFlow也提供了方法给我们设置: dataset.batch() ,在API中是这样介绍batchSize的:

representing the number of consecutive elements of this dataset to combine in a single batch

我们一般在每次训练之前,会将 整个数据集的顺序打乱 ,提高我们模型训练的效果。这里我们用到的api是: dataset.shffle();

三、再来聊聊dataset

我从官网的介绍中截了一个dataset的方法图(部分):

2IRrmy6.jpg!web

dataset的功能主要有以下三种:

  • 创建dataset实例

    • 通过文件创建(比如TFRecord)
    • 通过内存创建
  • 对数据集的数据进行变换

    • 比如上面的batch(),常见的 map(),flat_map(),zip(),repeat() 等等
    • 文档中一般都有给出 例子 ,跑一下一般就知道对应的意思了。
  • 创建迭代器,遍历数据集的数据

3.1 聊聊迭代器

迭代器可以分为四种:

  • 单次。对数据集进行一次迭代,不支持参数化
  • 可初始化迭代

    • 使用前需要进行初始化, 支持传入参数 。面向的是同一个DataSet
  • 可重新初始化:同一个Iterator从不同的DataSet中读取数据

    • DataSet的对象具有相同的结构,可以使用 tf.data.Iterator.from_structure 来进行初始化
    • 问题: 每次 Iterator 切换时,数据都从头开始打印了
  • 可馈送(也是通过对象相同的结果来创建的迭代器)

    • 可让您在 两个数据集之间切换 的可馈送迭代器
    • 通过一个string handler来实现。
    • 可馈送的 Iterator 在不同的 Iterator 切换的时候, 可以做到不从头开始

简单总结:

  • 1、 单次 Iterator ,它最简单,但无法重用,无法处理数据集参数化的要求。
  • 2、 可以初始化的 Iterator ,它可以满足 Dataset 重复加载数据,满足了参数化要求。
  • 3、可重新初始化的 Iterator,它可以对接不同的 Dataset,也就是可以从不同的 Dataset 中读取数据。
  • 4、可馈送的 Iterator,它可以通过 feeding 的方式,让程序在运行时候选择正确的 Iterator,它和可重新初始化的 Iterator 不同的地方就是它的数据在不同的 Iterator 切换时, 可以做到不重头开始读取数据

string handler(可馈送的 Iterator)这种方式是最常使用的,我当时也写了一个Demo来使用了一下,代码如下:

def read_tensorflow_tfrecord_files():
    # 开始定义dataset以及解析tfrecord格式.
    train_filenames = tf.placeholder(tf.string, shape=[None])
    vali_filenames = tf.placeholder(tf.string, shape=[None])

    # 加载train_dataset   batch_inputs这个方法每个人都不一样的,这个方法我就不给了。
    train_dataset = batch_inputs([
        train_filenames], batch_size=5, type=False,
        num_epochs=2, num_preprocess_threads=3)
    # 加载validation_dataset  batch_inputs这个方法每个人都不一样的,这个方法我就不给了。
    validation_dataset = batch_inputs([vali_filenames
                                       ], batch_size=5, type=False,
                                      num_epochs=2, num_preprocess_threads=3)

    # 创建出string_handler()的迭代器(通过相同数据结构的dataset来构建)
    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, train_dataset.output_types, train_dataset.output_shapes)

    # 有了迭代器就可以调用next方法了。
    itemid = iterator.get_next()

    # 指定哪种具体的迭代器,有单次迭代的,有初始化的。
    training_iterator = train_dataset.make_initializable_iterator()
    validation_iterator = validation_dataset.make_initializable_iterator()

    # 定义出placeholder的值
    training_filenames = [
        "/Users/zhongfucheng/tfrecord_test/data01aa"]
    validation_filenames = ["/Users/zhongfucheng/tfrecord_validation/part-r-00766"]

    with tf.Session() as sess:
        # 初始化迭代器
        training_handle = sess.run(training_iterator.string_handle())
        validation_handle = sess.run(validation_iterator.string_handle())

        for _ in range(2):
            sess.run(training_iterator.initializer, feed_dict={train_filenames: training_filenames})
            print("this is training iterator ----")

            for _ in range(5):
                print(sess.run(itemid, feed_dict={handle: training_handle}))

            sess.run(validation_iterator.initializer,
                     feed_dict={vali_filenames: validation_filenames})

            print("this is validation iterator ")
            for _ in range(5):
                print(sess.run(itemid, feed_dict={vali_filenames: validation_filenames, handle: validation_handle}))


if __name__ == '__main__':
    read_tensorflow_tfrecord_files()

参考资料:

3.2 dataset参考资料

在翻阅资料时,发现写得不错的一些博客:

最后

乐于输出 干货 的Java技术公众号:Java3y。公众号内有200多篇 原创 技术文章、海量视频资源、精美脑图,不妨来 关注 一下!

下一篇文章打算讲讲如何理解axis~

aERBfaR.jpg!web

觉得我的文章写得不错,不妨点一下


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK