• TensorFlow高层次机器学习API (tf.contrib.learn)


    TensorFlow高层次机器学习API (tf.contrib.learn)

    1.tf.contrib.learn.datasets.base.load_csv_with_header 加载csv格式数据

    2.tf.contrib.learn.DNNClassifier 建立DNN模型(classifier)

    3.classifer.fit 训练模型

    4.classifier.evaluate 评价模型

    5.classifier.predict 预测新样本

    完整代码:

    复制代码
     1 from __future__ import absolute_import
     2 from __future__ import division
     3 from __future__ import print_function
     4 
     5 import tensorflow as tf
     6 import numpy as np
     7 
     8 # Data sets
     9 IRIS_TRAINING = "iris_training.csv"
    10 IRIS_TEST = "iris_test.csv"
    11 
    12 # Load datasets.
    13 training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    14     filename=IRIS_TRAINING,
    15     target_dtype=np.int,
    16     features_dtype=np.float32)
    17 test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    18     filename=IRIS_TEST,
    19     target_dtype=np.int,
    20     features_dtype=np.float32)
    21 
    22 # Specify that all features have real-value data
    23 feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
    24 
    25 # Build 3 layer DNN with 10, 20, 10 units respectively.
    26 classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
    27                                             hidden_units=[10, 20, 10],
    28                                             n_classes=3,
    29                                             model_dir="/tmp/iris_model")
    30 
    31 # Fit model.
    32 classifier.fit(x=training_set.data,
    33                y=training_set.target,
    34                steps=2000)
    35 
    36 # Evaluate accuracy.
    37 accuracy_score = classifier.evaluate(x=test_set.data,
    38                                      y=test_set.target)["accuracy"]
    39 print('Accuracy: {0:f}'.format(accuracy_score))
    40 
    41 # Classify two new flower samples.
    42 new_samples = np.array(
    43     [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
    44 y = list(classifier.predict(new_samples, as_iterable=True))
    45 print('Predictions: {}'.format(str(y)))
    复制代码

     结果:

    Accuracy:0.966667

  • 相关阅读:
    python学习笔记二--列表
    python学习笔记一--字符串
    写点什么呢
    nagios&pnp4nagios--yum 安装
    敏捷开发的思路
    Foreman--管理PuppetClient
    url编码解码的问题(urlencode/quote)
    json数据的处理和转化(loads/load/dump/dumps)
    http和https的区别
    python中requests的用法总结
  • 原文地址:https://www.cnblogs.com/bonelee/p/7903436.html
Copyright © 2020-2023  润新知