博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
梯度下降法
阅读量:5110 次
发布时间:2019-06-13

本文共 2420 字,大约阅读时间需要 8 分钟。

有一个:

\[ z(x,y) = \sqrt {x^2 + y^2} \]
它的开口向上, 最低点为原点\(O\). \(z\)方向的俯视图如下:
sgd1.bmp

现想象你自己是一个有意识的能移动的小质点, 站在曲面上一个非原点的位置\((x, y)\)上, 如图上的小红圈所示. 你看不到整个曲面的全貌, 能看到的只有以你为中心的,半径为\(r\)(为投影在xoy平面上的值)的水平视野内的地貌. 因为曲面是光滑的, 所以地貌全部内容就是坡度. 现在你要做的是, 根据眼前能看到的地貌, 一步一步走, 步长为\(r\), 以最少的步数到达最低点. 注意, 因为你是质子, 本身很小, 所以你的水平视野半径\(r\)也会很小很小, 几乎为0.

你唯一能做的就是走一步看一步. 每一步能依据的信息就是各个方向的坡度. 用射线段\(l\)表示方向, 起点是你自己, 长度为\(r\), 与\(x\)轴的夹角为\(\alpha\). 假如你选择了\(\alpha\)方向, 则走完这一步之后, 你的xoy坐标为\((x + rcos\alpha, y + rsin\alpha)\), 高度\(z\)变为\(z(x + rcos\alpha, y + rsin\alpha)\). 你唯一知晓的坡度, 即高度变化率, 可以量化为:

\[ 坡度 = \frac {z(x + rcos\alpha, y + rsin\alpha)}{r} = \frac {\partial z}{\partial x} cos\alpha + \frac {\partial z}{\partial y} sin\alpha \]
在微积分里, 它也被称为\(z\)沿射线段\(l\)的方向导数, 用\(\frac {\partial z}{\partial l}\)表示. 当然, 你能观察到的只是它的数值, 而非表达式.
将其写成两个向量的内积形式:
\[ \frac {\partial z}{\partial l} = (\frac {\partial z}{\partial x}, \frac {\partial z}{\partial y})(cos\alpha, sin\alpha)^T = grad^T n \]
\(grad = (\frac {\partial z}{\partial x}, \frac {\partial z}{\partial y})^T\)也称为\(z\)\((x, y)\)处的梯度. \(n=(cos\alpha, sin\alpha)^T\)则是\(l\)的单位方向向量.
因为\(\Delta z \propto \frac {\partial z}{\partial l} = grad^T n\), 所以:

  • \(n\)\(grad\)同向时, \(grad^T n\)为正数最大. 若沿\(n\)方向走一步, \(z\)值最大限度的增大.
  • \(n\)\(grad\)反向时, \(grad^T n\)为负数最大. 若沿\(n\)方向走一步, \(z\)值最大限度的降低.

由于你的目的是往下走, 所以应该选择\(-grad\)方向. 每走一步, \(x与y\)的变化方式为:

\[ (x, y) \gets (x, y) - r\frac {grad^T}{||grad||} \]

嗯, 记住你现在还是个质子, 你的\(r\)很小很小. 如果你离目的地(原点)还很远的话, 要费很多很多极多的步子才能到达. 切换到实际应用中求最小值点的场景, 就意味着很长很长的计算时间. 所以往往不是将\(r\)固定为一个极小的值, 而是将\(\frac r{|grad|}\)固定为一个值: \(lr\), 称作为step size. 在机器学习里就是learning rate, 学习速率. 所以上式改为:

\[ (x, y) \gets (x, y) - lr*(\frac {\partial z}{\partial x}, \frac {\partial z}{\partial y}) \]
路径如图中红线所示:
sgd2.bmp
这种数值方法又叫做(单纯的)牛顿梯度下降法, 用于求最小值(点), 可以放心的推广到更高维空间. 不过有一个前提是目标函数是凹的, 即乘以\(-1\)后是凸的. 不然, 最后有可能会停留在局部最优而非全局最优.

用于画出路径的matlab代码:

close all;phi = pi/6;a = -pi:.05*pi:pi;r = 0: .1: 2;[A, R] = meshgrid(a, r);X = R.* cos(A);Y = R.* sin(A);Z = cot(phi) * sqrt(X.^2 + Y.^2);surf(X, Y, Z);hold on;plot3([1],[1], cot(phi)*sqrt(2), 'ro');alpha(.8);Xs = [];Ys = [];Zs = [];lr = 0.001;x = 1;y = 1;%z = cot(phi) * sqrt(x^2 + x^2);for i = 1:10^4    x = x  - lr * x / sqrt(x^2 + y^2);    y = y  - lr * y / sqrt(x^2 + y^2);    z = cot(phi) * sqrt(x^2 + x^2) %   plot3(x,y, z, 'r.');    Xs = [Xs, x];    Ys = [Ys, y];    Zs = [Zs, z];end plot3(Xs, Ys, Zs, 'r.');

转载于:https://www.cnblogs.com/dengdan890730/p/5557024.html

你可能感兴趣的文章
在VMWare上安装Arch Linux
查看>>
[arc076F]Exhausted? 贪心+堆
查看>>
MYSQL用户操作管理大杂烩
查看>>
stdafx.h有什么用
查看>>
神经网络与深度学习(1):神经元和神经网络
查看>>
Python-事件驱动模型代码
查看>>
Linux-NFS原理介绍
查看>>
maven工程仿springboot手写代码区分开发测试生产
查看>>
javascript高级编程笔记01(基本概念)
查看>>
unicode 字符范围
查看>>
确保对象在被使用前的初始化
查看>>
理解[].forEach.call()
查看>>
sdcms IIS7 windows server 2008 配置后,无法修改模板页面
查看>>
【整理】HTML5游戏开发学习笔记(4)- 记忆力游戏
查看>>
学习资料 数据结构
查看>>
python 标准库中队列相关模块介绍(子博客)
查看>>
excel中的数据整列填充
查看>>
python爬虫(1)
查看>>
20171202作业1python入门
查看>>
VC 实现文件与应用程序关联
查看>>