• java实现gbdt


    DATA类

    import java.io.File;
    import java.io.FileNotFoundException;
    import java.util.ArrayList;
    import java.util.Scanner;
    
    public class Data {
    	private ArrayList<ArrayList<String>> trainData=new ArrayList<ArrayList<String>>();
    	public ArrayList<ArrayList<String>> getTrainData() {
    		return this.trainData;
    	}
    
    	public Data() {
    		String dataPath="D://javajavajava//dbdt//src//script//data//adult.data.csv";
    		Scanner in;
    		try {
    			in = new Scanner(new File(dataPath));
    			while (in.hasNext()) {
    				String line=in.nextLine();
    				String []strs=line.trim().split(",");
    				ArrayList<String> tmp=new ArrayList<>();
    				for(int i=0;i<strs.length;i++)
    				{
    					tmp.add(strs[i]);	
    				}
    				this.trainData.add(tmp);
    			}
    		} catch (FileNotFoundException e) {
    			// TODO Auto-generated catch block
    			e.printStackTrace();
    		}
    		
    	}
    
    	public static void main(String[] args) {
    		// TODO Auto-generated method stub
    		Data d =new Data();
    		
    	}
    
    }
    

      TREE类

    import java.util.ArrayList;
    import java.util.HashSet;
    import java.util.Iterator;
    import java.util.Random;
    import java.util.spi.TimeZoneNameProvider;
    
    public class Tree {
    	private Tree leftTree=new Tree();
    	private Tree rightTree=new Tree();
    	private double loss=-1;
    	private int attributeSplit=0;
    	private String attributeSplitType="";
    	boolean isLeaf;
    	double leafValue;
    	private ArrayList<Integer> leafNodeSet=new ArrayList<>();
    	
    	public ArrayList<String> getAttributeSet(ArrayList<ArrayList<String>> trainData,int idx)
    	{
    		HashSet<String> mySet=new HashSet<>();
    		ArrayList<String> ans =new ArrayList<>();
    		for(int i=0;i<trainData.size();i++)
    		{
    			mySet.add(trainData.get(i).get(idx));
    		}
    		
    		Iterator<String> it=mySet.iterator();
    		
    		while(it.hasNext())
    		{
    			ans.add(it.next());
    		}
    		
    		return ans;
    	}
    	public boolean myCmpLess(String str1,String str2)
    	{
    		if(Integer.parseInt(str1.trim())<=Integer.parseInt(str2.trim()))
    			return true;
    		else return false;
    		
    	}
    	public double computeLoss(ArrayList<Double> values)
    	{
    		double loss=0;
    		for(int i=0;i<values.size();i++)
    		{
    			loss+=values.get(i);
    		}
    		double mean=loss/values.size();
    		loss=0;
    		for(int i=0;i<values.size();i++)
    		{
    			loss+=Math.pow(values.get(i)-mean,2);
    		}
    		return Math.sqrt(loss);
    	}
    	public double getPredictValue(int K, ArrayList<Integer> subIdx,ArrayList<Double> target) {
    		double ans=0;
    		double sum=0,sum1=0;
    		for(int i=0;i<subIdx.size();i++)
    		{
    			sum+=target.get(subIdx.get(i));
    		}
    		for(int i=0;i<subIdx.size();i++)
    		{
    			sum1+=target.get(subIdx.get(i))*(1-target.get(subIdx.get(i)));
    		}
    		ans=(K-1)/K*sum/sum1;
    		return ans;
    	}
    	public double getPredictValue(Tree root)
    	{
    		return root.leafValue;
    	}
    	public double getPredictValue(Tree root,ArrayList<String> instance,Boolean isDigit[])
    	{
    		
    		if(root.isLeaf)
    			return root.leafValue;
    		else if(isDigit[root.attributeSplit])
    		{
    			if(myCmpLess(instance.get(root.attributeSplit).trim(),root.attributeSplitType))
    				return getPredictValue(root.leftTree, instance, isDigit);
    			return getPredictValue(root.rightTree, instance, isDigit);
    		}
    		else
    		{
    			if(instance.get(root.attributeSplit).trim().equals(root.attributeSplitType))
    				return getPredictValue(root.leftTree, instance, isDigit);
    			return getPredictValue(root.rightTree, instance, isDigit);
    		}
    		
    	}
    	public Tree constructTree(ArrayList<ArrayList<Integer>> leafNodes,ArrayList<Double> leafValues,int K,int splitPoints, Boolean isDigit[],ArrayList<Integer> subIdx,ArrayList<ArrayList<String>> trainData,ArrayList<Double> target,int maxDepth[],int depth)
    	{
    		
    		int n=trainData.size();
    		int dim=trainData.get(0).size();
    		ArrayList<Integer> leftTreeIdx=new ArrayList<>();
    		ArrayList<Integer> rightTreeIdx=new ArrayList<>();
    		
    		if(depth<maxDepth[0])
    		{
    			/*
    			 * 从所有的attribute中选取最佳的attribute,并且attribute中最佳的分割点,对数据进行分割
    			 * */
    			double loss=-1;
    			ArrayList<Integer> leftNodes=new ArrayList<>();
    			ArrayList<Integer> rightNodes=new ArrayList<>();
    			int attributeSplit=0;
    			String attributeSplitType="";
    			
    			for(int i=0;i<dim;i++)//遍历所有的attribute
    			{
    				//得到该attribute下所有的distinct的值
    				ArrayList<String> myAttributeSet=new ArrayList<>();
    				ArrayList<String> subDigitAttribute=new ArrayList<>();
    				myAttributeSet=getAttributeSet(trainData, i);
    				if(isDigit[i])//如果是数字,就从数组中随机选取splitpoints个节点,代表这个属性可以在这splitpoints下进行分割
    				{
    					while(subDigitAttribute.size()<splitPoints)
    					{
    						Random r=new Random();
    						int tmp=r.nextInt(myAttributeSet.size());
    						subDigitAttribute.add(myAttributeSet.get(tmp));
    						myAttributeSet.clear();
    						myAttributeSet=subDigitAttribute;
    					}
    				}
    				for(int j=0;j<myAttributeSet.size();j++)
    				{
    					for(int k=0;k<subIdx.size();k++)
    					{
    						if((!isDigit[i]&&trainData.get(subIdx.get(k)).get(i).trim().equals(myAttributeSet.get(j)))||(isDigit[i]&&myCmpLess(trainData.get(subIdx.get(k)).get(i),myAttributeSet.get(j))))
    						{
    							leftTreeIdx.add(subIdx.get(k));
    						}
    						else
    						{
    							rightTreeIdx.add(subIdx.get(k));
    						}
    					}
    					ArrayList<Double> leftTarget=new ArrayList<>();
    					ArrayList<Double> rightTarget=new ArrayList<>();
    					for(int k=0;k<leftTreeIdx.size();k++)
    						leftTarget.add(target.get(leftTreeIdx.get(k)));
    					for(int k=0;k<rightTreeIdx.size();k++)
    						rightTarget.add(target.get(rightTreeIdx.get(k)));
    					double lossTmp=computeLoss(leftTarget)+computeLoss(rightTarget);	
    					if(loss<0||loss<lossTmp)
    					{
    						leftNodes.clear();
    						rightNodes.clear();
    						for(int k=0;k<leftTreeIdx.size();k++)
    							leftNodes.add(leftTreeIdx.get(k));
    						for(int k=0;k<rightTreeIdx.size();k++)
    							rightNodes.add(rightTreeIdx.get(k));
    						attributeSplit=i;
    						attributeSplitType=myAttributeSet.get(j);
    					}
    					
    				}
    						
    			}
    			
    			Tree tmpTree=new Tree();
    			tmpTree.attributeSplit=attributeSplit;
    			tmpTree.attributeSplitType=attributeSplitType;
    			tmpTree.loss=loss;
    			tmpTree.isLeaf=false;
    			tmpTree.leftTree=constructTree(leafNodes,leafValues,K,splitPoints, isDigit, leftNodes, trainData, target, maxDepth, depth+1);
    			tmpTree.leftTree=constructTree(leafNodes,leafValues,K,splitPoints, isDigit, rightNodes, trainData, target, maxDepth, depth+1);
    			return tmpTree;
    			
    		}
    		else
    		{
    			Tree tmpTree=new Tree();
    			tmpTree.isLeaf=true;
    			tmpTree.leafValue=getPredictValue(K, subIdx, target);
    			for(int i=0;i<subIdx.size();i++)
    				tmpTree.leafNodeSet.add(subIdx.get(i));
    			leafNodes.add(subIdx);
    			leafValues.add(tmpTree.leafValue);
    			return tmpTree;
    		}
    	}
    	
    	public static void main(String[] args) {
    		// TODO Auto-generated method stub
    		Tree aTree=new Tree();
    	}
    
    }
    

      

    GBDT类

    import java.rmi.server.SkeletonNotFoundException;
    import java.util.ArrayList;
    import java.util.HashSet;
    import java.util.Iterator;
    import java.util.Map;
    import java.util.Map.Entry;
    import java.util.Random;
    import java.util.Set;
    
    
    public class GBDT {
    	
    	private ArrayList<ArrayList<String>> datas=new ArrayList<ArrayList<String>>();
    	private ArrayList<String> labelSets=new ArrayList<>();
    	private ArrayList<ArrayList<Double>> F=new ArrayList<ArrayList<Double>>();
    	private ArrayList<ArrayList<Double>> residual=new ArrayList<ArrayList<Double>>();
    	private ArrayList<ArrayList<String>> trainData=new ArrayList<ArrayList<String>>();
    	private ArrayList<Integer> labelTrainData=new ArrayList<Integer>();
    	private int K;
    	private Boolean isDigit[];
    	private int dim;
    	private int n;
    	private double learningRate;
    	
    	private ArrayList<ArrayList<Tree>> trees=new ArrayList<ArrayList<Tree>>(); //存放所有的树
    	
    	private int max_iter;
    	private double sampleRate;
    	private int maxDepth;
    	private int splitPoints;
    
    	public void computeResidual(ArrayList<Integer> subId)
    	{
    		for(int i=0;i<subId.size();i++)
    		{
    			int idx=subId.get(i);
    			int y=0;
    			if(this.labelTrainData.get(idx)==-1) y=0;
    			else y=1;
    			double sum=Math.exp(this.F.get(idx).get(0))+Math.exp(this.F.get(idx).get(1));
    			double p1=Math.exp(this.F.get(idx).get(0))/sum,p2=Math.exp(this.F.get(idx).get(1))/sum;
    			this.residual.get(idx).set(0, y-p1);
    			this.residual.get(idx).set(1, y-p2);
    		}
    	}
    	public ArrayList<Integer> myrandom(int maxNum,int num)
    	{
    		ArrayList<Integer> ans=new ArrayList<>();
    		Set<Integer> mySet=new HashSet<>();
    		while(mySet.size()<num)
    		{
    			Random r=new Random();
    			int tmp=r.nextInt(maxNum);
    			mySet.add(tmp);
    		}
    		Iterator<Integer> it=mySet.iterator();
    		while(it.hasNext())
    		{
    			ans.add(it.next());
    		}
    		return ans;
    	}
    	
    	public GBDT()
    	{
    		this.max_iter=50;
    		this.sampleRate=0.8;
    		this.K=2;//2分类问题
    		this.maxDepth=6;
    		this.splitPoints=3;
    		this.learningRate=0.01;
    		getData();
    	}
    	
    	public void train()
    	{
    		for(int i=0;i<max_iter;i++)
    		{
    			ArrayList<Integer> subSet=new ArrayList<>();
    			int numSubset=(int)(this.n*this.sampleRate);
    			subSet=myrandom(this.n,numSubset);
    			computeResidual(subSet);
    			ArrayList<Double> target=new ArrayList<>();
    			ArrayList<Tree> tmpTree=new ArrayList<>();
    			int maxdepths[]={this.maxDepth};
    			for(int j=0;j<this.K;j++)
    			{
    				target.clear();
    				for(int k=0;k<subSet.size();k++)
    				{
    					target.add(residual.get(subSet.get(k)).get(j));
    				}
    				ArrayList<ArrayList<Integer>> leafNodes=new ArrayList<ArrayList<Integer>>();
    				ArrayList<Double> leafValues=new ArrayList<>();
    				Tree treeSub=new Tree();
    				Tree iterTree=treeSub.constructTree(leafNodes,leafValues,K,splitPoints, isDigit, subSet, trainData, target,maxdepths,0);
    				tmpTree.add(iterTree);
    				updateFvalue(isDigit, subSet,leafNodes,leafValues,j,iterTree);
    			}
    			
    			trees.add(tmpTree);
    		}
    	}
    	
    	public void updateFvalue(Boolean isDigit[], ArrayList<Integer> subIdx,ArrayList<ArrayList<Integer>> leafNodes,ArrayList<Double> leafValues,int label,Tree root)
    	{
    		ArrayList<Integer> remainIdx=new ArrayList<>();
    		int arr[]=new int[this.n];
    		for(int i=0;i<this.n;i++)
    			arr[i]=i;
    		for(int i=0;i<subIdx.size();i++)
    		{
    			arr[subIdx.get(i)]=-1;
    		}
    		//求出不是用来训练树的余下集合
    		for(int i=0;i<this.n;i++)
    		{
    			if(arr[i]!=-1)
    				remainIdx.add(i);
    		}
    		for(int i=0;i<leafNodes.size();i++)
    		{
    			for(int j=0;j<leafNodes.get(i).size();j++)
    			{
    				this.F.get(leafNodes.get(i).get(j)).set(label, this.F.get(leafNodes.get(i).get(j)).get(label)+this.learningRate*root.getPredictValue(root));
    			}
    		}
    		for(int i=0;i<remainIdx.size();i++)
    		{
    			double leafV=root.getPredictValue(root,this.trainData.get(remainIdx.get(i)),isDigit);
    			this.F.get(remainIdx.get(i)).set(label, this.F.get(remainIdx.get(i)).get(label)+this.learningRate*leafV);
    		}
    		
    		
    	}
    	
    	public boolean checkDigit(String str) {
    		for(int i=0;i<str.length();i++)
    		{
    			if(!(str.charAt(i)>='0'&&str.charAt(i)<='9'))
    			{
    				return false;
    			}
    		}
    		return true;
    	}
    	
    	public void getData() {
    		Data d =new Data();
    		this.datas=d.getTrainData();
    		this.dim=this.datas.get(0).size()-1;
    		this.isDigit=new Boolean[this.dim];
    		//遍历所有样本,去掉中间含有不是正常的数据
    		for(int i=0;i<this.datas.get(0).size()-1;i++)
    			labelSets.add(this.datas.get(0).get(i));
    		//保证数据的第一行是正确的,来判断,特征哪些纬度是数字,哪些纬度是字符串
    		for(int i=0;i<this.dim;i++)
    		{
    			if(checkDigit(this.datas.get(0).get(i)))
    				this.isDigit[i]=true;
    			else this.isDigit[i]=false;
    		}
    		//如果字符串==?说明是异常数据,这里做数据的清理
    		for(int i=1;i<this.datas.size();i++)
    		{
    			ArrayList<String> tmp=new ArrayList<>();
    			boolean flag=true;
    			for(int j=0;j<this.dim;j++)
    			{
    				if(datas.get(i).get(j).trim().equals("?"))
    				{
    					flag=false;
    					break;
    				}
    			}
    			if(!flag) continue;
    			if(datas.get(i).get(this.dim).trim().equals("?")) continue;
    			trainData.add(tmp);
    			if(datas.get(i).get(this.dim).trim().equals("<=50K")) 
    				labelTrainData.add(-1);
    			else
    				labelTrainData.add(1);
    			
    		}
    		this.n=this.labelTrainData.size();
    		
    		for(int i=0;i<this.datas.get(0).size()-1;i++)
    			labelSets.add(this.datas.get(0).get(i));
    		
    		//初始化F矩阵为全0,F矩阵是n*2,是2分类问题,如果要多分类,改下这里就可以了
    		for(int i=0;i<this.n;i++)
    		{
    			ArrayList<Double> arrTmp=new ArrayList<Double>();
    			for(int j=0;j<2;j++)
    			{
    				arrTmp.add(0.0);
    			}
    			this.F.add(arrTmp);
    			this.residual.add(arrTmp);
    		}
    		
    							
    	}
    	
    	public static void main(String[] args) {
    		GBDT dGbdt=new GBDT();
    		dGbdt.getData();
    		System.err.println(dGbdt.n);
    		
    	}
    }
    

      

  • 相关阅读:
    PostgreSQL事务特性之嵌套事务
    __attribute__((format(printf, a, b)))
    N个数依次入栈,出栈顺序有多少种?
    操作系统页面置换算法(opt,lru,fifo,clock)实现
    codeforces Round #320 (Div. 2) C. A Problem about Polyline(数学) D. "Or" Game(暴力,数学)
    基于X86平台的PC机通过网络发送一个int(32位)整数的字节顺序
    c/c++多线程模拟系统资源分配(并通过银行家算法避免死锁产生)
    Windows下使用Dev-C++开发基于pthread.h的多线程程序
    斐波那契的四种求法
    红黑树的插入
  • 原文地址:https://www.cnblogs.com/wuxiangli/p/6287624.html
Copyright © 2020-2023  润新知