1 #include <opencv2/core/core.hpp>
2 #include <opencv2/highgui/highgui.hpp>
3 #include <opencv2/ml/ml.hpp>
4
5 using namespace cv;
6
7 int main()
8 {
9 // Data for visual representation
10 int width = 512, height = 512;
11 Mat image = Mat::zeros(height, width, CV_8UC3);
12
13 // Set up training data
14 float labels[5] = { 1.0, -1.0, -1.0, -1.0, 1.0 };
15 Mat labelsMat(5, 1, CV_32FC1, labels);
16
17
18 float trainingData[5][2] = { { 501, 10 }, { 255, 10 }, { 501, 255 }, { 10, 501 }, { 501, 128 } };
19 Mat trainingDataMat(5, 2, CV_32FC1, trainingData);
20
21 //设置支持向量机的参数
22 CvSVMParams params;
23 params.svm_type = CvSVM::C_SVC;//SVM类型:使用C支持向量机
24 params.kernel_type = CvSVM::LINEAR;//核函数类型:线性
25 params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 100, 1e-6);//终止准则函数:当迭代次数达到最大值时终止
26
27 //训练SVM
28 //建立一个SVM类的实例
29 CvSVM SVM;
30 //训练模型,参数为:输入数据、响应、XX、XX、参数(前面设置过)
31 SVM.train(trainingDataMat, labelsMat, Mat(), Mat(), params);
32
33 Vec3b green(0, 255, 0), blue(255, 0, 0);
34 //显示判决域
35 for (int i = 0; i < image.rows; ++i)
36 for (int j = 0; j < image.cols; ++j)
37 {
38 Mat sampleMat = (Mat_<float>(1, 2) << i, j);
39 //predict是用来预测的,参数为:样本、返回值类型(如果值为ture而且是一个2类问题则返回判决函数值,否则返回类标签)、
40 float response = SVM.predict(sampleMat);
41
42 if (response == 1)
43 image.at<Vec3b>(j, i) = green;
44 else if (response == -1)
45 image.at<Vec3b>(j, i) = blue;
46 }
47
48 //画出训练数据
49 int thickness = -1;
50 int lineType = 8;
51 circle(image, Point(501, 10), 5, Scalar(0, 0, 0), thickness, lineType);//画圆
52 circle(image, Point(255, 10), 5, Scalar(255, 255, 255), thickness, lineType);
53 circle(image, Point(501, 255), 5, Scalar(255, 255, 255), thickness, lineType);
54 circle(image, Point(10, 501), 5, Scalar(255, 255, 255), thickness, lineType);
55 circle(image, Point(501, 128), 5, Scalar(0, 0, 0), thickness, lineType);
56
57 //显示支持向量
58 thickness = 2;
59 lineType = 8;
60 //获取支持向量的个数
61 int c = SVM.get_support_vector_count();
62
63 for (int i = 0; i < c; ++i)
64 {
65 //获取第i个支持向量
66 const float* v = SVM.get_support_vector(i);
67 //支持向量用到的样本点,用灰色进行标注
68 circle(image, Point((int)v[0], (int)v[1]), 6, Scalar(128, 128, 128), thickness, lineType);
69 }
70
71 imwrite("result.png", image); // save the image
72
73 imshow("SVM Simple Example", image); // show it to the user
74 waitKey(0);
75
76 }