• 最小二乘法多项式拟合的Java实现


    背景

    由项目中需要根据一些已有数据学习出一个y=ax+b的一元二项式,给定了x,y的一些样本数据,通过梯度下降或最小二乘法做多项式拟合得到a、b,解决该问题时,首先想到的是通过spark mllib去学习,可是结果并不理想:少量的文档,参数也很难调整。于是转变了解决问题的方式:采用了最小二乘法做多项式拟合。

    最小二乘法多项式拟合描述下: (以下参考:https://blog.csdn.net/funnyrand/article/details/46742561)

    假设给定的数据点和其对应的函数值为 (x1, y1), (x2, y2), ... (xm, ym),需要做的就是得到一个多项式函数f(x) = a0 * x + a1 * pow(x, 2) + .. + an * pow(x, n),使其对所有给定x所计算出的f(x)与实际对应的y值的差的平方和最小,也就是计算多项式的各项系数 a0, a1, ... an. 

    根据最小二乘法的原理,该问题可转换为求以下线性方程组的解:Ga = B

    所以从编程的角度来说需要做两件事情:

    1)确定线性方程组的各个系数:

    确定系数比较简单,对给定的 (x1, y1), (x2, y2), ... (xm, ym) 做相应的计算即可,相关代码:

    private void compute() {
      ...
    }

    2)解线性方程组:

    解线性方程组稍微复杂,这里用到了高斯消元法,基本思想是通过递归做矩阵转换,逐渐减少求解的多项式系数的个数,相关代码:

    private double[] calcLinearEquation(double[][] a, double[] b) {
      ...
    }

    Java代码

      1 public class JavaLeastSquare {
      2     private double[] x;
      3     private double[] y;
      4     private double[] weight;
      5     private int n;
      6     private double[] coefficient;
      7 
      8     /**
      9      * Constructor method.
     10      * @param x Array of x
     11      * @param y Array of y
     12      * @param n The order of polynomial
     13      */
     14     public JavaLeastSquare(double[] x, double[] y, int n) {
     15         if (x == null || y == null || x.length < 2 || x.length != y.length
     16                 || n < 2) {
     17             throw new IllegalArgumentException(
     18                     "IllegalArgumentException occurred.");
     19         }
     20         this.x = x;
     21         this.y = y;
     22         this.n = n;
     23         weight = new double[x.length];
     24         for (int i = 0; i < x.length; i++) {
     25             weight[i] = 1;
     26         }
     27         compute();
     28     }
     29 
     30     /**
     31      * Constructor method.
     32      * @param x      Array of x
     33      * @param y      Array of y
     34      * @param weight Array of weight
     35      * @param n      The order of polynomial
     36      */
     37     public JavaLeastSquare(double[] x, double[] y, double[] weight, int n) {
     38         if (x == null || y == null || weight == null || x.length < 2
     39                 || x.length != y.length || x.length != weight.length || n < 2) {
     40             throw new IllegalArgumentException(
     41                     "IllegalArgumentException occurred.");
     42         }
     43         this.x = x;
     44         this.y = y;
     45         this.n = n;
     46         this.weight = weight;
     47         compute();
     48     }
     49 
     50     /**
     51      * Get coefficient of polynomial.
     52      * @return coefficient of polynomial
     53      */
     54     public double[] getCoefficient() {
     55         return coefficient;
     56     }
     57 
     58     /**
     59      * Used to calculate value by given x.
     60      * @param x x
     61      * @return y
     62      */
     63     public double fit(double x) {
     64         if (coefficient == null) {
     65             return 0;
     66         }
     67         double sum = 0;
     68         for (int i = 0; i < coefficient.length; i++) {
     69             sum += Math.pow(x, i) * coefficient[i];
     70         }
     71         return sum;
     72     }
     73 
     74     /**
     75      * Use Newton's method to solve equation.
     76      * @param y y
     77      * @return x
     78      */
     79     public double solve(double y) {
     80         return solve(y, 1.0d);
     81     }
     82 
     83     /**
     84      * Use Newton's method to solve equation.
     85      * @param y      y
     86      * @param startX The start point of x
     87      * @return x
     88      */
     89     public double solve(double y, double startX) {
     90         final double EPS = 0.0000001d;
     91         if (coefficient == null) {
     92             return 0;
     93         }
     94         double x1 = 0.0d;
     95         double x2 = startX;
     96         do {
     97             x1 = x2;
     98             x2 = x1 - (fit(x1) - y) / calcReciprocal(x1);
     99         } while (Math.abs((x1 - x2)) > EPS);
    100         return x2;
    101     }
    102 
    103     /*
    104      * Calculate the reciprocal of x.
    105      * @param x x
    106      * @return the reciprocal of x
    107      */
    108     private double calcReciprocal(double x) {
    109         if (coefficient == null) {
    110             return 0;
    111         }
    112         double sum = 0;
    113         for (int i = 1; i < coefficient.length; i++) {
    114             sum += i * Math.pow(x, i - 1) * coefficient[i];
    115         }
    116         return sum;
    117     }
    118 
    119     /*
    120      * This method is used to calculate each elements of augmented matrix.
    121      */
    122     private void compute() {
    123         if (x == null || y == null || x.length <= 1 || x.length != y.length
    124                 || x.length < n || n < 2) {
    125             return;
    126         }
    127         double[] s = new double[(n - 1) * 2 + 1];
    128         for (int i = 0; i < s.length; i++) {
    129             for (int j = 0; j < x.length; j++) {
    130                 s[i] += Math.pow(x[j], i) * weight[j];
    131             }
    132         }
    133         double[] b = new double[n];
    134         for (int i = 0; i < b.length; i++) {
    135             for (int j = 0; j < x.length; j++) {
    136                 b[i] += Math.pow(x[j], i) * y[j] * weight[j];
    137             }
    138         }
    139         double[][] a = new double[n][n];
    140         for (int i = 0; i < n; i++) {
    141             for (int j = 0; j < n; j++) {
    142                 a[i][j] = s[i + j];
    143             }
    144         }
    145 
    146         // Now we need to calculate each coefficients of augmented matrix
    147         coefficient = calcLinearEquation(a, b);
    148     }
    149 
    150     /*
    151      * Calculate linear equation.
    152      * The matrix equation is like this: Ax=B
    153      * @param a two-dimensional array
    154      * @param b one-dimensional array
    155      * @return x, one-dimensional array
    156      */
    157     private double[] calcLinearEquation(double[][] a, double[] b) {
    158         if (a == null || b == null || a.length == 0 || a.length != b.length) {
    159             return null;
    160         }
    161 
    162         for (double[] x : a) {
    163             if (x == null || x.length != a.length)
    164                 return null;
    165         }
    166 
    167         int len = a.length - 1;
    168         double[] result = new double[a.length];
    169 
    170         if (len == 0) {
    171             result[0] = b[0] / a[0][0];
    172             return result;
    173         }
    174 
    175         double[][] aa = new double[len][len];
    176         double[] bb = new double[len];
    177         int posx = -1, posy = -1;
    178         for (int i = 0; i <= len; i++) {
    179             for (int j = 0; j <= len; j++)
    180                 if (a[i][j] != 0.0d) {
    181                     posy = j;
    182                     break;
    183                 }
    184             if (posy != -1) {
    185                 posx = i;
    186                 break;
    187             }
    188         }
    189         if (posx == -1) {
    190             return null;
    191         }
    192 
    193         int count = 0;
    194         for (int i = 0; i <= len; i++) {
    195             if (i == posx) {
    196                 continue;
    197             }
    198             bb[count] = b[i] * a[posx][posy] - b[posx] * a[i][posy];
    199             int count2 = 0;
    200             for (int j = 0; j <= len; j++) {
    201                 if (j == posy) {
    202                     continue;
    203                 }
    204                 aa[count][count2] = a[i][j] * a[posx][posy] - a[posx][j] * a[i][posy];
    205                 count2++;
    206             }
    207             count++;
    208         }
    209 
    210         // Calculate sub linear equation
    211         double[] result2 = calcLinearEquation(aa, bb);
    212 
    213         // After sub linear calculation, calculate the current coefficient
    214         double sum = b[posx];
    215         count = 0;
    216         for (int i = 0; i <= len; i++) {
    217             if (i == posy) {
    218                 continue;
    219             }
    220             sum -= a[posx][i] * result2[count];
    221             result[i] = result2[count];
    222             count++;
    223         }
    224         result[posy] = sum / a[posx][posy];
    225         return result;
    226     }
    227 
    228     public static void main(String[] args) {
    229         JavaLeastSquare eastSquareMethod = new JavaLeastSquare(
    230                 new double[]{
    231                         2, 14, 20, 25, 26, 34,
    232                         47, 87, 165, 265, 365, 465,
    233                         565, 665
    234                 },
    235                 new double[]{
    236                         0.7 * 2 + 20 + 0.4,
    237                         0.7 * 14 + 20 + 0.5,
    238                         0.7 * 20 + 20 + 3.4,
    239                         0.7 * 25 + 20 + 5.8,
    240                         0.7 * 26 + 20 + 8.27,
    241                         0.7 * 34 + 20 + 0.4,
    242 
    243                         0.7 * 47 + 20 + 0.1,
    244                         0.7 * 87 + 20,
    245                         0.7 * 165 + 20,
    246                         0.7 * 265 + 20,
    247                         0.7 * 365 + 20,
    248                         0.7 * 465 + 20,
    249 
    250                         0.7 * 565 + 20,
    251                         0.7 * 665 + 20
    252                 },
    253                 2);
    254 
    255         double[] coefficients = eastSquareMethod.getCoefficient();
    256         for (double c : coefficients) {
    257             System.out.println(c);
    258         }
    259 
    260         // 测试
    261         System.out.println(eastSquareMethod.fit(4));
    262     }
    263 }

    输出结果:

    com.datangmobile.biz.leastsquare.JavaLeastSquare
    22.27966881467629
    0.6952475907448203
    25.06065917765557

    Process finished with exit code 0

    使用开源库

    也可使用Apache开源库commons math(http://commons.apache.org/proper/commons-math/userguide/fitting.html),提供的功能更强大:

    <dependency>  
        <groupId>org.apache.commons</groupId>  
        <artifactId>commons-math3</artifactId>  
        <version>3.5</version>  
    </dependency>  

    实现代码:

    import org.apache.commons.math3.fitting.PolynomialCurveFitter;
    import org.apache.commons.math3.fitting.WeightedObservedPoints;
    
    public class WeightedObservedPointsTest {
        public static void main(String[] args) {
            final WeightedObservedPoints obs = new WeightedObservedPoints();
            obs.add(2,  0.7 * 2 + 20 + 0.4);
            obs.add(12,  0.7 * 12 + 20 + 0.3);
            obs.add(32,  0.7 * 32 + 20 + 3.4);
            obs.add(34 ,  0.7 * 34 + 20 + 5.8);
            obs.add(58 , 0.7 * 58 + 20 + 8.4);
            obs.add(43 , 0.7 * 43 + 20 + 0.28);
            obs.add(27 , 0.7 * 27 + 20 + 0.4);
    
            // Instantiate a two-degree polynomial fitter.
            final PolynomialCurveFitter fitter = PolynomialCurveFitter.create(2);
    
            // Retrieve fitted parameters (coefficients of the polynomial function).
            final double[] coeff = fitter.fit(obs.toList());
            for (double c : coeff) {
                System.out.println(c);
            }
        }
    }

    测试输出结果:

    20.47425047847121
    0.6749744063035112
    0.002523043547711147

    Process finished with exit code 0

    使用org.ujmp(矩阵)实现最小二乘法:

    pom.xml中需要引入org.ujmp

    <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
        xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
        <modelVersion>4.0.0</modelVersion>
        <groupId>com.dtgroup</groupId>
        <artifactId>dtgroup</artifactId>
        <version>0.0.1-SNAPSHOT</version>
    
        <repositories>
            <repository>
                <id>limaven</id>
                <name>aliyun maven</name>
                <url>http://maven.aliyun.com/nexus/content/groups/public/</url>
                <layout>default</layout>
                <releases>
                    <enabled>true</enabled>
                </releases>
                <snapshots>
                    <enabled>false</enabled>
                </snapshots>
            </repository>
        </repositories>
        <dependencies>
            <dependency>
                <groupId>org.ujmp</groupId>
                <artifactId>ujmp-core</artifactId>
                <version>0.3.0</version>
            </dependency>
        </dependencies>
    </project>

    java代码:

        /**
         * 采用最小二乘法多项式拟合方式,获取多项式的系数。
         * @param sampleCount 采样点个数
         * @param fetureCount 多项式的系数
         * @param samples 采样点集合
         * **/
        private static void leastsequare(int sampleCount, int fetureCout, List<Sample> samples) {
            // 构件 2*2矩阵 存储X,元素值都为1.0000的矩阵
            Matrix matrixX = DenseMatrix.Factory.ones(sampleCount, fetureCout);
    
            for (int i = 0; i < samples.size(); i++) {
                matrixX.setAsDouble(samples.get(i).getX(), i, 1);
            }
    
            // System.out.println(matrixX);
            System.out.println("--------------------------------------");
            // 构件 2*2矩阵 存储X
            Matrix matrixY = DenseMatrix.Factory.ones(sampleCount, 1);
    
            for (int i = 0; i < samples.size(); i++) {
                matrixY.setAsDouble(samples.get(i).getY(), i, 0);
            }
            // System.out.println(matrixY);
    
            // 对X进行转置
            Matrix matrixXTrans = matrixX.transpose();
            // System.out.println(matrixXTrans);
    
            // 乘积运算:x*转转置后x:matrixXTrans*matrixX
            Matrix matrixMtimes = matrixXTrans.mtimes(matrixX);
            System.out.println(matrixMtimes);
    
            System.out.println("--------------------------------------");
            // 求逆
            Matrix matrixMtimesInv = matrixMtimes.inv();
            System.out.println(matrixMtimesInv);
    
            // x转置后结果*求逆结果
            System.out.println("--------------------------------------");
            Matrix matrixMtimesInvMtimes = matrixMtimesInv.mtimes(matrixXTrans);
            System.out.println(matrixMtimesInvMtimes);
    
            System.out.println("--------------------------------------");
            Matrix theta = matrixMtimesInvMtimes.mtimes(matrixY);
            System.out.println(theta);
        }

    测试代码:

        public static void main(String[] args) {
            /**
             * y=ax+b
             * 
             * a(0,1] b[5,20]
             * 
             * x[0,500] y>=5
             */
    
            // y= 0.8d*x+15
            // 当x不变动时,y对应有多个值;此时把y求均值。
            List<Sample> samples = new ArrayList<Sample>();
            samples.add(new Sample(0.8d * 1 + 15 + 1, 1d));
            samples.add(new Sample(0.8d * 4 + 15 + 0.8, 4d));
            samples.add(new Sample(0.8d * 3 + 15 + 0.7, 3d));
            samples.add(new Sample(0.8d * 24 + 15 + 0.4, 24d));
            samples.add(new Sample(0.8d * 5 + 15 + 0.3, 5d));
            samples.add(new Sample(0.8d * 10 + 15 + 0.4, 10d));
            samples.add(new Sample(0.8d * 14 + 15 + 0.2, 14d));
            samples.add(new Sample(0.8d * 7 + 15 + 0.3, 7d));
            samples.add(new Sample(0.8d * 1000 + 23 + 0.3, 70d));
    
            int sampleCount = samples.size();
            int fetureCout = 2;
    
            leastsequare(sampleCount, fetureCout, samples);
        }

    过滤样本中的噪点:

        public static void main(String[] args) {
            /**
             * y=ax+b
             * 
             * a(0,1] b[5,20]
             * 
             * x[0,500] y>=5
             */
    
            // y= 0.8d*x+15
            // 当x不变动时,y对应有多个值;此时把y求均值。
            List<Sample> samples = new ArrayList<Sample>();
            samples.add(new Sample(0.8d * 1 + 15 + 1, 1d));
            samples.add(new Sample(0.8d * 4 + 15 + 0.8, 4d));
            samples.add(new Sample(0.8d * 3 + 15 + 0.7, 3d));
            samples.add(new Sample(0.8d * 24 + 15 + 0.4, 24d));
            samples.add(new Sample(0.8d * 5 + 15 + 0.3, 5d));
            samples.add(new Sample(0.8d * 10 + 15 + 0.4, 10d));
            samples.add(new Sample(0.8d * 14 + 15 + 0.2, 14d));
            samples.add(new Sample(0.8d * 7 + 15 + 0.3, 7d));
            samples.add(new Sample(0.8d * 1000 + 23 + 0.3, 70d));
    
            // samples = filterSample(samples);
            sortSample(samples);
            FilterSampleByGradientResult result = filterSampleByGradient(0, samples);
    
            while (result.isComplete() == false) {
                List<Sample> newSamples=result.getSamples(); 
                sortSample(newSamples);
                result = filterSampleByGradient(result.getIndex(), newSamples);
            }
            samples = result.getSamples();
    
            for (Sample sample : samples) {
                System.out.println(sample);
            }
    
            int sampleCount = samples.size();
            int fetureCout = 2;
    
            leastsequare(sampleCount, fetureCout, samples);
        }
    
        /**
         * 对采样点进行排序,按照x排序,升序排列
         * @param samples 采样点集合
         * **/
        private static void sortSample(List<Sample> samples) {
            samples.sort(new Comparator<Sample>() {
                public int compare(Sample o1, Sample o2) {
                    if (o1.getX() > o2.getX()) {
                        return 1;
                    } else if (o1.getX() <= o2.getX()) {
                        return -1;
                    }
                    return 0;
                }
            });
        }
    
        /**
         * 过滤采样点中的噪点(采样过滤方式:double theta=(y2-y1)/(x2-x1),theta就是一个斜率,根据该值范围来过滤。)
         * @param index 记录上次过滤索引
         * @param samples 采样点集合(将从其中过滤掉噪点)
         * **/
        private static FilterSampleByGradientResult filterSampleByGradient(int index, List<Sample> samples) {
            int sampleSize = samples.size();
            for (int i = index; i < sampleSize - 1; i++) {
                double delta_x = samples.get(i).getX() - samples.get(i + 1).getX();
                double delta_y = samples.get(i).getY() - samples.get(i + 1).getY();
                // 距离小于2米
                if (Math.abs(delta_x) < 1) {
                    double newY = (samples.get(i).getY() + samples.get(i + 1).getY()) / 2;
                    double newX = samples.get(i).getX();
    
                    samples.remove(i);
                    samples.remove(i + 1);
                    samples.add(new Sample(newY, newX));
    
                    return new FilterSampleByGradientResult(false, i, samples);
                } else {
                    double gradient = delta_y / delta_x;
                    if (gradient > 1.5) {
                        if (i == 0) {
                            // double newY = (samples.get(i).getY() + samples.get(i
                            // + 1).getY()) / 2;
                            // double newX = (samples.get(i).getX() + samples.get(i
                            // + 1).getX()) / 2;
    
                            // samples.remove(i);
                            // samples.add(new Sample(newY, newX));
                        } else {
                            samples.remove(i + 1);
                        }
    
                        return new FilterSampleByGradientResult(false, i, samples);
                    }
                }
            }
    
            return new FilterSampleByGradientResult(true, 0, samples);
        }

     使用距离来处理过滤:

        private static List<Sample> filterSample(List<Sample> samples) {
            // x={x1,x2,x3...xn}
            // u=E(x) ---x的期望(均值)为 u
            // 6=sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2))
            // 6为x的标准差,标准差=sqrt(方差)
            // 剔除噪点可以采用:
            // 若xi不属于(u-3*6,u+3*6),则认为它是噪点。
    
            // 另外一种方案,对x/y都做上边的处理,之后如果两个结果为and 或者 or操作来选取是否剔除。
            // 用点的方式来过滤数据,求出一个中值点,求其他点到该点的距离。
            int sampleCount = samples.size();
            double sumX = 0d;
            double sumY = 0d;
    
            for (Sample sample : samples) {
                sumX += sample.getX();
                sumY += sample.getY();
            }
    
            // 求中心点
            double centerX = (sumX / sampleCount);
            double centerY = (sumY / sampleCount);
    
            List<Double> distanItems = new ArrayList<Double>();
            // 计算出所有点距离该中心点的距离
            for (int i = 0; i < samples.size(); i++) {
                Sample sample = samples.get(i);
                Double xyPow2 = Math.pow(sample.getX() - centerX, 2) + Math.pow(sample.getY() - centerY, 2);
                distanItems.add(Math.sqrt(xyPow2));
            }
    
            // 以下对根据距离(所有点距离中心点的距离)进行筛选
            double sumDistan = 0d;
            double distanceU = 0d;
            for (Double distance : distanItems) {
                sumDistan += distance;
            }
            distanceU = sumDistan / sampleCount;
    
            double deltaPowSum = 0d;
            double distanceTheta = 0d;
            // sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2))
            for (Double distance : distanItems) {
                deltaPowSum += Math.pow((distance - distanceU), 2);
            }
            distanceTheta = Math.sqrt(deltaPowSum);
    
            // 剔除噪点可以采用:
            // 若xi不属于(u-3*6,u+3*6),则认为它是噪点。
            double minDistance = distanceU - 0.5 * distanceTheta;
            double maxDistance = distanceU + 0.5 * distanceTheta;
            List<Integer> willbeRemoveIdxs = new ArrayList<Integer>();
            for (int i = distanItems.size() - 1; i >= 0; i--) {
                Double distance = distanItems.get(i);
                if (distance <= minDistance || distance >= maxDistance) {
                    willbeRemoveIdxs.add(i);
                    System.out.println("will be remove " + i);
                }
            }
    
            for (int willbeRemoveIdx : willbeRemoveIdxs) {
                samples.remove(willbeRemoveIdx);
            }
    
            return samples;
        }

    实际业务测试:

    package com.zjanalyse.spark.maths;
    
    import java.util.ArrayList;
    import java.util.Comparator;
    import java.util.List;
    
    import org.ujmp.core.DenseMatrix;
    import org.ujmp.core.Matrix;
    
    public class LastSquare {
        /**
         * y=ax+b a(0,1] b[5,20] x[0,500] y>=5
         */
        public static void main(String[] args) {
            // y= 0.8d*x+15
            // 当x不变动时,y对应有多个值;此时把y求均值。
            List<Sample> samples = new ArrayList<Sample>();
            samples.add(new Sample(0.8d * 11 + 15 + 1, 11d));
            samples.add(new Sample(0.8d * 24 + 15 + 0.8, 24d));
            samples.add(new Sample(0.8d * 33 + 15 + 0.7, 33d));
            samples.add(new Sample(0.8d * 24 + 15 + 0.4, 24d));
            samples.add(new Sample(0.8d * 47 + 15 + 0.3, 47d));
            samples.add(new Sample(0.8d * 60 + 15 + 0.4, 60d));
            samples.add(new Sample(0.8d * 14 + 15 + 0.2, 14d));
            samples.add(new Sample(0.8d * 57 + 15 + 0.3, 57d));
            samples.add(new Sample(0.8d * 70 + 60 + 0.3, 70d));
            samples.add(new Sample(0.8d * 80 + 60 + 0.3, 80d));
            samples.add(new Sample(0.8d * 40 + 30 + 0.3, 40d));
    
            sortSample(samples);
            System.out.println("原始样本数据");
            for (Sample sample : samples) {
                System.out.println(sample);
            }
    
            System.out.println("开始“所有点”通过“业务数据取值范围”剔除:");
            // 按照业务过滤。。。
            filterByBusiness(samples);
            System.out.println("结束“所有点”通过“业务数据取值范围”剔除:");
    
            for (Sample sample : samples) {
                System.out.println(sample);
            }
    
            int sampleCount = samples.size();
            int fetureCout = 2;
            System.out.println("第一次拟合。。。");
            Matrix theta = leastsequare(sampleCount, fetureCout, samples);
    
            double wear_loss = theta.getAsDouble(0, 0);
            double path_loss = theta.getAsDouble(1, 0);
    
            System.out.println("wear loss " + wear_loss);
            System.out.println("path loss " + path_loss);
    
            System.out.println("开始“所有点”与“第一多项式拟合结果直线方式距离方差”剔除:");
            samples = filterSample(wear_loss, path_loss, samples);
            System.out.println("结束“所有点”与“第一多项式拟合结果直线方式距离方差”剔除:");
    
            for (Sample sample : samples) {
                System.out.println(sample);
            }
    
            System.out.println("第二次拟合。。。");
            sampleCount = samples.size();
            fetureCout = 2;
    
            if (sampleCount >= 2) {
                theta = leastsequare(sampleCount, fetureCout, samples);
    
                wear_loss = theta.getAsDouble(0, 0);
                path_loss = theta.getAsDouble(1, 0);
    
                System.out.println("wear loss " + wear_loss);
                System.out.println("path loss " + path_loss);
            }
            System.out.println("complete...");
        }
    
        /**
         * 按照业务过滤有效值范围
         */
        private static void filterByBusiness(List<Sample> samples) {
            for (int i = 0; i < samples.size(); i++) {
                double x = samples.get(i).getX();
                double y = samples.get(i).getY();
                if (x >= 500) {
                    System.out.println(x + " x值超出有效值范围[0,500)");
                    samples.remove(i);
                    i--;
                }
                // y= 0.8d*x+15
                else if (y < 0 * x + 5 || y > 1 * x + 30) {
                    System.out.println(
                            y + " y值超出有效值范围[(0*x+5),(1*x+30)]其中x=" + x + ",也就是:[" + (0 * x + 5) + "," + (1 * x + 30) + ")");
                    samples.remove(i);
                    i--;
                }
            }
        }
    
        /**
         * Description 点到直线的距离
         * 
         * @param x1
         *            点横坐标
         * @param y1
         *            点纵坐标
         * @param A
         *            直线方程一般式系数A
         * @param B
         *            直线方程一般式系数B
         * @param C
         *            直线方程一般式系数C
         * @return 点到之间的距离
         * @see 点0,1到之前y=x+0的距离 <br>
         *      double distance = getDistanceOfPerpendicular(0,0, -1, 1, 0);<br>
         *      System.out.println(distance);<br>
         */
        private static double getDistanceOfPerpendicular(double x1, double y1, double A, double B, double C) {
            double distance = Math.abs((A * x1 + B * y1 + C) / Math.sqrt(A * A + B * B));
            return distance;
        }
    
        private static List<Sample> filterSample(double wear_loss, double path_loss, List<Sample> samples) {
            // x={x1,x2,x3...xn}
            // u=E(x) ---x的期望(均值)为 u
            // 6=sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2))
            // 6为x的标准差,标准差=sqrt(方差)
            // 剔除噪点可以采用:
            // 若xi不属于(u-3*6,u+3*6),则认为它是噪点。
    
            // 求出所有点距离第一次拟合结果的直线方程的距离
            int sampleCount = samples.size();
            List<Double> distanItems = new ArrayList<Double>();
            // 计算出所有点距离该中心点的距离
            for (int i = 0; i < samples.size(); i++) {
                Sample sample = samples.get(i);
                double distance = getDistanceOfPerpendicular(sample.getX(), sample.getY(), path_loss, -1, wear_loss);
                distanItems.add(Math.sqrt(distance));
            }
    
            // 以下对根据距离(所有点距离中心点的距离)进行筛选
            double sumDistan = 0d;
            double distanceU = 0d;
            for (Double distance : distanItems) {
                sumDistan += distance;
            }
            distanceU = sumDistan / sampleCount;
    
            double deltaPowSum = 0d;
            double distanceTheta = 0d;
            // sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2))
            for (Double distance : distanItems) {
                deltaPowSum += Math.pow((distance - distanceU), 2);
            }
            distanceTheta = Math.sqrt(deltaPowSum);
    
            // 剔除噪点可以采用:
            // 若xi不属于(u-3*6,u+3*6),则认为它是噪点。
            double minDistance = distanceU - 0.25 * distanceTheta;
            double maxDistance = distanceU + 0.25 * distanceTheta;
            List<Integer> willbeRemoveIdxs = new ArrayList<Integer>();
    
            for (int i = distanItems.size() - 1; i >= 0; i--) {
                Double distance = distanItems.get(i);
                if (distance <= minDistance || distance >= maxDistance) {
                    System.out.println(distance + " out of range [" + minDistance + "," + maxDistance + "]");
                    willbeRemoveIdxs.add(i);
                } else {
                    System.out.println(distance);
                }
            }
    
            for (int willbeRemoveIdx : willbeRemoveIdxs) {
                Sample sample = samples.get(willbeRemoveIdx);
                System.out.println("remove " + sample);
                samples.remove(willbeRemoveIdx);
            }
    
            return samples;
        }
    
        /**
         * 对采样点进行排序,按照x排序,升序排列
         * 
         * @param samples
         *            采样点集合
         **/
        private static void sortSample(List<Sample> samples) {
            samples.sort(new Comparator<Sample>() {
                public int compare(Sample o1, Sample o2) {
                    if (o1.getX() > o2.getX()) {
                        return 1;
                    } else if (o1.getX() <= o2.getX()) {
                        return -1;
                    }
                    return 0;
                }
            });
        }
    
        /**
         * Description 采用最小二乘法多项式拟合方式,获取多项式的系数。
         * 
         * @param sampleCount
         *            采样点个数
         * @param fetureCount
         *            多项式的系数
         * @param samples
         *            采样点集合
         **/
        private static Matrix leastsequare(int sampleCount, int fetureCout, List<Sample> samples) {
            // 构件 2*2矩阵 存储X,元素值都为1.0000的矩阵
            Matrix matrixX = DenseMatrix.Factory.ones(sampleCount, fetureCout);
    
            for (int i = 0; i < samples.size(); i++) {
                matrixX.setAsDouble(samples.get(i).getX(), i, 1);
            }
    
            // System.out.println(matrixX);
            // System.out.println("--------------------------------------");
            // 构件 2*2矩阵 存储X
            Matrix matrixY = DenseMatrix.Factory.ones(sampleCount, 1);
    
            for (int i = 0; i < samples.size(); i++) {
                matrixY.setAsDouble(samples.get(i).getY(), i, 0);
            }
            // System.out.println(matrixY);
    
            // 对X进行转置
            Matrix matrixXTrans = matrixX.transpose();
            // System.out.println(matrixXTrans);
    
            // 乘积运算:x*转转置后x:matrixXTrans*matrixX
            Matrix matrixMtimes = matrixXTrans.mtimes(matrixX);
            // System.out.println(matrixMtimes);
    
            // System.out.println("--------------------------------------");
            // 求逆
            Matrix matrixMtimesInv = matrixMtimes.inv();
            // System.out.println(matrixMtimesInv);
    
            // x转置后结果*求逆结果
            // System.out.println("--------------------------------------");
            Matrix matrixMtimesInvMtimes = matrixMtimesInv.mtimes(matrixXTrans);
            // System.out.println(matrixMtimesInvMtimes);
    
            // System.out.println("--------------------------------------");
            Matrix theta = matrixMtimesInvMtimes.mtimes(matrixY);
            // System.out.println(theta);
    
            return theta;
        }
    }
    View Code
  • 相关阅读:
    win10系统下office 2019激活
    如何根据【抖音分享链接】去掉抖音水印
    Java多线程学习之ThreadLocal源码分析
    Java多线程学习之synchronized总结
    Java多线程学习之线程的取消与中断机制
    Java多线程学习之Lock与ReentranLock详解
    Java多线程学习之线程池源码详解
    MyBatis 一、二级缓存和自定义缓存
    Spring 高级依赖注入方式
    Spring 依赖注入的方式
  • 原文地址:https://www.cnblogs.com/yy3b2007com/p/8711946.html
Copyright © 2020-2023  润新知