Cola Life long learning

2D函数梯度优化及代码讲解

基本公式:

$$ f(x,y) = (x^2 + y - 11)^2 + (x + y^2 - 7)^2 $$

解释:这里2D函数就是只有两个未知数的函数(个人理解),这个方程是科学家专门设计出来检测优化器效果的。

第一步:画图

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


def himmelblau(x, y):
    return (x**2 + y - 11)**2 + (x + y**2 - 7)**2   ### 2d函数方程

x = np.arange(-6, 6, 0.1) # import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def himmelblau(x, y):
    return (x**2 + y - 11)**2 + (x + y**2 - 7)**2

x = np.arange(-6, 6, 0.1) ### x轴的范围 从-6到6 每隔0.1一个点 一共120个点
y = np.arange(-6, 6, 0.1) ### y轴的范围 与上面同理
X, Y = np.meshgrid(x, y)    ### 将x,y合并,生成数组(张量)
Z = himmelblau(X, Y)      ### 运算 求出Z

fig = plt.figure('himmelblau')    #含义见下文
ax = fig.add_subplot(projection='3d')  #含义见下文
ax.plot_surface(X, Y, Z)
ax.view_init(60, -30)
ax.set_xlabel('x')
ax.set_ylabel('y')
plt.show()
y = np.arange(-6, 6, 0.1) # y轴的范围
X, Y = np.meshgrid(x, y)
Z = himmelblau(X, Y)

fig = plt.figure('himmelblau')
ax = fig.add_subplot(projection='3d')
ax.plot_surface(X, Y, Z)
ax.view_init(60, -30)
ax.set_xlabel('x')
ax.set_ylabel('y')
plt.show()

(1)figure语法说明

  • figure(num=None, figsize=None, dpi=None, facecolor=None, edgecolor=None, frameon=True)
  • num:图像编号或名称,数字为编号 ,字符串为名称
  • figsize:指定figure的宽和高,单位为英寸;
  • dpi参数指定绘图对象的分辨率,即每英寸多少个像素,缺省值为80 1英寸等于2.5cm,A4纸是 21*30cm的纸张
  • facecolor:背景颜色
  • edgecolor:边框颜色
  • frameon:是否显示边框

(2)fig.add_subplot语法说明

  • fig.add_subplot(234),“234”表示“2×3网格,第四子图”。

第二步:优化参数,反向传播进行梯度下降

通过对机器学习的概念的了解可以很准切的明白下面步骤的含义:

                  数据 —> 模型 —> 损失函数 —> 优化器 —> 迭代训练
x = torch.tensor([0., 0.], requires_grad=True) # 设定初始值(0, 0)
optimizer = torch.optim.Adam([x], lr=1e-3)   # 定义一个优化器对X进行优化,设定学习率为0.001

for step in range(20000):
    pred = himmelblau(x[0], x[1])
    optimizer.zero_grad() # 梯度信息清零
    pred.backward()
    optimizer.step() # 进行一次优化器优化,根据梯度信息更新x[0]和x[1]
    
    if step % 2000 == 0:
        print("step{}: x={}, f(x) = {}".format(step, x.tolist(), pred.item()))
2 评论
    Mr.Ghosts 2022年11月17日 Reply

    牛啊,大彪

      Cola 2022年11月19日 Reply

      @Mr.Ghosts 俺就是个菜鸟

^