• 【EM】C++代码实现


    看了原理和比人的代码后,终于自己写了一个EM的实现。

    我从网上找了一些身高性别的数据,用EM算法通过身高信息来识别性别。

    实现的效果还行,正确率有84% (初始数据 男生170 女生160 方差都是10)

                                     79%  (初始数据 男生165 女生150 方差都是10)

    正确率与初始值有关。

    /*
    试图用EM算法来根据输入的身高来区分性别
    */
    
    #include<iostream>
    #include<fstream>
    #include<algorithm>
    #include<vector>
    using namespace std;
    
    #define PI 3.14159
    #define max(x,y) (x > y ? x : y)
    
    typedef struct FLOAT2
    {
        float f1;
        float f2;
    }FLOAT2;
    typedef struct Gaussian
    {
        float mean;
        float var;
    }Gaussian;
    
    typedef struct EMData
    {
        char sex;
        float fHeight;
    }EMData;
    
    //获取身高性别数据
    int getdata(vector<EMData> &Data)
    {
        ifstream fin;
        fin.open("data.txt");
        if(!fin)
        {
            cout<<"error: can't open the file."<<endl;
            return -1;
        }
    
        while(!fin.eof())
        {
            char c[10];
            float height;
            fin >> c >> height;
            EMData data;
            data.sex = c[0];
            data.fHeight = height;
            Data.push_back(data);
        }
    
        return 0;
    }
    
    //根据身高数据区分性别, 返回正确率
    float predict(vector<EMData> Data)
    {
        //设符合正态分布
        Gaussian sex[2];
        float a[2]; //男女生所占百分比
        float t = 1;
        float tlimit = 0.000001; //收敛条件
    
        //赋初值 下标0表示男生 1表示女生
        sex[0].mean = 180.0;
        sex[0].var = 10.0;
        sex[1].mean = 150.0;
        sex[1].var = 10.0;
        a[0] = 0.5;
        a[1] = 0.5;
    
        while(t > tlimit)
        {
            Gaussian sex_old[2];
            float a_old[2];
            sex_old[0] = sex[0];
            sex_old[1] = sex[1];
            a_old[0] = a[0];
            a_old[1] = a[1];
    
            //计算每个样本分别被两个模型抽中的概率
            vector<FLOAT2> px;
        
            vector<EMData>::iterator it;
            for(it = Data.begin(); it < Data.end(); it++)
            {
                FLOAT2 p;
                p.f1 = 1/(sqrt(2 * PI * sex[0].var)) * exp(-(it->fHeight - sex[0].mean) * (it->fHeight - sex[0].mean) / (2 * sex[0].var));
                p.f2 = 1/(sqrt(2 * PI * sex[1].var)) * exp(-(it->fHeight - sex[1].mean) * (it->fHeight - sex[1].mean) / (2 * sex[1].var));
                px.push_back(p);
            }
    
            //E步
            //计算每个样本属于男生或女生的概率
            vector<FLOAT2>::iterator it2;
            for(it2 = px.begin(); it2 < px.end(); it2++)
            {
                float sum = 0.0;
                (*it2).f1 *= a[0];
                sum += (*it2).f1;
                (*it2).f2 *= a[1];
                sum += (*it2).f2;
    
                (*it2).f1 = (*it2).f1/sum;
                (*it2).f2 = (*it2).f2/sum;
            }
    
            //M步
            float sum_male = 0, sum_female = 0;
            float sum_mean_male = 0, sum_mean_female = 0;
            for(it2 = px.begin(), it = Data.begin(); it2 < px.end(); it2++, it++)
            {
                sum_male += (*it2).f1;
                sum_female += (*it2).f2;
                sum_mean_male += (*it2).f1 * (it->fHeight);
                sum_mean_female += (*it2).f2 * (it->fHeight);
            }
            //更新a
            a[0] = sum_male/(sum_male + sum_female);
            a[1] = sum_female/(sum_male + sum_female);
    
            //更新均值
            sex[0].mean = sum_mean_male/ sum_male;
            sex[1].mean = sum_mean_female/ sum_female;
    
            //更新方差
            float sum_var_male = 0, sum_var_female = 0;
            for(it2 = px.begin(), it = Data.begin(); it2 < px.end(); it2++, it++)
            {
                sum_var_male += (*it2).f1 * ((it->fHeight) - sex[0].mean) * ((it->fHeight) - sex[0].mean);
                sum_var_female += (*it2).f2 * ((it->fHeight) - sex[1].mean) * ((it->fHeight) - sex[1].mean);
            }
            sex[0].var = sum_var_male / sum_male;
            sex[1].var = sum_var_female / sum_female;
    
            //计算变化率
            t = max((a[0] - a_old[0])/a_old[0], (a[1] - a_old[1])/a_old[1]);
            t = max(t, (sex[0].mean - sex_old[0].mean)/sex_old[0].mean);
            t = max(t, (sex[1].mean - sex_old[1].mean)/sex_old[1].mean);
            t = max(t, (sex[0].var - sex_old[0].var)/sex_old[0].var);
            t = max(t, (sex[1].var - sex_old[1].var)/sex_old[1].var);
        }
    
        //计算正确率
        int correct_num = 0;
        float correct_rate = 0;
        vector<EMData>::iterator it;
        for(it = Data.begin(); it < Data.end(); it++)
        {
            float p[2];
            char csex;
            for(int i = 0; i < 2; i++)
            {
                p[i] = 1/(sqrt(2 * PI * sex[i].var)) * exp(-(it->fHeight - sex[i].mean) * (it->fHeight - sex[i].mean) / (2 * sex[i].var));
            }
    
            csex = (p[0] > p[1]) ? 'm' : 'f';
            if(csex == it->sex)
                correct_num++;
        }
    
        correct_rate = (float)correct_num / Data.size();
        return correct_rate;
    }
    
    int main()
    {
        vector<EMData> Data;
        getdata(Data);
        float correct_rate = predict(Data);
        cout << "correct rate = "<< correct_rate << endl;
        return 0;
    }

    数据:data.txt内容

    male    164
    female    156
    male    168
    female    160
    female    162
    male    187
    female    162
    male    167
    female    160.5
    female    160
    female    158
    female    164
    female    165
    male    174
    female    166
    female    158
    male     162
    male    175
    male    170
    female    161
    female    169
    female    161
    female    160
    female    167
    male    176
    male    169
    male    178
    male    165
    female    155
    male    183
    male    171
    male    179
    female    154
    male    172
    female    172
    male    173
    male    172
    male    175
    male    160
    male    160
    male    160
    male    175
    male    163
    male    181
    male    172
    male    175
    male    175
    male    167
    male    172
    male    169
    male    172
    male    175
    male    172
    male    170
    male    158
    male    167
    male    164
    male    176
    male    182
    male    173
    male    176
    male    163
    male    166
    male    162
    male    169
    male    163
    male    163
    male    176
    male    169
    male    173
    male    163
    male    167
    male    176
    male    168
    male    167
    male    170
    female    155
    female    157
    female    165
    female    156
    female    155
    female    156
    female    160
    female    158
    female    162
    female    162
    female    155
    female    163
    female    160
    female    162
    female    165
    female    159
    female    147
    female    163
    female    157
    female    160
    female    162
    female    158
    female    155
    female    165
    female    161
    female    159
    female    163
    female    158
    female    155
    female    162
    female    157
    female    159
    female    152
    female    156
    female    165
    female    154
    female    156
    female    162
  • 相关阅读:
    osg 自定义图元
    osg model
    Qt 获取键盘输入
    TensorFlow Object Detection API —— 测试自己的模型
    labelimg data
    Qt 自定义信号SIGNAL
    qt ui
    QPixmap QImage 相互转化
    QString std::string 相互转 含中文
    ubuntu 安装百度云客户端
  • 原文地址:https://www.cnblogs.com/dplearning/p/3981578.html
Copyright © 2020-2023  润新知