继续浏览精彩内容
慕课网APP
程序员的梦工厂
打开
继续
感谢您的支持,我会继续努力的
赞赏金额会直接到老师账户
将二维码发送给自己后长按识别
微信支付
支付宝支付

基于梯度下降的曲线拟合

潇潇雨雨
关注TA
已关注
手记 165
粉丝 25
获赞 128

背景

7月份的时候导师布置了个作业,他给了一条用程序生成的曲线,然后让我们用代码实现一个梯度下降算法来拟合曲线。具体要求:

data.csv文件中包含两列用逗号分隔的数据。第一列是x,第二列是y。完成如下工作:
(1)在data.csv中随机选择80%的数据作为训练集,剩余20%作为测试集。
(2)构造模型,采用梯度下降算法训练模型。
(3)用测试集对训练的模型进行评估,将测试集中的x作为输入,用模型计算y,计算预测值与实际值的RMSE。
(4)绘制data.csv中的点,绘制x ∈ [0,1] 之间模型的对应曲线。

数据格式如下:

0.000000000000000000,0.0000454019910096840.010010010010010010,0.0000674879083479180.020020020020020020,0.0000995166652482450.030030030030030030,0.0001455742214057580.040040040040040040,0.0002112477521525380.050050050050050046,0.0003041019360496450.060060060060060060,0.0004342776116289260.070070070070070073,0.0006152366314268930.080080080080080079,0.0008646872279901880.090090090090090086,0.0012057601227382130.100100100100100092,0.001668621265042236

上面的csv文件一共有1000行数据,在xy平面上绘制出来的曲线如下:


https://img4.mukewang.com/5d3088540001cf4b06290478.jpg

思路

老师的意思是先猜这条曲线是什么函数的曲线(先确定函数的基本形式),一开始函数的具体参数是不知道的,需要猜几个初始值,那么猜出来的曲线一定和实际曲线有较大差异,再用最优化的方法找到使差异最小化的函数参数,从而实现曲线的拟合。这里要求实现梯度下降算法来求解最小值。

从曲线的图像来看原始数据应该是几个均值方差不同的高斯函数叠加而成的,图中有4个峰,因此可以假设曲线的模型为:f(x)=\alpha_1e^{-\frac{(x-\mu_1)^2}{2\sigma^2_1}}+\alpha_2e^{-\frac{(x-\mu_2)^2}{2\sigma^2_2}}+\alpha_3e^{-\frac{(x-\mu_3)^2}{2\sigma^2_3}}+\alpha_4e^{-\frac{(x-\mu_4)^2}{2\sigma^2_4}}。
令误差函数为E=\sum\limits_{i=1}^{n} (f(x_i) - y_i)^2。则理想的模型参数:
(\alpha_1,\mu_1,\sigma_1,\alpha_2,...,\sigma_4)=\min\limits_{\alpha_1,...,\sigma_4}E

梯度下降算法每次求出函数(E)在某个点(当前参数)的梯度,因为梯度就是函数值增长最快的那个方向,所以让参数沿着梯度的负方向乘以一定的步长进行更新,就一定能抵达一个局部极小点。所以只要给定了这里的误差函数E(\alpha_1,\mu_1,\sigma_1,\alpha_2,\mu_2,\sigma_2,\alpha_3,\mu_3,\sigma_3,\alpha_4,\mu_4,\sigma_4),就可以通过梯度下降算法来找到使误差函数达到局部极小的12个参数。

为了便于计算,可以把\sigma^2当成一个整体,此时需要求出E在某个点的梯度的一般表示:(\frac{\partial E}{\partial \alpha_1},\frac{\partial E}{\partial \mu_1},\frac{\partial E}{\partial \sigma_1^2},\frac{\partial E}{\partial \alpha_2},\frac{\partial E}{\partial \mu_2},\frac{\partial E}{\partial \sigma_2^2},\frac{\partial E}{\partial \alpha_3},\frac{\partial E}{\partial \mu_3},\frac{\partial E}{\partial \sigma_3^2},\frac{\partial E}{\partial \alpha_4},\frac{\partial E}{\partial \mu_4},\frac{\partial E}{\partial \sigma_4^2},)。其中\frac{\partial E}{\partial \alpha_1}=2\sum\limits_{i=1}^{n}((f(x_i)-y_i)e^{-\frac{(x_i-\mu_1)^2}{2\sigma_1^2}}) \frac{\partial E}{\partial \mu_1}=2\sum\limits_{i=1}^{n}(\frac{\alpha_1(x_i-\mu_1)}{\sigma_1^2}(f(x_i)-y_i)e^{-\frac{(x_i-\mu_1)^2}{2\sigma_1^2}}) \frac{\partial E}{\partial \sigma_1^2}=2\sum\limits_{i=1}^{n}(\frac{\alpha_1(x_i-\mu_1)^2}{2\sigma_1^4}(f(x_i)-y_i)e^{-\frac{(x_i-\mu_1)^2}{2\sigma_1^2}}),其余参数的偏导数以此类推。

设定一个迭代次数,每次求出误差函数的梯度后,设定步长\eta,让参数沿梯度的负方向更新,如:\alpha_1=\alpha_1-\eta\frac{\partial E}{\partial \alpha_1},\mu_1=\mu_1-\eta\frac{\partial E}{\partial \mu_1},然后重复这个步骤,直到达到一定迭代次数或者总误差小于一定阈值停止迭代。

程序

程序使用Java实现。(C++写起来麻烦而且没有合适的图表显示库,Python太慢,Java写起来最顺手)

一开始我面临的问题就是选择一个图表显示库,简单地调研了一下选了XChart,但是去了该项目的Github主页发现居然没有打包好的 jar 包,于是需要 clone 下来然后用 mvn package 命令把 jar 包打出来。

然后我定义了一个模型类 Model,这个模型类的成员变量是 double数组,用来放待调的参数,比如上文中的f(x)对应的参数数组长度就为12。Model类有一些待实现的方法如函数的求值(val)、梯度的求值(grad)等,其派生类GaussianModel就是上文中的模型。另外,因为梯度下降会抵达最近的极小点而不是全局最小点,最终的收敛点极大依赖于参数的初始值,我每次随机选取了一部分数据点来求梯度以跳出局部极小。

Java代码如下:

package com.company;import org.knowm.xchart.QuickChart;import org.knowm.xchart.SwingWrapper;import org.knowm.xchart.XYChart;import java.io.File;import java.io.FileNotFoundException;import java.util.*;import java.util.function.Function;import java.util.stream.Collectors;import static java.lang.Math.E;import static java.lang.Math.pow;import static java.lang.Math.sqrt;import static java.lang.System.exit;public class Solver {    private List<Point> rawData = new ArrayList<>();    private List<Point> trainData = new ArrayList<>();    private List<Point> testData = new ArrayList<>();    private Model model = null;    private Function<Model, Double> loss = null;    public Solver(String csvPath) throws FileNotFoundException {
        Scanner scanner = new Scanner(new File(csvPath));        while (scanner.hasNextLine()) {
            String[] xy = scanner.nextLine().split(",");
            rawData.add(new Point(Double.valueOf(xy[0]), Double.valueOf(xy[1])));
        }
    }    private Function<Model, Double> mse = (m) -> {        double lossSum = 0.0;        for (Point p : trainData) {            double diff = m.val(p.x) - p.y;
            lossSum += (diff * diff);
        }        return lossSum / 2.0;
    };    private void divide(float ratio4Train) {
        trainData.clear();
        testData.clear();        if (ratio4Train <= 0) throw new IllegalArgumentException("Ratio <= 0");        int testCount = (int) (rawData.size() * (1 - ratio4Train));
        Random rand = new Random(System.currentTimeMillis());
        Set<Integer> exclusiveIndices4Test = new HashSet<>();        while (exclusiveIndices4Test.size() < testCount) {            int index = rand.nextInt(rawData.size());            if (! exclusiveIndices4Test.contains(index)) {
                testData.add(rawData.get(index));
                exclusiveIndices4Test.add(index);
            }
        }        for (int i = 0; i < rawData.size(); i ++) {            if (! exclusiveIndices4Test.contains(i)) {
                trainData.add(rawData.get(i));
            }
        }
    }    private void train() {
        System.out.println("Train data size: " + trainData.size());
        System.out.println("Test data size: " + testData.size());//        model = new PolyModel(4);
        model = new GaussianModel(5);
        loss = mse;        // ==========================================================
        for (int i = 0; i < 10000; i ++) {            double lossVal = loss.apply(model);            double[] gradVal = model.grad(trainData);
            System.out.println(String.format("Iter: %d, loss: %f ", i, lossVal));
            System.out.println(String.format("Theta: %f, %f, %f", model.theta[0], model.theta[1], model.theta[2]));
            System.out.println(String.format("Grad: %f, %f, %f\n", gradVal[0], gradVal[1], gradVal[2]));            if (Double.isNaN(lossVal)) {
                model.randomize(); i = 0;                continue;
            }            for (int j = 0; j < gradVal.length; j ++) {                double delta = model.rate(j) * gradVal[j];
                model.theta[j] -= delta;
            }//            if (lossVal < 1.06) break;
        }
        System.out.println(String.format("Theta: %f, %f, %f", model.theta[0], model.theta[1], model.theta[2]));
    }    private void validate() {        double RMSE = 0.0;        for (Point p : testData) {            double diff = model.val(p.x) - p.y;
            RMSE += (diff * diff);
        }
        RMSE /= testData.size();
        RMSE = sqrt(RMSE);
        System.out.println("RMSE: " + RMSE);
    }    private void plot() {
        XYChart chart = QuickChart.getChart(                "Result", "X", "Y", "y(x)",
                trainData.stream().map(point -> point.x).collect(Collectors.toList()),
                trainData.stream().map(point -> point.y).collect(Collectors.toList()));        double[] xPoints = new double[150];        double[] yPoints = new double[150];        for (int i = 0; i < 150; i ++) {
            xPoints[i] = i * 10.0 / 150;
            yPoints[i] = model.val(xPoints[i]);
        }
        chart.addSeries("model", xPoints, yPoints);        new SwingWrapper<XYChart>(chart).displayChart();
    }    public void solve() {
        divide(0.8f);
        train();
        validate();
        plot();
    }    public static void main(String[] args) throws FileNotFoundException {    // write your code here
        if (args.length < 1) {
            System.out.println("Usage: java -jar GradientDesent.jar data.csv");
            exit(0);
        }        new Solver(args[0]).solve();
    }    private static class Point {        double x;        double y;        public Point(double x, double y) {this.x = x; this.y = y;}

    }    private static abstract class Model {        double theta[] = null;        abstract double val(double x);        abstract double[] grad(List<Point> trainData);        abstract void randomize();        abstract double rate(int i);
    }    private static class PolyModel extends Model{        public PolyModel(int n) {            if (n < 2) throw new IllegalArgumentException("n MUST be larger than 2.");
            theta = new double[n];
            randomize();
        }        double val(double x) {            double result = 0.0;            for (int i = 0; i < theta.length; i ++) {
                result += theta[i] * pow(x, i);
            }            return result;
        }        @Override
        double[] grad(List<Point> trainData) {            double []gradVec = new double[theta.length];            for (int i = 0; i < gradVec.length; i ++) {
                gradVec[i] = 0.0;
                Random r = new Random();
                List<Point> data = new ArrayList<>();                for (int k = 0; k < 50; k ++)
                    data.add(trainData.get(r.nextInt(trainData.size())));                for (Point p : data) {                    double diff = val(p.x) - p.y;
                    gradVec[i] += (diff * pow(p.x, i));
                }
            }            return gradVec;
        }        @Override
        void randomize() {
            Random rand = new Random(System.currentTimeMillis());            for (int i = 0; i < theta.length; i ++) {
                theta[i] = rand.nextDouble() ;
            }
        }        @Override
        double rate(int i) {            return 0.00000002;
        }
    }    private static class GaussianModel extends Model{        /**
         * f(x) = a * e ^ (- (x - μ)^2 / σ^2)
         * (a, μ, σ2) <<----
         * @param n number of gaussian function
         */
        public GaussianModel(int n) {            if (n < 1) throw new IllegalArgumentException("n MUST be larger than 1.");
            theta = new double[n * 3];
            randomize();
        }        @Override
        double val(double x) {            double result = 0.0;            for (int i = 0; i < theta.length / 3; i ++) {                double alpha = theta[i * 3 + 0];                double miu = theta[i * 3 + 1];                double sigma2 = theta[i * 3 + 2];
                result += (alpha * pow(E, - pow((x - miu), 2) / sigma2 / 2));
            }            return result;
        }        @Override
        double[] grad(List<Point> trainData) {            double[] gradVec = new double[theta.length];            for (int i = 0; i < theta.length / 3; i ++) {
                gradVec[i * 3 + 0] = 0;
                gradVec[i * 3 + 1] = 0;
                gradVec[i * 3 + 2] = 0;                double alpha = theta[i * 3 + 0];                double miu = theta[i * 3 + 1];                double sigma2 = theta[i * 3 + 2];
                Random r = new Random();
                List<Point> stochasticData = new ArrayList<>();                for (int k = 0; k < 30; k ++)
                    stochasticData.add(trainData.get(r.nextInt(trainData.size())));                for (Point p : stochasticData) {                    double val = val(p.x);
                    gradVec[i * 3 + 0] += 2
                            * (val - p.y)
                            * (pow(E, - pow((p.x - miu), 2) / sigma2 / 2));
                    gradVec[i * 3 + 1] += (2
                            * alpha
                            * (val - p.y)
                            * pow(E, - pow((p.x - miu), 2) / sigma2 / 2)
                            * ((p.x - miu) / sigma2));
                    gradVec[i * 3 + 2] += (2
                            * alpha
                            * (val - p.y)
                            * pow(E, - pow((p.x - miu), 2) / sigma2 / 2)
                            * (pow((p.x - miu), 2) / pow(sigma2, 2) / 2)); //把sigma平方当成了一个整体
                }
            }            return gradVec;
        }        @Override
        void randomize() {
            Random rand = new Random(System.currentTimeMillis());            for (int i = 0; i < theta.length / 3; i ++) {
                theta[i * 3 + 0] = rand.nextDouble();
                theta[i * 3 + 1] = rand.nextDouble() * 5;
                theta[i * 3 + 2] = rand.nextDouble();
            }
        }        @Override
        double rate(int i) {            if (i % 3 == 0) {                return 0.0005;
            } else if (i % 3 == 1) { // miu
                return 0.0005;
            } else {                return 0.00005;
            }
        }        public String toString() {
            StringBuilder builder = new StringBuilder("Theta: ");            for (double t : theta) {
                builder.append(t);
                builder.append(", ");
            }
            builder.append("\nGrad: ");            return builder.toString();
        }
    }
}

最后的结果还是比较看人品的,并不是每次都能拟合地比较好,贴一个结果的图:


https://img3.mukewang.com/5d3088700001cccc06380474.jpg

结果

数据和代码我放到了我的Github:https://github.com/Jimmie00x0000/gradient_desent_demo



作者:JimmieZhou
链接:https://www.jianshu.com/p/7943a565d0b5


打开App,阅读手记
0人推荐
发表评论
随时随地看视频慕课网APP