拟合上图的直线,这被称为线性回归,是机器学习最简单的入门问题,理论部分这里就不再赘述了。
线性回归的损失函数即为普通的平方均差。
tensorflow实现如下:
# Linear Regression based on tensorflow
import tensorflow as tf
# training data
x_train = [1, 2, 3]
y_train = [1, 2, 3]
W = tf.Variable(tf.random_normal([1]), name="weight")
b = tf.Variable(tf.random_normal([1]), name="bias")
# define hypothesis W*X+b
hypothesis = x_train * W + b
# define loss function
cost = tf.reduce_mean(tf.square(hypothesis - y_train))
# minimize the cost
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) # choose optimizer
train = optimizer.minimize(cost)
# Launch the computering graph in a session
sess = tf.Session()
# Initializes global variables in the graph
sess.run(tf.global_variables_initializer())
# training
for step in range(2001):
sess.run(train)
if step % 20 == 0:
print(step, sess.run(cost), sess.run(W), sess.run(b))
0 35.765697 [-0.9668263] [-1.8271425]
20 0.40806618 [1.1168443] [-0.865325]
40 0.07979807 [1.3009574] [-0.74125075]
60 0.0698372 [1.3048675] [-0.6984728]
80 0.06340335 [1.292259] [-0.66489094]
100 0.057583764 [1.2786876] [-0.6335716]
120 0.052298497 [1.2656059] [-0.6037893]
140 0.047498345 [1.2531248] [-0.5754128]
160 0.043138783 [1.2412292] [-0.5483706]
180 0.03917933 [1.2298924] [-0.52259916]
200 0.035583306 [1.2190882] [-0.49803898]
220 0.032317337 [1.2087919] [-0.47463304]
240 0.02935112 [1.1989795] [-0.45232704]
260 0.02665715 [1.1896281] [-0.43106934]
280 0.024210446 [1.1807164] [-0.4108107]
300 0.021988325 [1.1722233] [-0.3915041]
320 0.019970147 [1.1641294] [-0.37310484]
340 0.018137185 [1.1564159] [-0.3555702]
360 0.016472481 [1.149065] [-0.3388597]
380 0.014960594 [1.1420594] [-0.3229346]
400 0.013587452 [1.1353832] [-0.30775788]
420 0.01234033 [1.1290207] [-0.29329437]
440 0.011207695 [1.1229572] [-0.27951062]
460 0.010179002 [1.1171786] [-0.2663746]
480 0.009244727 [1.1116717] [-0.2538559]
500 0.0083962055 [1.1064235] [-0.24192561]
520 0.0076255747 [1.1014222] [-0.23055606]
540 0.006925676 [1.0966555] [-0.21972084]
560 0.006290013 [1.0921131] [-0.20939474]
580 0.005712679 [1.0877842] [-0.19955397]
600 0.005188351 [1.0836587] [-0.1901757]
620 0.0047121416 [1.0797269] [-0.18123813]
640 0.0042796372 [1.0759801] [-0.1727206]
660 0.0038868487 [1.0724093] [-0.16460335]
680 0.0035300886 [1.0690063] [-0.15686758]
700 0.0032060833 [1.0657632] [-0.14949535]
720 0.0029118222 [1.0626727] [-0.14246966]
740 0.0026445538 [1.0597271] [-0.135774]
760 0.0024018288 [1.0569202] [-0.1293931]
780 0.0021813812 [1.0542452] [-0.1233121]
800 0.0019811604 [1.0516958] [-0.11751691]
820 0.0017993227 [1.0492663] [-0.11199406]
840 0.0016341762 [1.046951] [-0.10673075]
860 0.0014841849 [1.0447445] [-0.10171485]
880 0.0013479654 [1.0426418] [-0.09693464]
900 0.0012242408 [1.0406377] [-0.09237909]
920 0.001111871 [1.0387278] [-0.08803752]
940 0.001009825 [1.0369078] [-0.08390007]
960 0.00091713294 [1.0351732] [-0.07995704]
980 0.00083295646 [1.0335202] [-0.07619935]
1000 0.0007565059 [1.031945] [-0.07261829]
1020 0.0006870744 [1.0304438] [-0.06920557]
1040 0.0006240067 [1.0290129] [-0.06595317]
1060 0.00056673423 [1.0276494] [-0.0628536]
1080 0.0005147152 [1.0263499] [-0.05989965]
1100 0.00046747155 [1.0251116] [-0.05708459]
1120 0.00042456423 [1.0239314] [-0.05440179]
1140 0.00038559965 [1.0228069] [-0.05184511]
1160 0.00035020712 [1.0217351] [-0.04940866]
1180 0.00031806604 [1.0207134] [-0.04708674]
1200 0.00028886736 [1.0197396] [-0.04487363]
1220 0.00026235086 [1.0188121] [-0.04276457]
1240 0.00023827043 [1.017928] [-0.04075469]
1260 0.00021640316 [1.0170856] [-0.03883936]
1280 0.00019653903 [1.0162824] [-0.03701404]
1300 0.00017850229 [1.0155174] [-0.03527451]
1320 0.00016211714 [1.014788] [-0.03361676]
1340 0.00014723801 [1.014093] [-0.0320369]
1360 0.00013372392 [1.0134307] [-0.03053129]
1380 0.00012145025 [1.0127996] [-0.02909644]
1400 0.00011030317 [1.012198] [-0.02772901]
1420 0.00010017818 [1.0116247] [-0.02642582]
1440 9.098491e-05 [1.0110785] [-0.02518391]
1460 8.263377e-05 [1.0105579] [-0.02400039]
1480 7.5051335e-05 [1.0100619] [-0.02287256]
1500 6.816123e-05 [1.0095887] [-0.02179763]
1520 6.1905266e-05 [1.0091381] [-0.02077321]
1540 5.6223405e-05 [1.0087087] [-0.01979691]
1560 5.1062147e-05 [1.0082994] [-0.01886651]
1580 4.6375324e-05 [1.0079093] [-0.01797981]
1600 4.211857e-05 [1.0075377] [-0.01713484]
1620 3.825328e-05 [1.0071834] [-0.01632959]
1640 3.474223e-05 [1.006846] [-0.01556219]
1660 3.155399e-05 [1.0065241] [-0.01483089]
1680 2.8657421e-05 [1.0062175] [-0.01413385]
1700 2.6027375e-05 [1.0059252] [-0.01346957]
1720 2.3638742e-05 [1.0056468] [-0.01283653]
1740 2.146833e-05 [1.0053815] [-0.01223328]
1760 1.9498499e-05 [1.0051286] [-0.0116584]
1780 1.7708879e-05 [1.0048876] [-0.01111053]
1800 1.608378e-05 [1.0046579] [-0.01058839]
1820 1.46076845e-05 [1.004439] [-0.01009081]
1840 1.3266283e-05 [1.0042304] [-0.00961658]
1860 1.2049089e-05 [1.0040315] [-0.00916467]
1880 1.0943141e-05 [1.003842] [-0.00873394]
1900 9.938384e-06 [1.0036615] [-0.00832346]
1920 9.026024e-06 [1.0034894] [-0.00793231]
1940 8.198044e-06 [1.0033255] [-0.00755953]
1960 7.4456652e-06 [1.0031692] [-0.00720426]
1980 6.762168e-06 [1.0030203] [-0.00686568]
2000 6.1414844e-06 [1.0028784] [-0.00654307]
使用占位符的版本如下。
X = tf.placeholder(tf.float32) # X = tf.placeholder(tf.float32, shape=[None])
Y = tf.placeholder(tf.float32)
for step in range(2001):
cost_val, W_val, b_val, _ = sess.run([cost, W, b, train], feed_dict={X :[1, 2, 3], Y: [1, 2, 3]})
if step % 20 == 0:
print(step, cost_val, W_val, b_val)
0 6.1414844e-06 [1.0028715] [-0.00652734]
20 5.577856e-06 [1.0027366] [-0.00622063]
40 5.066255e-06 [1.002608] [-0.00592834]
60 4.6013033e-06 [1.0024855] [-0.0056498]
80 4.179101e-06 [1.0023687] [-0.00538436]
100 3.7957805e-06 [1.0022573] [-0.00513137]
120 3.4473003e-06 [1.0021513] [-0.00489026]
140 3.1306356e-06 [1.0020502] [-0.00466048]
160 2.843726e-06 [1.0019538] [-0.0044415]
180 2.5825923e-06 [1.001862] [-0.00423283]
200 2.3457799e-06 [1.0017747] [-0.00403396]
220 2.1304106e-06 [1.0016913] [-0.00384443]
240 1.9351698e-06 [1.0016117] [-0.00366384]
260 1.7575672e-06 [1.0015361] [-0.00349173]
280 1.5963069e-06 [1.0014638] [-0.00332767]
300 1.4496977e-06 [1.0013952] [-0.00317134]
320 1.316921e-06 [1.0013297] [-0.00302238]
340 1.1959204e-06 [1.0012671] [-0.0028804]
360 1.086122e-06 [1.0012077] [-0.0027451]
380 9.865297e-07 [1.0011511] [-0.00261621]
400 8.9631254e-07 [1.0010971] [-0.00249339]
420 8.1404386e-07 [1.0010456] [-0.00237633]
440 7.392952e-07 [1.0009965] [-0.00226474]
460 6.7158004e-07 [1.0009496] [-0.00215838]
480 6.1001066e-07 [1.000905] [-0.00205701]
500 5.5406525e-07 [1.0008625] [-0.00196044]
520 5.032363e-07 [1.000822] [-0.00186846]
540 4.5723263e-07 [1.0007836] [-0.00178079]
560 4.1523273e-07 [1.0007468] [-0.00169719]
580 3.7713474e-07 [1.0007117] [-0.00161751]
600 3.4264568e-07 [1.0006783] [-0.00154169]
620 3.1124145e-07 [1.0006467] [-0.00146935]
640 2.827011e-07 [1.0006161] [-0.00140041]
660 2.568669e-07 [1.0005875] [-0.00133481]
680 2.3322984e-07 [1.0005598] [-0.00127214]
700 2.1195665e-07 [1.0005336] [-0.00121258]
720 1.9250615e-07 [1.0005085] [-0.00115563]
740 1.7493107e-07 [1.0004847] [-0.00110153]
760 1.58852e-07 [1.0004619] [-0.00104977]
780 1.4438253e-07 [1.0004405] [-0.00100067]
800 1.310953e-07 [1.0004196] [-0.00095368]
820 1.191587e-07 [1.0004003] [-0.00090906]
840 1.0825985e-07 [1.0003812] [-0.00086656]
860 9.833358e-08 [1.0003635] [-0.00082584]
880 8.935075e-08 [1.0003467] [-0.00078726]
900 8.114407e-08 [1.00033] [-0.00075039]
920 7.3714894e-08 [1.0003147] [-0.00071512]
940 6.705098e-08 [1.0003003] [-0.00068174]
960 6.08941e-08 [1.000286] [-0.00064993]
980 5.527862e-08 [1.0002725] [-0.00061933]
1000 5.025235e-08 [1.00026] [-0.00059032]
1020 4.5718448e-08 [1.000248] [-0.00056285]
1040 4.1517623e-08 [1.000236] [-0.00053657]
1060 3.7675957e-08 [1.0002248] [-0.00051124]
1080 3.4228545e-08 [1.0002146] [-0.00048727]
1100 3.114076e-08 [1.0002048] [-0.0004646]
1120 2.8319485e-08 [1.0001953] [-0.0004431]
1140 2.5718046e-08 [1.0001857] [-0.0004224]
1160 2.3342182e-08 [1.0001769] [-0.00040242]
1180 2.1219392e-08 [1.0001688] [-0.00038352]
1200 1.9279417e-08 [1.0001612] [-0.00036564]
1220 1.755286e-08 [1.0001539] [-0.00034876]
1240 1.5966597e-08 [1.0001467] [-0.00033273]
1260 1.4512845e-08 [1.0001396] [-0.00031727]
1280 1.3158239e-08 [1.0001328] [-0.00030223]
1300 1.1938056e-08 [1.0001266] [-0.00028792]
1320 1.0842793e-08 [1.0001208] [-0.00027439]
1340 9.871989e-09 [1.0001153] [-0.0002616]
1360 8.986551e-09 [1.00011] [-0.00024951]
1380 8.176273e-09 [1.0001051] [-0.00023805]
1400 7.4540205e-09 [1.0001004] [-0.00022721]
1420 6.786564e-09 [1.0000956] [-0.0002168]
1440 6.156578e-09 [1.0000908] [-0.00020669]
1460 5.5840417e-09 [1.0000863] [-0.00019679]
1480 5.0495537e-09 [1.0000823] [-0.00018736]
1500 4.5862563e-09 [1.0000784] [-0.00017845]
1520 4.1692196e-09 [1.0000747] [-0.00017002]
1540 3.788705e-09 [1.0000713] [-0.00016206]
1560 3.443219e-09 [1.0000681] [-0.0001545]
1580 3.1258622e-09 [1.0000651] [-0.00014736]
1600 2.855246e-09 [1.0000621] [-0.00014063]
1620 2.602884e-09 [1.0000594] [-0.00013426]
1640 2.3761497e-09 [1.0000569] [-0.00012824]
1660 2.1730717e-09 [1.0000544] [-0.00012253]
1680 1.9905098e-09 [1.000052] [-0.00011713]
1700 1.8101076e-09 [1.0000496] [-0.00011194]
1720 1.6553988e-09 [1.0000472] [-0.0001069]
1740 1.4972793e-09 [1.0000448] [-0.00010194]
1760 1.3577098e-09 [1.0000424] [-9.705349e-05]
1780 1.2223715e-09 [1.0000402] [-9.222075e-05]
1800 1.1087545e-09 [1.0000381] [-8.761131e-05]
1820 9.971473e-10 [1.0000364] [-8.325857e-05]
1840 8.9879987e-10 [1.0000346] [-7.9139485e-05]
1860 8.167736e-10 [1.0000329] [-7.52473e-05]
1880 7.381118e-10 [1.0000314] [-7.157128e-05]
1900 6.669782e-10 [1.0000298] [-6.809513e-05]
1920 6.017113e-10 [1.0000285] [-6.4810905e-05]
1940 5.504622e-10 [1.0000271] [-6.171385e-05]
1960 5.0008114e-10 [1.0000257] [-5.877932e-05]
1980 4.5211834e-10 [1.0000247] [-5.5997367e-05]
2000 4.1271164e-10 [1.0000235] [-5.337e-05]
可以使用训练好的模型进行预测:
hypothesis
是不使用占位符的版本,hypothesis1
是使用占位符的版本。
hypothesis1 = X * W + b
print(sess.run(hypothesis))
print(sess.run(hypothesis1, feed_dict={X: [5]}))
[0.99997014 1.9999936 3.0000172 ]
[5.000064]