• java EM算法


    学习吧,总会有厌倦期,EM算法就不自己写了,转帖一个源码,可运行,个人不喜欢算法的图形界面,所以这个算法也没怎么研究,但是可运行。转载的最初链接无从查起,如有侵权,请与我联系,带来不便,敬请谅解。不多啰嗦,直接上源码(个人感觉EM算法的高斯混合模型好难- -、)

      1 import java.awt.Dimension;
      2 import java.awt.EventQueue;
      3 import java.awt.Toolkit;
      4 import java.awt.event.ActionEvent;
      5 import java.awt.event.ActionListener;
      6 import java.util.ArrayList;
      7 
      8 import java.lang.Math;
      9 import java.text.DecimalFormat;
     10 
     11 import javax.swing.JButton;
     12 import javax.swing.JFrame;
     13 import javax.swing.JLabel;
     14 import javax.swing.JScrollPane;
     15 import javax.swing.JTable;
     16 import javax.swing.JTextField;
     17 
     18 
     19 public class MachineTranslation{
     20     private static final long serialVersionUID = 2904270580467455923L;
     21     public static void main(String[] args) {
     22         EventQueue.invokeLater(new Runnable() {
     23             public void run() {
     24                 Display frame = new Display();
     25                 frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
     26                 frame.setVisible(true);
     27             }
     28         });
     29     }
     30 }
     31 
     32 class Display extends JFrame {
     33     private static final int DEFAULT_WIDTH = 532;
     34     private static final int DEFAULT_HEIGHT = 508;
     35     
     36     private JTextField textField_3;
     37     private JTextField textField_2;
     38     private JTextField textField_1;
     39     private static final long serialVersionUID = -2794537679802534502L;
     40     
     41     private JTable table;
     42     private JTextField textField;
     43     private final JButton emButton;
     44     
     45     public Display() {
     46         super();
     47         getContentPane().setLayout(null);
     48         setSize(DEFAULT_WIDTH,DEFAULT_HEIGHT);
     49         
     50 
     51         textField = new JTextField("<HHH>,<TTT>,<HHH>,<TTT>");
     52         textField.setBounds(61, 42, 237, 33);
     53         getContentPane().add(textField);
     54 
     55         emButton = new JButton();
     56         emButton.setText("EM");
     57         emButton.setBounds(344, 44, 106, 28);
     58         getContentPane().add(emButton);
     59 
     60 
     61         final JLabel wLabel = new JLabel();
     62         wLabel.setText("λ");
     63         wLabel.setBounds(61, 98, 25, 18);
     64         getContentPane().add(wLabel);
     65 
     66         final JLabel p1Label = new JLabel();
     67         p1Label.setText("P1");
     68         p1Label.setBounds(198, 98, 25, 18);
     69         getContentPane().add(p1Label);
     70 
     71         final JLabel p2Label = new JLabel();
     72         p2Label.setText("P2");
     73         p2Label.setBounds(349, 98, 25, 18);
     74         getContentPane().add(p2Label);
     75 
     76         textField_1 = new JTextField();
     77         textField_1.setText("0.3");
     78         textField_1.setBounds(86, 96, 87, 22);
     79         getContentPane().add(textField_1);
     80 
     81         textField_2 = new JTextField();
     82         textField_2.setText("0.3");
     83         textField_2.setBounds(229, 96, 87, 22);
     84         getContentPane().add(textField_2);
     85 
     86         textField_3 = new JTextField();
     87         textField_3.setText("0.6");
     88         textField_3.setBounds(380, 96, 87, 22);
     89         getContentPane().add(textField_3);
     90         
     91         setTitle("掷硬币EM算法之实现");
     92         
     93         Toolkit tk =this.getToolkit();//得到窗口工具条
     94         Dimension dm = tk.getScreenSize();
     95         this.setLocation((int)(dm.getWidth()-DEFAULT_WIDTH)/2,(int)(dm.getHeight()-DEFAULT_HEIGHT)/2);//显示在屏幕中央
     96 
     97         
     98         emButton.addActionListener(new ActionListener() {
     99             public void actionPerformed(ActionEvent event) {
    100                 EmAlgorithm em = new EmAlgorithm(textField.getText(),textField_1.getText(),textField_2.getText(),textField_3.getText());
    101                 em.maximizeExpectation();
    102                 final JScrollPane scrollPane = new JScrollPane();
    103                 scrollPane.setBounds(37, 121, 453, 261);
    104                 getContentPane().add(scrollPane);
    105 
    106                 table= new JTable(em.getCells(),em.getColumnNames());
    107                 scrollPane.setViewportView(table);
    108                 System.out.println(1E-300-1E-301>0);
    109             }
    110         });
    111     }
    112 }
    113 
    114 class EmAlgorithm{
    115     private String str;
    116     private ArrayList<Sub> sub=new ArrayList<Sub>();
    117     private ArrayList<Double> w=new ArrayList<Double>();
    118     private ArrayList<Double> P1=new ArrayList<Double>();
    119     private ArrayList<Double> P2=new ArrayList<Double>();
    120     private ArrayList<ArrayList<Double>> p;
    121     
    122     private Object[][] cells;
    123     private String[] columnNames;
    124 
    125     public EmAlgorithm(String str, String str1,String str2, String str3){
    126         this.str=str;
    127         w.add(Double.parseDouble(str1));
    128         P1.add(Double.parseDouble(str2));
    129         P2.add(Double.parseDouble(str3));
    130 
    131         textSplit();    
    132         p=new ArrayList<ArrayList<Double>>();
    133         for(int j = 0; j < sub.size(); ++j)
    134             p.add(new ArrayList<Double>());
    135         System.out.println(p);
    136     }
    137     public void textSplit(){
    138         String[] sList;
    139         sList=str.substring(1,str.length()-1).split(">,<");
    140 
    141         for (int i = 0; i < sList.length; i++){
    142             sub.add(new Sub(sList[i]));
    143         }
    144         System.out.println(sub);
    145     }
    146     public void maximizeExpectation(){
    147         int iterate=0;
    148         if (!P1.get(P1.size() - 1).equals(P2.get(P2.size() - 1))) {
    149             do {
    150                 iterate++;
    151                 compute();
    152             } while (Math.abs(P1.get(P1.size() - 1) - P1.get(P1.size() - 2)) > 0.0000001);
    153             // }while(Math.abs(P1.get(P1.size()-1)-P1.get(P1.size()-2))>1.E-300);
    154         }
    155         else{
    156             iterate=7;
    157             for(int i=0;i<iterate;i++){
    158                 compute();
    159             }
    160         }
    161         cells=new Object[p.get(0).size()][4+sub.size()];
    162         
    163         DecimalFormat df_t=new DecimalFormat("0.0000");
    164         
    165         for(int i=0;i<iterate;i++){
    166             cells[i][0]=i;
    167             cells[i][1]=df_t.format(w.get(i));
    168             cells[i][2]=df_t.format(P1.get(i));
    169             cells[i][3]=df_t.format(P2.get(i));
    170             for(int j=0;j<sub.size();j++){
    171                 cells[i][j+4]=df_t.format(p.get(j).get(i));
    172             }
    173         }
    174         columnNames=new String[4+sub.size()];
    175         columnNames[0]="Iteration";
    176         columnNames[1]="w";
    177         columnNames[2]="P1";
    178         columnNames[3]="P2";
    179         for(int i=0;i<sub.size();i++)
    180             columnNames[i+4]="p"+(i+1);
    181     }
    182     public void compute(){
    183         double w=this.w.get(this.w.size()-1);
    184         double P1=this.P1.get(this.P1.size()-1);
    185         double P2=this.P2.get(this.P2.size()-1);
    186         //System.out.println(w+","+P1+","+P2+"     "+this.w.size());
    187         for (int i = 0; i < sub.size(); i++) {
    188             float PP1 = (float) Math.pow(P1, sub.get(i).getH())
    189                     * (float) Math.pow(1 - P1,sub.get(i).getT());
    190             float PP2 = (float) Math.pow(P2, sub.get(i).getH())
    191                     * (float) Math.pow(1 - P2, sub.get(i).getT());
    192             p.get(i).add(w * PP1 / (w * PP1 + (1 - w) * PP2));
    193         }
    194         double sump=0;
    195         double sumP1=0;
    196         double sumP2=0;float sumpp=0;
    197         
    198         for(int i=0;i<sub.size();i++){
    199             sump+=p.get(i).get(p.get(i).size()-1);
    200             sumP1+=(sub.get(i).getH()/3.0)*p.get(i).get(p.get(i).size()-1);
    201             sumP2+=(sub.get(i).getH()/3.0)*(1-p.get(i).get(p.get(i).size()-1));
    202             sumpp+=(1-p.get(i).get(p.get(i).size()-1));
    203         }
    204         this.w.add(sump/sub.size());
    205         this.P1.add(sumP1/sump);
    206         this.P2.add(sumP2/sumpp);
    207     }
    208     public Object[][] getCells(){
    209         return cells;
    210     }
    211     public String[] getColumnNames(){
    212         return columnNames;
    213     }
    214 }
    215 class CountHT{
    216     private int hCount=0;
    217     private int tCount=0;
    218     
    219     public CountHT(String str){
    220         char[] temp=str.toCharArray();
    221         for(int j=0;j<temp.length;j++){
    222             if(temp[j]=='H'){
    223                 hCount++;
    224                 continue;
    225             }
    226             if(temp[j]=='T'){
    227                 tCount++;
    228                 continue;
    229             }
    230             System.out.println("输入有问题");
    231         }    
    232     }
    233     public int getH(){
    234         return hCount;
    235     }
    236     public int getT(){
    237         return tCount;
    238     }
    239 }
    240 
    241 class Sub{
    242     private String str;
    243     private int hCount;
    244     private int tCount;
    245     
    246     public Sub(String str){
    247         this.str=str;
    248         CountHT countHT=new CountHT(str);
    249         hCount=countHT.getH();
    250         tCount=countHT.getT();
    251     }
    252     public String toString(){
    253         return str+" "+"hCount="+hCount+" "+"tCount="+tCount;
    254     }
    255     public int getH(){
    256         return hCount;
    257     }
    258     public int getT(){
    259         return tCount;
    260     }
    261 }
  • 相关阅读:
    HttpWatch 有火狐版本?
    JQgrid的最新API
    jqgrid
    JSON的学习网站
    array创建数组
    Numpy安装及测试
    SQLite3删除数据_7
    SQLite3修改数据_6
    SQLite3查询一条数据_5
    SQLite3查询所有数据_4
  • 原文地址:https://www.cnblogs.com/wn19910213/p/3337538.html
Copyright © 2020-2023  润新知