• 使用SVM对于许多类型的多维数据分类


    最近,我做了一件小事,使用SVM正确8三维级数据分类,在线搜索,我们发现二分的问题大家都在讨论二维数据,一些决定自己的研究。我首先参考opencvtutorial。这也是二维数据的二分类问题。然后通过学习研究,发现别有洞天,遂实现之前的目标。在这里将代码贴出来。这里实现了对三维数据进行三类划分。以供大家相互学习。

    #include "stdafx.h"
    #include <iostream>
    #include <opencv2/core/core.hpp>
    #include <opencv2/highgui/highgui.hpp>
    #include <opencv2/ml/ml.hpp>
    
    using namespace cv;
    using namespace std;
    
    int main()
    {
    
        //--------------------- 1. Set up training data randomly ---------------------------------------
        Mat trainData(100, 3, CV_32FC1);
        Mat labels   (100, 1, CV_32FC1);
    
        RNG rng(100); // Random value generation class
    
        // Generate random points for the class 1
        Mat trainClass = trainData.rowRange(0, 40);
        // The x coordinate of the points is in [0, 0.4)
        Mat c = trainClass.colRange(0, 1);
        rng.fill(c, RNG::UNIFORM, Scalar(1), Scalar(0.4 * 100));
        // The y coordinate of the points is in [0, 0.4)
        c = trainClass.colRange(1, 2);
        rng.fill(c, RNG::UNIFORM, Scalar(1), Scalar(0.4 * 100));
    	// The z coordinate of the points is in [0, 0.4)
        c = trainClass.colRange(2, 3);
        rng.fill(c, RNG::UNIFORM, Scalar(1), Scalar(0.4 * 100));
    
        // Generate random points for the class 2
        trainClass = trainData.rowRange(60, 100);
        // The x coordinate of the points is in [0.6, 1]
        c = trainClass.colRange(0, 1);
        rng.fill(c, RNG::UNIFORM, Scalar(0.6*100), Scalar(100));
        // The y coordinate of the points is in [0.6, 1)
        c = trainClass.colRange(1, 2);
        rng.fill(c, RNG::UNIFORM, Scalar(0.6*100), Scalar(100));
    	 // The z coordinate of the points is in [0.6, 1]
        c = trainClass.colRange(2, 3);
        rng.fill(c, RNG::UNIFORM, Scalar(0.6*100), Scalar(100));
    
        
    
        // Generate random points for the classes 3
        trainClass = trainData.rowRange(  40, 60);
        // The x coordinate of the points is in [0.4, 0.6)
        c = trainClass.colRange(0,1);
        rng.fill(c, RNG::UNIFORM, Scalar(0.4*100), Scalar(0.6*100));
        // The y coordinate of the points is in [0.4, 0.6)
        c = trainClass.colRange(1,2);
        rng.fill(c, RNG::UNIFORM, Scalar(0.4*100), Scalar(0.6*100));
    	// The z coordinate of the points is in [0.4, 0.6)
        c = trainClass.colRange(2,3);
        rng.fill(c, RNG::UNIFORM, Scalar(0.4*100), Scalar(0.6*100));
    
    
    
        //------------------------- Set up the labels for the classes ---------------------------------
        labels.rowRange( 0,  40).setTo(1);  // Class 1
        labels.rowRange(60, 100).setTo(2);  // Class 2
    	labels.rowRange(40, 60).setTo(3);  // Class 3
    
    
        //------------------------ 2. Set up the support vector machines parameters --------------------
        CvSVMParams params;
        params.svm_type    = SVM::C_SVC;
        params.C           = 0.1;
        params.kernel_type = SVM::LINEAR;
        params.term_crit   = TermCriteria(CV_TERMCRIT_ITER, (int)1e7, 1e-6);
    
        //------------------------ 3. Train the svm ----------------------------------------------------
        cout << "Starting training process" << endl;
        CvSVM svm;
        svm.train(trainData, labels, Mat(), Mat(), params);
        cout << "Finished training process" << endl;
    
    	 Mat sampleMat = (Mat_<float>(1,3) << 50, 50,10);
         float response = svm.predict(sampleMat);
    	 cout<<response<<endl;
    
    	 sampleMat = (Mat_<float>(1,3) << 50, 50,100);
         response = svm.predict(sampleMat);
    	 cout<<response<<endl;
    
    	 sampleMat = (Mat_<float>(1,3) << 50, 50,60);
         response = svm.predict(sampleMat);
    	 cout<<response<<endl;
    	
        waitKey(0);
    }



    版权声明:本文博客原创文章。博客,未经同意,不得转载。

  • 相关阅读:
    【YbtOJ#911】欧拉函数
    【CF590E】Birthday
    打印控件的区别
    RPA教程
    UiPath培训教程
    RPA视频教程
    搭建samba服务
    kvm虚拟机在线扩容
    zabbix监控交换机
    UiPath Level3讲解
  • 原文地址:https://www.cnblogs.com/lcchuguo/p/4719018.html
Copyright © 2020-2023  润新知