6

最小二乘法学习

 2 years ago
source link: https://veviz.github.io/2016/12/24/Least%20squares/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

  最小二乘法(又称最小平方法)是一种数学优化技术。它通过最小化误差的平方和寻找数据的最佳函数匹配。利用最小二乘法可以简便地求得未知的数据,并使得这些求得的数据与实际数据之间误差的平方和为最小。最小二乘法还可用于曲线拟合。其他一些优化问题也可通过最小化能量或最大化熵用最小二乘法来表达。

  监督学习中,如果预测的变量是离散的,我们称其为分类(如决策树、支持向量机等),如果预测的变量是连续的,我们称其为回归。在回归分析中,如果只包括一个自变量和一个因变量,且二者的关系可用一条直线近似表示,则称为一元线性回归分析。如果回归分析中包括两个或者两个以上的自变量,且因变量和自变量之间是线性关系,则称为多远线性回归分析。
  对于二维空间线性是一条直线;对于三维空间线性是一个平面;对于多维空间线性是一个超平面。  


2. 一元线性回归

一元线性回归模型如公式1:

Y=kX+b+uY=kX+b+u

  在上面的公式中,YY 表示因变量(即响应变量,被预测变量),kk 表示斜率,bb 表示截距,XX 表示自变量(解释变量,预测变量),最后的uu 表示随机误差。
  于是,可以推出公式2:

u=Y−kX−bu=Y−kX−b

3. 参数估计-最小二乘法

  对于一元线性回归模型,假设有观测值$(x_1,y_1)…(x_n,yn),对于这n个点,需要用一条直线去拟合。综合起来看,这条直线应该处于样本数据的中心位置最合理。选择该直线的标准为:使总的拟合误差(即总残差)达到最小。  一般来说,有几种方法可以计算总残差:直接对残差进行求和、求出残差的绝对值的和以及求出残差的平方和。经分析,以残差平方和最小确定直线位置最为合理。这也是最小二乘法的原则。  由公式2可以计算出所有样本的残差平方和,见公式3:,对于这n个点,需要用一条直线去拟合。综合起来看,这条直线应该处于样本数据的中心位置最合理。选择该直线的标准为:使总的拟合误差(即总残差)达到最小。一般来说,有几种方法可以计算总残差:直接对残差进行求和、求出残差的绝对值的和以及求出残差的平方和。经分析,以残差平方和最小确定直线位置最为合理。这也是最小二乘法的原则。由公式2可以计算出所有样本的残差平方和,见公式3:$Q=\sum{i=1}^n ui^2=\sum{i=1}^n (Y_i-kXi-b)^2

只要求出使得$Q$最小,即确定$k$和$b$两个参数。以$k$和$b$为变量,把他们看作Q的函数,由于Q函数是一个凸函数,其极值点也就是其最大值和最小值。求极值点可以通过求导数获得。求$Q$的两个待估计参数的偏倒数。只要求出使得$Q$最小,即确定$k$和$b$两个参数。以$k$和$b$为变量,把他们看作Q的函数,由于Q函数是一个凸函数,其极值点也就是其最大值和最小值。求极值点可以通过求导数获得。求$Q$的两个待估计参数的偏倒数。
\frac{\partial Q}{\partial b}=2\sum

{i=1}^n (Y_i-kXi-b)(-1)=0\frac{\partial Q}{\partial k}=2\sum{i=1}^n (Y_i-kX_i-b)(-X_i)=0$$
  求出两个参数,带入公式1,即可以得到n个点的拟合直线方程。


4. 最小二乘法的C++实现

  这段代码网上有很多,我这里贴出来。

#include<iostream>
#include<fstream>
#include<vector>
using namespace std;
class LeastSquare{
    double a, b;
public:
    LeastSquare(const vector<double>& x, const vector<double>& y)
    {
        double t1=0, t2=0, t3=0, t4=0;
        for(int i=0; i<x.size(); ++i)
        {
            t1 += x[i]*x[i];
            t2 += x[i];
            t3 += x[i]*y[i];
            t4 += y[i];
        }
        a = (t3*x.size() - t2*t4) / (t1*x.size() - t2*t2);
        //b = (t4 - a*t2) / x.size();
        b = (t1*t4 - t2*t3) / (t1*x.size() - t2*t2);
    }

    double getY(const double x) const
    {
        return a*x + b;
    }

    void print() const
    {
        cout<<"y = "<<a<<"x + "<<b<<"\n";
    }

};

int main(int argc, char *argv[])
{
    if(argc != 2)
    {
        cout<<"Usage: DataFile.txt"<<endl;
        return -1;
    }
    else
    {
        vector<double> x;
        ifstream in(argv[1]);
        for(double d; in>>d; )
            x.push_back(d);
        int sz = x.size();
        vector<double> y(x.begin()+sz/2, x.end());
        x.resize(sz/2);
        LeastSquare ls(x, y);
        ls.print();

        cout<<"Input x:\n";
        double x0;
        while(cin>>x0)
        {
            cout<<"y = "<<ls.getY(x0)<<endl;
            cout<<"Input x:\n";
        }
    }
}

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK