手记

在浏览器中进行深度学习:TensorFlow.js (二)第一个模型,线性回归

在这一遍文章里,我们来看一看如何利用TensorFlow.js来构建数学模型,以及进行学习的基本过程。


学习的过程基本如下:


  1. 准备训练数据

  2. 构建一个模型

  3. 利用训练数据和模型,进行迭代的学习

  4. 模型训练完毕,用这个模型对新的数据进行预测(这里我们先略过对模型的验证部分)


好了,我们以最简单的线性回归为例子,看看这个过程。


准备数据



如上图所示,我在二维坐标系中生成了7个点,让它们在我假想的某条直线附近。我以这几个点作为我的训练数据。


训练数据的初始化代码如下,这里tx是所有点数据的x坐标,ty是所有点数据的坐标。


const train_x = tf.tensor1d(tx);const train_y = tf.tensor1d(ty);


模型选择


所有的模型都是错的,有的模型更好。


所谓的模型,也就是一个函数f,对应于某个输入数据,计算出某些输出数据。模型可以复杂,可以简单。简单的模型不一定不好,负责的模型也不一定好。


我们用线性模型举例,数学上就是假定 Y = wX + b


在这个模型中,有两个参数需要确定,w和b。


模型既然是个函数,那么它的代码也就很容易理解了:


const f = x => w.mul(x).add(b);


当然你也可以这样写:


const f = function(x){  return w.mul(x).add(b);
  }
}


迭代学习


学习的过程我们称作训练,训练通常是一个迭代的过程,这个过程中,通常需要这几样东西:


  • 一个损失函数(loss function),损失函数定义了模型是不是足够好,通常loss越小越好。

  • 一个优化器 (optimizer),优化器通过某种算法来决定如何改变参数的值,使得损失函数最小化。

  • 迭代循环, 通过循环 -> 调用优化器,得到新的参数,计算损失, 最终当损失足够小时,可以认为训练结束了。


训练代码如下:


初始化参数,这里使用随机数来作为参数的初始值。(注意,初始参数并不总是随机选择的。)


const w = tf.variable(tf.scalar(Math.random()));const b = tf.variable(tf.scalar(Math.random()));


初始化学习参数,


  • numIterations是迭代的次数,一般次数越多,模型的拟合就越好,但是就需要花费越多的计算

  • learningRate是学习率,这个值越大,学的速度就越快,但是也会更加容易错过极值点。


const numIterations = 200;const learningRate = 1;


选择一个优化器,这里我选择了adam。TensorFlow.js提供了多种优化器,例如sgd,momentum等等,大家可以根据自己的需要来选择。


const optimizer = tf.train.adam(learningRate);


对于损失函数,我们采用的是均方差 



const loss = (pred, label) => pred.sub(label).square().mean();


或者可以写作:


function loss(predictions, labels) {  const meanSquareError = predictions.sub(labels).square().mean();  return meanSquareError;
}


然后就是训练的过程啦:


for (let iter = 0; iter < numIterations; iter++) {
    optimizer.minimize(() => {      const loss_var = loss(f(train_x), train_y);
      loss_var.print();
      return loss_var;
    })
}


在训练过程中,我们调用tensor的print()方法打印出损失的值,看看训练过程是不是收敛。当选择的模型,参数,优化器不合适的时候,有可能训练过程并不收敛。


训练的结果我们就等到了w和b的值。也就是确定了直线的斜率和截距。


我们可以看到学习过程中是如何慢慢收敛到最后的结果的直线。


总结


本文描述了一个使用tensoflow.js来进行最简单的线性回归模型的学习的过程。希望大家可以通过这个简单的例子了解机器学习的基本思路。

作者:naughty                            

来源:https://my.oschina.net/taogang/blog/1793835


0人推荐
随时随地看视频
慕课网APP