• kmean算法C++实现


    kmean均值算法是一种最常见的聚类算法。算法实现简单,效果也比较好。kmean算法把n个对象划分成指定的k个簇,每个簇中所有对象的均值的平均值为该簇的聚点(中心)。

    k均值算法有如下五个步骤:

    1. 随机生成最初始k个簇心。可以从样本中随机选择,也可以根据样本中每个特征的取值特点随机生成。
    2. 对每个样本计算到每个簇心的欧式距离,将样本划分到欧氏距离最小的簇心(聚点)。
    3. 对划分到同一个簇心(聚点)的样本计算平均值,用均值更新簇心(聚点)
    4. 若某些簇心(聚点)发生变化,转到2;若所有的聚点都没有变化,转5
    5. 输出划分结果
      1 #include <vector>
      2 #include <cassert>
      3 #include <iostream>
      4 #include <cmath>
      5 #include <fstream>
      6 #include <climits>
      7 #include <ctime>
      8 #include <iomanip>
      9 
     10 using namespace std;
     11 namespace terse {
     12 class Kmeans {
     13 private:
     14     vector<vector<double>> m_dataSet;
     15     int m_k;
     16     vector<int> m_clusterResult;         // result of cluster
     17     vector<vector<double>> m_cluserCent; //center of k clusters
     18 
     19 private:
     20     vector<string> split(const string& s, string pattern) {
     21         vector<string> res;
     22         size_t start = 0;
     23         size_t end = 0;
     24         while (start < s.size()) {
     25             end = s.find_first_of(pattern, start);
     26             if (end == string::npos) {
     27                 res.push_back(s.substr(start, end - start - 1));
     28                 return res;
     29             }
     30             res.push_back(s.substr(start, end - start));
     31             start = end + 1;
     32         }
     33         return res;
     34     }
     35 
     36     void loadDataSet(const char* fileName) {
     37         ifstream dataFile(fileName);
     38         if (!dataFile.is_open()) {
     39             cerr << "open file " << fileName << "failed!
    ";
     40             return;
     41         }
     42         string tmpstr;
     43         vector<double> data;
     44         while (!dataFile.eof()) {
     45             data.clear();
     46             tmpstr.clear();
     47             getline(dataFile, tmpstr);
     48             vector<string> tmp = split(tmpstr, ",");
     49             for (string str : tmp) {
     50                 data.push_back(stod(str));
     51             }
     52             this->m_dataSet.push_back(data);
     53         }
     54         dataFile.close();
     55     }
     56 
     57     //compute Euclidean distance of two vector
     58     double distEclud(vector<double>& v1, vector<double>& v2) {
     59         assert(v1.size() == v2.size());
     60         double dist = 0;
     61         for (size_t i = 0; i < v1.size(); i++) {
     62             dist += (v1[i] - v2[i]) * (v1[i] - v2[i]);
     63         }
     64         return sqrt(dist);
     65     }
     66 
     67     void generateRandCent() {
     68         int numOfFeats = this->m_dataSet[0].size();
     69         size_t numOfSamples = this->m_dataSet.size();
     70 
     71         //first:min second:max
     72         vector<pair<double, double>> minMaxOfFeat(numOfFeats);
     73         for (int i = 0; i < numOfFeats; i++) {
     74             minMaxOfFeat[i].first = this->m_dataSet[0][i];
     75             minMaxOfFeat[i].second = this->m_dataSet[0][i];
     76         }
     77         for (size_t i = 1; i < numOfSamples; i++) {
     78             for (int j = 0; j < numOfFeats; j++) {
     79                 if (this->m_dataSet[i][j] > minMaxOfFeat[j].second) {
     80                     minMaxOfFeat[j].second = this->m_dataSet[i][j];
     81                 }
     82                 if (this->m_dataSet[i][j] < minMaxOfFeat[j].first) {
     83                     minMaxOfFeat[j].first = this->m_dataSet[i][j];
     84                 }
     85             }
     86         }
     87         srand(time(NULL));
     88         for (int i = 0; i < this->m_k; i++) {
     89             for (int j = 0; j < numOfFeats; j++) {
     90                 this->m_cluserCent[i][j] = minMaxOfFeat[j].first
     91                         + (minMaxOfFeat[j].second - minMaxOfFeat[j].first)
     92                                 * (rand() / (double) RAND_MAX);
     93             }
     94         }
     95 
     96     }
     97 
     98     void printClusterCent(int iter) {
     99         int m = this->m_cluserCent.size();
    100         int n = this->m_cluserCent[0].size();
    101         cout << "iter =  " << iter;
    102         for (int i = 0; i < m; i++) {
    103             cout << " {";
    104             for (int j = 0; j < n; j++) {
    105                 cout << this->m_cluserCent[i][j] << ",";
    106             }
    107             cout << "};";
    108         }
    109         cout << endl;
    110     }
    111 
    112     void writeResult(const char* fileName = "res.txt") {
    113         ofstream fout(fileName);
    114         if (!fout.is_open()) {
    115             cerr << "open file " << fileName << "failed!";
    116             return;
    117         }
    118         for (size_t i = 0; i < this->m_dataSet.size(); i++) {
    119             for (size_t j = 0; j < this->m_dataSet[0].size(); j++) {
    120                 fout << this->m_dataSet[i][j] << "	";
    121             }
    122             fout << setprecision(5) << this->m_clusterResult[i] << "
    ";
    123         }
    124         fout.close();
    125     }
    126 
    127 public:
    128     Kmeans(int k, const char* fileName) {
    129         this->m_k = k;
    130         this->loadDataSet(fileName);
    131         this->m_clusterResult.reserve(this->m_dataSet.size());
    132         this->m_cluserCent = vector<vector<double>>(k,
    133                 vector<double>(this->m_dataSet[0].size()));
    134         generateRandCent();
    135     }
    136 
    137     Kmeans(int k, vector<vector<double>>& data) {
    138         this->m_k = k;
    139         this->m_dataSet = data;
    140         this->m_clusterResult.reserve(this->m_dataSet.size());
    141         this->m_cluserCent = vector<vector<double>>(k,
    142                 vector<double>(this->m_dataSet[0].size()));
    143         generateRandCent();
    144     }
    145 
    146     //verbose = 1,printClusterCent();
    147     void kmeansCluster(int verbose = 1) {
    148         int iter = 0;
    149         bool isClusterChanged = true;
    150         while (isClusterChanged) {
    151             isClusterChanged = false;
    152             //step 1: find the nearest centroid of each point
    153             int numOfFeats = this->m_dataSet[0].size();
    154             size_t numOfSamples = this->m_dataSet.size();
    155             for (size_t i = 0; i < numOfSamples; i++) {
    156                 int minIndex = -1;
    157                 double minDist = INT_MAX;
    158                 for (int j = 0; j < this->m_k; j++) {
    159                     double dist = distEclud(this->m_cluserCent[j],
    160                             m_dataSet[i]);
    161                     if (dist < minDist) {
    162                         minDist = dist;
    163                         minIndex = j;
    164                     }
    165                 }
    166                 if (m_clusterResult[i] != minIndex) {
    167                     isClusterChanged = true;
    168                     m_clusterResult[i] = minIndex;
    169                 }
    170             }
    171 
    172             //step 2: update cluster center
    173             vector<size_t> cnt(this->m_k, 0);
    174             this->m_cluserCent = vector<vector<double>>(this->m_k,
    175                     vector<double>(numOfFeats, 0.0));
    176             for (size_t i = 0; i < numOfSamples; i++) {
    177                 for (int j = 0; j < numOfFeats; j++) {
    178                     this->m_cluserCent[this->m_clusterResult[i]][j] +=
    179                             this->m_dataSet[i][j];
    180                 }
    181                 cnt[this->m_clusterResult[i]]++;
    182             }
    183             // mean of the vector belong to a cluster
    184             for (int i = 0; i < this->m_k; i++) {
    185                 for (int j = 0; j < numOfFeats; j++) {
    186                     this->m_cluserCent[i][j] /= cnt[i];
    187                 }
    188             }
    189             if (verbose)
    190                 printClusterCent(iter++);
    191         }
    192         writeResult();
    193     }
    194 };
    195 
    196 };
    197 
    198 int main(){
    199     terse::Kmeans kmeans(4,"datafile.txt");
    200     kmeans.kmeansCluster();
    201     return 0;
    202 }
    203 /*namespace terse*/
  • 相关阅读:
    Nim教程【八】(博客园撰写工具客户端更新)
    图解 MongoDB 地理位置索引的实现原理(转)
    MongoDB学习笔记(索引)(转)
    Hadoop集群WordCount运行详解(转)
    java操作mongodb(连接池)(转)
    面向对象设计七大原则(转)
    Spring中IOC和AOP的详细解释(转)
    java单例模式使用及注意事项
    java.io包的总体框架图(转)
    Java常见异常(Runtime Exception )小结(转)
  • 原文地址:https://www.cnblogs.com/wxquare/p/6754485.html
Copyright © 2020-2023  润新知