• 使用Cross-validation (CV) 调整Extreme learning Machine (ELM) 最优参数的实现(matlab)


    ELM算法模型是最近几年得到广泛重视的模型,它不同于现在广为火热的DNN。 ELM使用传统的三层神经网络,只包含一个隐含层,但又不同于传统的神经网络。ELM是一种简单易用、有效的单隐层前馈神经网络SLFNs学习算法。2006年由南洋理工大学黄广斌副教授提出。传统的神经网络学习算法(如BP算法)需要人为设置大量的网络训练参数,并且很容易产生局部最优解。极限学习机只需要设置网络的隐层节点个数,在算法执行过程中不需要调整网络的输入权值以及隐元的偏置,并且产生唯一的最优解,因此具有学习速度快且泛化性能好的优点。但是隐含层节点个数的设置需要经过人工大量实验得到或者通过最常见的CV方法可以得到。 下面,matlab实现10-fold CV 寻找最优隐含层节点个数的。作为以前工作的一个小记录。ELM使用的是主页http://www.ntu.edu.sg/home/egbhuang/ 源码。源程序包含两个脚本文件cv_para.m  和 Data2txt.m

    Description:

    @cv_para.m 这个是主程序

    结构体:function [best_para]=cv_para(data,para_set)。其中使用到两个参数,data代表我们的完整数据,也就是没有划分训练集和测试集的完整数据。para_set代表隐含层节点个数的一个数组,例如在[1:60]之间选择一个最优的隐含层节点个数。

     

    @Data2txt.m 这个脚本文件是为了满足ELM算法训练将数据转化为源码ELM可以使用的文本文件。数据格式在ELM主页已经给出example。

      1 function [best_para]=cv_para(data,para_set)
      2 
      3 num_folds=10;     % 10-fold cross validation
      4 
      5 n=size(data,1);
      6 
      7 n_paras=length(para_set);
      8 
      9 idx=randperm(n);    % idx 代表n个数据中索引的任意排列
     10 
     11 n_test=floor(n/num_folds); % n_test: 测试集包含的数据集的个数
     12 
     13 test_idx=zeros(num_folds,n_test); % test_idx: 储存num_folds次测试集的索引
     14 
     15 train_idx=zeros(num_folds,n-n_test); %train_idx: 原理同test_idx
     16 
     17  
     18 
     19 % 下面程序操作的主要是索引,只要将训练集地址和测试集地址划分出来
     20 
     21 for i=1:num_folds
     22 
     23     test_idx(i,:)=idx((i-1)*n_test+1:i*n_test);  
     24 
     25     tmp=1:n;
     26 
     27     tmp(test_idx(i,:))=[];
     28 
     29     train_idx(i,:)=tmp;
     30 
     31 end
     32 
     33 best_accs=inf;
     34 
     35 best_para=1; % 保存最优的隐含节点个数
     36 
     37 for i=1:n_paras
     38 
     39         one_accs=0;
     40 
     41         for j=1:num_folds  
     42 
     43             % 这里就是将数据集转化为文本文件形式,以满足elm源码的需求
     44 
     45              train_data=data(train_idx(j,:),:);
     46 
     47              test_data=data(test_idx(j,:),:);
     48 
     49              Data2txt(train_data,'trainfile');
     50 
     51              Data2txt(test_data,'testfile');
     52 
     53             
     54 
     55              [TrainingTime, TrainingAccuracy] = elm_train('trainfile', 0, para_set(1,i), 'sig');
     56 
     57              [TestingTime, acc] = elm_predict('testfile');
     58 
     59             one_accs=one_accs+acc;
     60 
     61            
     62 
     63             delete('trainfile');
     64 
     65             delete('testfile');
     66 
     67         end
     68 
     69         if(best_accs>one_accs)
     70 
     71             best_para=para_set(1,i);
     72 
     73             best_accs=one_accs;
     74 
     75         end
     76 
     77 end
     78 
     79 end
     80 
     81  
     82 
     83 @Data2txt 源码
     84 
     85 function[]=Data2txt(Data,file)
     86 
     87     fid=fopen(file,'w');%дÈëÎļþ·¾¶
     88 
     89     [m,n]=size(Data);
     90 
     91      for i=1:1:m
     92 
     93          for j=1:1:n
     94 
     95             if j==n
     96 
     97                 fprintf(fid,'%g
    ',Data(i,j));
     98 
     99             else
    100 
    101                 fprintf(fid,'%g	',Data(i,j));
    102 
    103          end
    104 
    105         end
    106 
    107      end
    108 
    109     fclose(fid);
    110 
    111 end

     

  • 相关阅读:
    2020-2021第一学期2024"DCDD"小组第十二周讨论
    2020-2021第一学期《网络空间安全导论》第十二周自习总结
    2020-2021第一学期2024"DCDD"小组第十一周讨论
    2020-2021第一学期《网络空间安全导论》第十一周自习总结
    2020-2021第一学期2024"DCDD"小组第十周讨论
    2019-2020-1 20165213 20165224 20165311 《信息安全系统设计基础》实验五 通讯协议设计
    2019-2020-1 实验三 并发程序
    2019-2020-1 20165213 20165224 20165311 实验二
    2019-2020-1 20165213 20165224 20165311 实验一 开发环境的熟悉
    20165213 Exp9 Web安全基础
  • 原文地址:https://www.cnblogs.com/7899-89/p/3613169.html
Copyright © 2020-2023  润新知