• BankNote


     1 # coding=utf-8
     2 import pandas as pd
     3 import numpy as np
     4 from sklearn import cross_validation
     5 import tensorflow as tf
     6 
     7 global flag
     8 flag=0
     9 
    10 def DataPreprocessing():
    11     abalone = pd.read_csv("ceshi.csv", sep=',', header=0, keep_default_na=True,na_values=[])
    12     X_train=np.array(abalone.iloc[:,:4])
    13     Y_train=np.array(abalone.iloc[:,4:])
    14     # Y_train=[]
    15     # for i in range(len(X_train)):
    16     #     if X_train[i][0] == 'M':
    17     #         X_train[i][0]=0
    18     #     elif X_train[i][0]=='F':
    19     #         X_train[i][0]=1
    20     #     else:
    21     #         X_train[i][0]=2
    22     #
    23     # for i in range(len(Y_train_)):
    24     #
    25     #     #print(Y_train[i][0])
    26     #     Y_train.append(Y_train_[i][0])
    27 
    28     # print(X_train)
    29     # print(len(X_train))
    30     # print(Y_train)
    31     # print(len(Y_train))
    32    # print(min(Y_train))
    33    # print(max(Y_train))
    34 
    35     return cross_validation.train_test_split(X_train,Y_train,test_size=0.25,random_state=0,stratify=Y_train)
    36 
    37 
    38 def GetInputs():
    39     global flag
    40     X_train, X_test, Y_train, Y_test = DataPreprocessing()
    41 
    42     #print(X_train)
    43     # print(len(X_test))
    44     # print(len(Y_train))
    45     # print(len(Y_test))
    46 
    47 
    48     #X_train[X_train.isnull().any(axis=1)]
    49     #X_train.fillna('',inplace=True)
    50 
    51     print(X_train)
    52     print(Y_test)
    53 
    54     x_train=tf.constant(X_train)
    55     y_train=tf.constant(Y_train)
    56     x_test=tf.constant(X_test)
    57     y_test=tf.constant(Y_test)
    58 
    59     print(x_train)
    60     print(y_train)
    61     print(x_test)
    62     print(y_test)
    63 
    64     if flag==0:
    65         return x_train,y_train
    66     else:
    67         return x_test,y_test
    68 
    69 
    70 def Main():
    71 
    72     global flag
    73 
    74     feature_columns=[tf.contrib.layers.real_valued_column("",dimension=4)]
    75 
    76     clf=tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,hidden_units=[10,20,10],n_classes=2,model_dir="/home/jiangjing/TensorflowModel/banknote")
    77 
    78     clf.fit(input_fn=GetInputs,steps=2000)
    79 
    80     flag=1
    81     accuracy_score=clf.evaluate(input_fn=GetInputs,steps=1)["accuracy"]
    82 
    83     print("nTest Accuracy:{0:f}".format(accuracy_score))
    84 
    85 if __name__ =="__main__":
    86     #DataPreprocessing()
    87 
    88     Main()
    89 
    90 exit(0)
  • 相关阅读:
    Java程序员必知的8大排序
    java提高篇-----理解java的三大特性之封装
    树莓派学习笔记——GPIO功能学习
    SQL 服务没有及时响应启动或控制请求”的解决方法
    http://blog.csdn.net/u011001723/article/details/45621027
    error
    Spring @Conditional注解的使用
    Python [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed 解决方法
    python
    local class incompatible: stream classdesc serialVersionUID = -2897844985684768944, local class serialVersionUID = 7350468743759137184
  • 原文地址:https://www.cnblogs.com/acm-jing/p/9097373.html
Copyright © 2020-2023  润新知