• 视觉十四讲:第六讲_手写高斯牛顿法


    1.优化问题:

    (y=exp(ax^{2}+bx+c)+w),由y和x,求解a,b,c

    误差为:(e_{i}=y_{i}-exp(ax_{i}^{2}+bx_{i}+c))

    误差项对每一个待估计量进行求导:

    (frac{partial e_{i}}{partial a}=-x^{2}_{i}exp(ax^{2}_{i}+bx_{i}+c))
    (frac{partial e_{i}}{partial b}=-x_{i}exp(ax^{2}_{i}+bx_{i}+c))
    (frac{partial e_{i}}{partial c}=-exp(ax^{2}_{i}+bx_{i}+c))

    雅可比矩阵(J_{i}=[frac{partial e_{i}}{partial a},frac{partial e_{i}}{partial b},frac{partial e_{i}}{partial c}]^{T}),
    高斯的增量方程为: ((displaystyle sum^{100}_{i=1} J_{i}(sigma^{2})^{-1} J_{i}^{T})Delta x_{k}=displaystyle sum^{100}_{i=1} -J_{i}(sigma^{2})^{-1} e_{i})
    (HDelta x_{k}=b)

    噪声满足 w ~ ( 0,(sigma^{2})

    #include <iostream>
    #include <chrono>
    #include <opencv2/opencv.hpp>
    #include <Eigen/Core>
    #include <Eigen/Dense>
    
    using namespace std;
    using namespace Eigen;
    
    
    int main(int argc, char **argv) {
        double ar = 1.0, br = 2.0, cr = 1.0;    //真实参考值
        double ae = 20.0, be = -10.0, ce = 10.0; //初始值,不能太大,初始化很重要
    
        int N = 100;   //数据总点数
    
        double w_sigma = 1.0;  //噪声sigma值
        double inv_sigma = 1.0 / w_sigma;
    
        cv::RNG rng;   // opencv随机数产生
    
        vector<double> x_data, y_data;  //数据,生成真值数据加上随机数模拟实际采样值
        for(int i=0; i<N; i++){
            double x = i / 100.0;
            x_data.push_back(x);
            y_data.push_back( exp(ar*x*x + br*x + cr) + rng.gaussian(w_sigma*w_sigma) );
    
        }
    
        int iterations = 100;  //迭代次数
        double cost = 0, lastCost= 0;  //每次迭代的误差平方和,用于判断退出迭代次数
    
        chrono::steady_clock::time_point t1 = chrono::steady_clock::now();
        for ( int iter=0; iter<iterations; iter++ ){
            Matrix3d H = Matrix3d::Zero();
            Vector3d b = Vector3d::Zero();
            cost = 0;
    
            for(int i=0; i<N; i++){
                double xi = x_data[i], yi = y_data[i];
                double error = yi - exp( ae*xi*xi + be*xi + ce );  
                Vector3d J;   //雅克比矩阵
                J[0] = -xi * xi * exp(ae * xi * xi + be * xi + ce);  // de/da
                J[1] = -xi * exp(ae * xi * xi + be * xi + ce);  // de/db
                J[2] = -exp(ae * xi * xi + be * xi + ce);  // de/dc
    
                H += inv_sigma * inv_sigma * J * J.transpose();
                b += -inv_sigma * inv_sigma * error * J;
    
                cost += error * error;
            }
    
            //求解线性方程  Hx = b
            Vector3d dx = H.ldlt().solve(b);
            if (isnan(dx[0])) {
                cout << "result is nan!" << endl;
                break;
            }
    
            if (iter > 0 && cost >= lastCost) {   //误差变大,找到最小值,退出迭代
                cout << "cost: " << cost << ">= last cost: " << lastCost << ", break." << endl;
                break;
            }
    
            ae += dx[0];
            be += dx[1];
            ce += dx[2];
    
            lastCost = cost;
    
            cout << "total cost: " << cost << ", 		update: " << dx.transpose() <<
                 "		estimated params: " << ae << "," << be << "," << ce << endl;
        }
        chrono::steady_clock::time_point t2 = chrono::steady_clock::now();
        chrono::duration<double> time_used = chrono::duration_cast<chrono::duration<double>>(t2 - t1);
        cout << "solve time cost = " << time_used.count() << " seconds. " << endl;
    
        cout << "estimated abc = " << ae << ", " << be << ", " << ce << endl;
        return 0;
    }
    
    

    CMakelists.txt:

    cmake_minimum_required(VERSION 2.8)
    project(gaussnewton)
    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
    find_package(OpenCV REQUIRED)
    include_directories(${OpenCV_INCLUDE_DIRS})
    include_directories("/usr/include/eigen3")
    set(SOURCE_FILES main.cpp)
    add_executable(gaussnewton ${SOURCE_FILES})
    target_link_libraries(gaussnewton ${OpenCV_LIBS})
    
  • 相关阅读:
    Comprehend-Elasticsearch-Demo5
    Mxnet使用TensorRT加速模型--Mxnet官方例子
    Mxnet模型转换ONNX,再用tensorrt执行前向运算
    MxNet模型转换Onnx
    基于Flask-APScheduler实现添加动态定时任务
    Golang习题
    算法题
    Celery使用指南
    flask拓展(数据库操作)
    flask进阶(上下文源管理源码浅析)
  • 原文地址:https://www.cnblogs.com/penuel/p/12941607.html
Copyright © 2020-2023  润新知